Skip to content

Commit

Permalink
feat: support loading nl conf
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Jan 6, 2024
1 parent 3035407 commit 7f9317c
Showing 1 changed file with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions numalogic/backtest/_prom.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
output_dir: Union[str, Path] = DEFAULT_OUTPUT_DIR,
test_ratio: float = 0.25,
numalogic_cfg: Optional[dict] = None,
load_saved_conf: bool = False,
experiment_name: str = "exp",
):
self._url = url
Expand All @@ -68,19 +69,32 @@ def __init__(

self.query = query
self.metrics = metrics or []
self.conf = StreamConf(
source=ConnectorType.prometheus,
window_size=DEFAULT_SEQUENCE_LEN,
metrics=metrics,
numalogic_conf=OmegaConf.to_object(
OmegaConf.merge(
OmegaConf.structured(NumalogicConf), OmegaConf.create(numalogic_cfg)
),
),
)
self.conf = self._init_conf(metrics, numalogic_cfg, load_saved_conf)
self._seq_len = self.conf.window_size
self._n_features = len(self.conf.metrics)

def _init_conf(self, metrics: list[str], nl_conf: dict, load_saved_conf: bool) -> StreamConf:
if load_saved_conf:
try:
nl_conf = OmegaConf.load(os.path.join(self.out_dir, "config.yaml"))
except FileNotFoundError:
LOGGER.warning("No saved config found in %s", self.out_dir)
else:
LOGGER.info("Loaded saved config from %s", self.out_dir)

if nl_conf:
LOGGER.info("Using provided config!")
return StreamConf(
source=ConnectorType.prometheus,
window_size=DEFAULT_SEQUENCE_LEN,
metrics=metrics,
numalogic_conf=OmegaConf.to_object(
OmegaConf.merge(OmegaConf.structured(NumalogicConf), OmegaConf.create(nl_conf)),
),
)

raise ValueError("Provide one of numalogic_conf or load_saved_conf")

def train_models(
self,
df: Optional[pd.DataFrame] = None,
Expand Down Expand Up @@ -122,6 +136,10 @@ def train_models(
}
with open(self._modelpath, "wb") as f:
torch.save(artifacts_dict, f)

with open(os.path.join(self.out_dir, "config.yaml"), "w") as f:
OmegaConf.save(self.conf.numalogic_conf, f)

LOGGER.info("Models saved in %s", self._modelpath)
return artifacts_dict

Expand Down

0 comments on commit 7f9317c

Please sign in to comment.