Skip to content

Commit

Permalink
feat: enable feature transforms
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Apr 24, 2024
1 parent edcb9e6 commit 8b2e939
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions numalogic/backtest/_prom.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,14 @@ def train_models(
x_train = df_train.to_numpy(dtype=np.float32)
LOGGER.info("Training data shape: %s", x_train.shape)

if self.nlconf.trainer.transforms:
train_txs = PreprocessFactory().get_pipeline_instance(self.nlconf.trainer.transforms)
else:
train_txs = None
artifacts = UDFFactory.get_udf_cls("promtrainer").compute(
model=ModelFactory().get_instance(self.nlconf.model),
input_=x_train,
trainer_transform=train_txs,
preproc_clf=PreprocessFactory().get_pipeline_instance(self.nlconf.preprocess),
threshold_clf=ThresholdFactory().get_instance(self.nlconf.threshold),
numalogic_cfg=self.nlconf,
Expand Down

0 comments on commit 8b2e939

Please sign in to comment.