Skip to content

Commit

Permalink
fix: trainer bug (#297)
Browse files Browse the repository at this point in the history
1. fix train error.
2. Change multiple_save logic -> making it robust

---------

Signed-off-by: s0nicboOm <[email protected]>
  • Loading branch information
s0nicboOm committed Sep 23, 2023
1 parent d249942 commit fdec237
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 62 deletions.
11 changes: 8 additions & 3 deletions numalogic/backtest/_prom.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,17 @@ def train_models(
self.conf.numalogic_conf.preprocess
),
threshold_clf=ThresholdFactory().get_instance(self.conf.numalogic_conf.threshold),
trainer_cfg=self.conf.numalogic_conf.trainer,
numalogic_cfg=self.conf.numalogic_conf,
)
artifacts_dict = {
"model": artifacts["inference"].artifact,
"preproc_clf": artifacts["preproc_clf"].artifact,
"threshold_clf": artifacts["threshold_clf"].artifact,
}
with open(self._modelpath, "wb") as f:
torch.save(artifacts, f)
torch.save(artifacts_dict, f)
LOGGER.info("Models saved in %s", self._modelpath)
return artifacts
return artifacts_dict

def generate_scores(
self,
Expand Down
48 changes: 27 additions & 21 deletions numalogic/registry/redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,21 @@ def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
stale_ts = (datetime.now() - timedelta(hours=freq_hr)).timestamp()
return stale_ts > artifact_ts

def __update_metadata(self, skeys: KEYS, dict_artifacts: dict[str, KeyedArtifact], metadata):
try:
with self.client.pipeline(transaction=self.transactional) as pipe:
pipe.multi()
for _, value in dict_artifacts.items():
key = self.construct_key(skeys, value.dkeys)
latest_key = self.__construct_latest_key(key)
version_key = self.client.get(name=latest_key)
pipe.hset(
name=version_key.decode(), key="metadata", value=orjson.dumps(metadata)
)
pipe.execute()
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err

def save_multiple(
self,
skeys: KEYS,
Expand All @@ -388,27 +403,18 @@ def save_multiple(
"""
dict_model_ver = {}
try:
with self.client.pipeline(transaction=self.transactional) as pipe:
pipe.multi()
for key, value in dict_artifacts.items():
dict_model_ver[":".join(value.dkeys)] = self.save(
skeys=skeys,
dkeys=value.dkeys,
artifact=value.artifact,
_pipe=pipe,
**metadata,
)

if len(dict_artifacts) == len(dict_model_ver):
self.save(
skeys=skeys,
dkeys=value.dkeys,
artifact=value.artifact,
_pipe=pipe,
artifact_versions=dict_model_ver,
**metadata,
)
pipe.execute()
for key, value in dict_artifacts.items():
dict_model_ver[":".join(value.dkeys)] = self.save(
skeys=skeys,
dkeys=value.dkeys,
artifact=value.artifact,
**metadata,
)
self.__update_metadata(
skeys=skeys,
dict_artifacts=dict_artifacts,
metadata={**{"artifact_versions": dict_model_ver}, **metadata},
)
_LOGGER.info("Successfully saved all the artifacts with: %s", dict_model_ver)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
Expand Down
64 changes: 30 additions & 34 deletions numalogic/udfs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

from numalogic.base import StatelessTransformer
from numalogic.config import PreprocessFactory, ModelFactory, ThresholdFactory, RegistryFactory
from numalogic.config._config import TrainerConf
from numalogic.config._config import NumalogicConf
from numalogic.config.factory import ConnectorFactory
from numalogic.models.autoencoder import AutoencoderTrainer
from numalogic.tools.data import StreamingDataset
from numalogic.tools.exceptions import ConfigNotFoundError, RedisRegistryError
from numalogic.tools.types import redis_client_t, artifact_t, KEYS, KeyedArtifact
from numalogic.udfs import NumalogicUDF
from numalogic.udfs._config import StreamConf, PipelineConf
from numalogic.udfs.entities import TrainerPayload, StreamPayload
from numalogic.udfs.entities import TrainerPayload

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -97,8 +97,8 @@ def compute(
input_: npt.NDArray[float],
preproc_clf: Optional[artifact_t] = None,
threshold_clf: Optional[artifact_t] = None,
trainer_cfg: Optional[TrainerConf] = None,
) -> dict[str, artifact_t]:
numalogic_cfg: Optional[NumalogicConf] = None,
) -> dict[str, KeyedArtifact]:
"""
Train the model on the given input data.
Expand All @@ -107,7 +107,7 @@ def compute(
input_: Input data
preproc_clf: Preprocessing artifact
threshold_clf: Thresholding artifact
trainer_cfg: Trainer configuration
numalogic_cfg: Numalogic configuration
Returns
-------
Expand All @@ -117,11 +117,15 @@ def compute(
------
ConfigNotFoundError: If trainer config is not found
"""
if not trainer_cfg:
raise ConfigNotFoundError("Trainer config not found!")

if not (numalogic_cfg and numalogic_cfg.trainer):
raise ConfigNotFoundError("Numalogic Trainer config not found!")
dict_artifacts = {}
trainer_cfg = numalogic_cfg.trainer
if preproc_clf:
input_ = preproc_clf.fit_transform(input_)
dict_artifacts["preproc_clf"] = KeyedArtifact(
dkeys=[_conf.name for _conf in numalogic_cfg.preprocess], artifact=preproc_clf
)

train_ds = StreamingDataset(input_, model.seq_len)
trainer = AutoencoderTrainer(**asdict(trainer_cfg.pltrainer_conf))
Expand All @@ -131,15 +135,17 @@ def compute(
train_reconerr = trainer.predict(
model, dataloaders=DataLoader(train_ds, batch_size=trainer_cfg.batch_size)
).numpy()
dict_artifacts["inference"] = KeyedArtifact(
dkeys=[numalogic_cfg.model.name], artifact=model
)

if threshold_clf:
threshold_clf.fit(train_reconerr)
dict_artifacts["threshold_clf"] = KeyedArtifact(
dkeys=[numalogic_cfg.threshold.name], artifact=threshold_clf
)

return {
"model": model,
"preproc_clf": preproc_clf,
"threshold_clf": threshold_clf,
}
return dict_artifacts

def exec(self, keys: list[str], datum: Datum) -> Messages:
"""
Expand Down Expand Up @@ -184,29 +190,17 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
thresh_clf = self._thresh_factory.get_instance(_conf.numalogic_conf.threshold)

# Train artifacts
artifacts = self.compute(
dict_artifacts = self.compute(
model,
x_train,
preproc_clf=preproc_clf,
threshold_clf=thresh_clf,
trainer_cfg=_conf.numalogic_conf.trainer,
numalogic_cfg=_conf.numalogic_conf,
)

# Save artifacts
# TODO perform multi-save here
# Save artifacts`
skeys = payload.composite_keys
dict_artifacts = {
"postproc": KeyedArtifact(
dkeys=[_conf.numalogic_conf.threshold.name], artifact=artifacts["threshold_clf"]
),
"inference": KeyedArtifact(
dkeys=[_conf.numalogic_conf.model.name], artifact=artifacts["model"]
),
"preproc": KeyedArtifact(
dkeys=[_conf.name for _conf in _conf.numalogic_conf.preprocess],
artifact=artifacts["preproc_clf"],
),
}

self.artifacts_to_save(
skeys=skeys,
dict_artifacts=dict_artifacts,
Expand Down Expand Up @@ -235,7 +229,7 @@ def artifacts_to_save(
skeys: KEYS,
dict_artifacts: dict[str, KeyedArtifact],
model_registry,
payload: StreamPayload,
payload: TrainerPayload,
) -> None:
"""
Save artifacts.
Expand All @@ -251,10 +245,12 @@ def artifacts_to_save(
Tuple of keys and artifacts
"""
for key, value in dict_artifacts.items():
if value.artifact:
if isinstance(value.artifact, StatelessTransformer):
del dict_artifacts[key]
dict_artifacts = {
k: v
for k, v in dict_artifacts.items()
if not isinstance(v.artifact, StatelessTransformer)
}

try:
ver_dict = model_registry.save_multiple(
skeys=skeys,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.6.0a5"
version = "0.6.0a6"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
1 change: 0 additions & 1 deletion tests/registry/test_redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def test_save_model_without_metadata_cache_hit(self):
skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model, **{"lr": 0.01}
)
resave_data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys)
print(resave_data.extras)
self.assertEqual(save_version, "0")
self.assertEqual(resave_version1, "1")
self.assertEqual(resave_data.extras["version"], "0")
Expand Down
7 changes: 5 additions & 2 deletions tests/udfs/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
REDIS_CLIENT = FakeStrictRedis(server=FakeServer())
KEYS = ["service-mesh", "1", "2"]


logging.basicConfig(level=logging.DEBUG)


Expand Down Expand Up @@ -63,12 +62,16 @@ def test_trainer_01(self):
"druid-config",
StreamConf(
numalogic_conf=NumalogicConf(
model=ModelInfo(name="VanillaAE", conf={"seq_len": 12, "n_features": 2}),
model=ModelInfo(
name="VanillaAE", stateful=True, conf={"seq_len": 12, "n_features": 2}
),
preprocess=[ModelInfo(name="LogTransformer", stateful=True, conf={})],
trainer=TrainerConf(pltrainer_conf=LightningTrainerConf(max_epochs=1)),
)
),
)
self.udf(KEYS, self.datum)

self.assertEqual(
2,
REDIS_CLIENT.exists(
Expand Down

0 comments on commit fdec237

Please sign in to comment.