-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
examples: update with new numalogic and pynumaflow (#202)
- update examples for numalogic version and pynumaflow versions - base numalogic udf - pynumaflow as an optional dependency - class-based udfs - remove protobuf 3.20 requirement --------- Signed-off-by: Avik Basu <[email protected]>
- Loading branch information
Showing
17 changed files
with
715 additions
and
446 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
cachetools==5.2.0 | ||
dataclasses-json==0.5.7 | ||
numalogic[mlflow]==0.3.0 | ||
pytorch-lightning==1.8.6 | ||
protobuf==3.20 # need this to avoid errors with tensorboard | ||
pynumaflow==0.2.6 | ||
cachetools>5.2,<6.0 | ||
numalogic[mlflow,numaflow] @ git+https://github.com/numaproj/numalogic.git@main | ||
# ../../../numalogic[mlflow,numaflow] # for local testing | ||
pytorch-lightning>2.0,< 3.0 | ||
pynumaflow>0.4,<0.5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from src.udf import Preprocess, Inference, Postprocess, Trainer, Threshold | ||
from numalogic.numaflow import NumalogicUDF | ||
|
||
|
||
class UDFFactory: | ||
"""Factory class to return the handler for the given step.""" | ||
|
||
_UDF_MAP = { | ||
"preprocess": Preprocess, | ||
"inference": Inference, | ||
"postprocess": Postprocess, | ||
"train": Trainer, | ||
"threshold": Threshold, | ||
} | ||
|
||
@classmethod | ||
def get_handler(cls, step: str) -> NumalogicUDF: | ||
"""Return the handler for the given step.""" | ||
udf_cls = cls._UDF_MAP[step] | ||
return udf_cls() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
from src.udf.inference import inference | ||
from src.udf.postprocess import postprocess | ||
from src.udf.preprocess import preprocess | ||
from src.udf.train import train | ||
from src.udf.threshold import threshold | ||
from src.udf.inference import Inference | ||
from src.udf.postprocess import Postprocess | ||
from src.udf.preprocess import Preprocess | ||
from src.udf.train import Trainer | ||
from src.udf.threshold import Threshold | ||
|
||
|
||
__all__ = ["preprocess", "inference", "postprocess", "train", "threshold"] | ||
__all__ = ["Preprocess", "Inference", "Postprocess", "Trainer", "Threshold"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,52 +1,68 @@ | ||
import logging | ||
import os | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
from numalogic.models.autoencoder import AutoencoderTrainer | ||
from numalogic.numaflow import NumalogicUDF | ||
from numalogic.registry import MLflowRegistry, ArtifactData | ||
from numalogic.tools.data import StreamingDataset | ||
from pynumaflow.function import Messages, Message, Datum | ||
from torch.utils.data import DataLoader | ||
|
||
from src.utils import Payload, load_artifact | ||
from src.utils import Payload | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
WIN_SIZE = int(os.getenv("WIN_SIZE")) | ||
TRACKING_URI = "http:https://mlflow-service.default.svc.cluster.local:5000" | ||
|
||
|
||
def inference(_: str, datum: Datum) -> Messages: | ||
r"""Here inference is done on the data, given, the ML model is present | ||
in the registry. If a model does not exist, the payload is flagged for training. | ||
It then passes to the threshold vertex. | ||
For more information about the arguments, refer: | ||
https://github.com/numaproj/numaflow-python/blob/main/pynumaflow/function/_dtypes.py | ||
class Inference(NumalogicUDF): | ||
""" | ||
The inference function here performs inference on the streaming data and sends | ||
the payload to threshold vertex. | ||
""" | ||
# Load data and convert bytes to Payload | ||
payload = Payload.from_json(datum.value.decode("utf-8")) | ||
messages = Messages() | ||
|
||
artifact_data = load_artifact(skeys=["ae"], dkeys=["model"], type_="pytorch") | ||
stream_data = np.asarray(payload.ts_data).reshape(-1, 1) | ||
def __init__(self): | ||
super().__init__() | ||
self.registry = MLflowRegistry(tracking_uri=TRACKING_URI) | ||
|
||
# Check if model exists for inference | ||
if artifact_data: | ||
LOGGER.info("%s - Model found!", payload.uuid) | ||
def load_model(self) -> ArtifactData: | ||
"""Loads the model from the registry.""" | ||
return self.registry.load(skeys=["ae"], dkeys=["model"]) | ||
|
||
# Load model from registry | ||
@staticmethod | ||
def _infer(artifact_data: ArtifactData, stream_data: npt.NDArray[float]) -> list[float]: | ||
"""Performs inference on the streaming data.""" | ||
main_model = artifact_data.artifact | ||
streamloader = DataLoader(StreamingDataset(stream_data, WIN_SIZE)) | ||
|
||
trainer = AutoencoderTrainer() | ||
recon_err = trainer.predict(main_model, dataloaders=streamloader) | ||
reconerr = trainer.predict(main_model, dataloaders=streamloader) | ||
return reconerr.tolist() | ||
|
||
def exec(self, keys: list[str], datum: Datum) -> Messages: | ||
""" | ||
Here inference is done on the data, given, the ML model is present | ||
in the registry. If a model does not exist, the payload is flagged for training. | ||
It then passes to the threshold vertex. | ||
For more information about the arguments, refer: | ||
https://github.com/numaproj/numaflow-python/blob/main/pynumaflow/function/_dtypes.py | ||
""" | ||
# Load data and convert bytes to Payload | ||
payload = Payload.from_json(datum.value) | ||
|
||
payload.ts_data = recon_err.tolist() | ||
LOGGER.info("%s - Inference complete", payload.uuid) | ||
artifact_data = self.load_model() | ||
stream_data = payload.get_array().reshape(-1, 1) | ||
|
||
else: | ||
# If model not found, set status as not found | ||
LOGGER.warning("%s - Model not found", payload.uuid) | ||
payload.is_artifact_valid = False | ||
# Check if model exists for inference | ||
if artifact_data: | ||
payload.set_array(self._infer(artifact_data, stream_data)) | ||
LOGGER.info("%s - Inference complete", payload.uuid) | ||
else: | ||
# If model not found, set status as not found | ||
LOGGER.warning("%s - Model not found", payload.uuid) | ||
payload.is_artifact_valid = False | ||
|
||
# Convert Payload back to bytes and conditional forward to threshold vertex | ||
messages.append(Message.to_vtx(key="threshold", value=payload.to_json().encode("utf-8"))) | ||
return messages | ||
# Convert Payload back to bytes and conditional forward to threshold vertex | ||
return Messages(Message(value=payload.to_json())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,58 @@ | ||
import logging | ||
|
||
import numpy as np | ||
from numalogic.numaflow import NumalogicUDF | ||
from numalogic.registry import MLflowRegistry | ||
from pynumaflow.function import Messages, Message, Datum | ||
|
||
from src.utils import Payload, load_artifact | ||
from src.utils import Payload | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
TRACKING_URI = "http:https://mlflow-service.default.svc.cluster.local:5000" | ||
|
||
|
||
def threshold(_: str, datum: Datum) -> Messages: | ||
r"""UDF that applies thresholding to the reconstruction error returned by the autoencoder. | ||
class Threshold(NumalogicUDF): | ||
"""UDF to apply thresholding to the reconstruction error returned by the autoencoder.""" | ||
|
||
For more information about the arguments, refer: | ||
https://github.com/numaproj/numaflow-python/blob/main/pynumaflow/function/_dtypes.py | ||
""" | ||
# Load data and convert bytes to Payload | ||
payload = Payload.from_json(datum.value.decode("utf-8")) | ||
messages = Messages() | ||
def __init__(self): | ||
super().__init__() | ||
self.registry = MLflowRegistry(tracking_uri=TRACKING_URI) | ||
|
||
# Load the threshold model from registry | ||
thresh_clf_artifact = load_artifact(skeys=["thresh_clf"], dkeys=["model"]) | ||
recon_err = np.asarray(payload.ts_data).reshape(-1, 1) | ||
|
||
# Check if model exists for inference | ||
if (not thresh_clf_artifact) or (not payload.is_artifact_valid): | ||
# If model not found, send it to trainer for training | ||
@staticmethod | ||
def _handle_not_found(payload: Payload) -> Messages: | ||
""" | ||
Handles the case when the model is not found. | ||
If model not found, send it to trainer for training. | ||
""" | ||
LOGGER.warning("%s - Model not found. Training the model.", payload.uuid) | ||
|
||
# Convert Payload back to bytes and conditional forward to train vertex | ||
payload.is_artifact_valid = False | ||
messages.append(Message.to_vtx(key="train", value=payload.to_json().encode("utf-8"))) | ||
return messages | ||
return Messages(Message(keys=["train"], value=payload.to_json())) | ||
|
||
def exec(self, _: list[str], datum: Datum) -> Messages: | ||
""" | ||
UDF that applies thresholding to the reconstruction error returned by the autoencoder. | ||
For more information about the arguments, refer: | ||
https://github.com/numaproj/numaflow-python/blob/main/pynumaflow/function/_dtypes.py | ||
""" | ||
# Load data and convert bytes to Payload | ||
payload = Payload.from_json(datum.value) | ||
|
||
# Load the threshold model from registry | ||
thresh_clf_artifact = self.registry.load( | ||
skeys=["thresh_clf"], dkeys=["model"], artifact_type="sklearn" | ||
) | ||
recon_err = payload.get_array().reshape(-1, 1) | ||
|
||
LOGGER.debug("%s - Threshold Model found!", payload.uuid) | ||
# Check if model exists for inference | ||
if (not thresh_clf_artifact) or (not payload.is_artifact_valid): | ||
return self._handle_not_found(payload) | ||
|
||
thresh_clf = thresh_clf_artifact.artifact | ||
payload.ts_data = thresh_clf.predict(recon_err).tolist() | ||
thresh_clf = thresh_clf_artifact.artifact | ||
payload.set_array(thresh_clf.predict(recon_err).tolist()) | ||
|
||
LOGGER.info("%s - Thresholding complete", payload.uuid) | ||
LOGGER.info("%s - Thresholding complete", payload.uuid) | ||
|
||
# Convert Payload back to bytes and conditional forward to postprocess vertex | ||
messages.append(Message.to_vtx(key="postprocess", value=payload.to_json().encode("utf-8"))) | ||
return messages | ||
# Convert Payload back to bytes and conditional forward to postprocess vertex | ||
return Messages(Message(keys=["postprocess"], value=payload.to_json())) |
Oops, something went wrong.