Skip to content

Commit

Permalink
feat: trainer join vertex for preprocess & inference
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Feb 27, 2024
1 parent a331577 commit 6623e2b
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 312 deletions.
103 changes: 56 additions & 47 deletions numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from numalogic.registry import LocalLRUCache, ArtifactData
from numalogic.tools.types import artifact_t, redis_client_t
from numalogic.udfs._base import NumalogicUDF
from numalogic.udfs._config import PipelineConf
from numalogic.udfs._config import PipelineConf, StreamConf
from numalogic.udfs._metrics import (
MODEL_STATUS_COUNTER,
RUNTIME_ERROR_COUNTER,
Expand All @@ -23,7 +23,7 @@
UDF_TIME,
_increment_counter,
)
from numalogic.udfs.entities import StreamPayload, Header, Status
from numalogic.udfs.entities import StreamPayload, Status, TrainerPayload
from numalogic.udfs.tools import _load_artifact, _update_info_metric

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -86,6 +86,33 @@ def compute(cls, model: artifact_t, input_: npt.NDArray[float], **_) -> npt.NDAr
raise RuntimeError("Model forward pass failed!") from err
return np.ascontiguousarray(recon_err).squeeze(0)

@staticmethod
def _get_trainer_message(
keys: list[str],
stream_conf: StreamConf,
payload: StreamPayload,
*metric_values: str,
) -> Message:
ckeys = [_ckey for _, _ckey in zip(stream_conf.composite_keys, payload.composite_keys)]
train_payload = TrainerPayload(
uuid=payload.uuid,
composite_keys=ckeys,
metrics=payload.metrics,
config_id=payload.config_id,
pipeline_id=payload.pipeline_id,
)
if metric_values:
_increment_counter(
counter=MODEL_STATUS_COUNTER,
labels=(payload.status.value, *metric_values),
)
_LOGGER.info(
"%s - Sending training request for: %s",
train_payload.uuid,
train_payload.composite_keys,
)
return Message(keys=keys, value=train_payload.to_json(), tags=["train"])

@UDF_TIME.time()
def exec(self, keys: list[str], datum: Datum) -> Messages:
"""
Expand All @@ -110,26 +137,14 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload.config_id,
payload.pipeline_id,
)

_increment_counter(counter=MSG_IN_COUNTER, labels=_metric_label_values)

_LOGGER.debug(
"%s - Received Msg: { CompositeKeys: %s, Metrics: %s }",
payload.uuid,
payload.composite_keys,
payload.metrics,
)

# Forward payload if a training request is tagged
if payload.header == Header.TRAIN_REQUEST:
_LOGGER.info(
"%s - Forwarding the message with the key: %s to next vertex because header is: %s",
payload.uuid,
payload.composite_keys,
payload.header,
)
return Messages(Message(keys=keys, value=payload.to_json()))

_stream_conf = self.get_stream_conf(payload.config_id)
_conf = _stream_conf.ml_pipelines[payload.pipeline_id]

Expand All @@ -144,17 +159,9 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:

# Send training request if artifact loading is not successful
if not artifact_data:
payload = replace(
payload, status=Status.ARTIFACT_NOT_FOUND, header=Header.TRAIN_REQUEST
)
_increment_counter(
counter=MODEL_STATUS_COUNTER,
labels=(
payload.status.value,
*_metric_label_values,
),
return Messages(
self._get_trainer_message(keys, _stream_conf, payload, *_metric_label_values)
)
return Messages(Message(keys=keys, value=payload.to_json()))

# Perform inference
try:
Expand All @@ -168,25 +175,28 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload.composite_keys,
payload.metrics,
)
payload = replace(payload, status=Status.RUNTIME_ERROR, header=Header.TRAIN_REQUEST)
_increment_counter(
counter=MODEL_STATUS_COUNTER, labels=(payload.status.value, *_metric_label_values)
)
return Messages(Message(keys=keys, value=payload.to_json()))
else:
status = (
Status.ARTIFACT_STALE
if self.is_model_stale(artifact_data, payload)
else Status.ARTIFACT_FOUND
return Messages(
self._get_trainer_message(keys, _stream_conf, payload, *_metric_label_values)
)
payload = replace(
payload,
data=x_inferred,
status=status,
metadata={
"model_version": int(artifact_data.extras.get("version")),
**payload.metadata,
},

msgs = Messages()
status = (
Status.ARTIFACT_STALE
if self.is_model_stale(artifact_data, payload)
else Status.ARTIFACT_FOUND
)
payload = replace(
payload,
data=x_inferred,
status=status,
metadata={
"model_version": int(artifact_data.extras.get("version")),
**payload.metadata,
},
)
if status == Status.ARTIFACT_STALE:
msgs.append(
self._get_trainer_message(keys, _stream_conf, payload, *_metric_label_values)
)

_LOGGER.info(
Expand All @@ -195,16 +205,15 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload.composite_keys,
payload.metrics,
)
_increment_counter(counter=MSG_PROCESSED_COUNTER, labels=_metric_label_values)
msgs.append(Message(keys=keys, value=payload.to_json()))

_LOGGER.debug(
"%s - Time taken in inference: %.4f sec",
payload.uuid,
time.perf_counter() - _start_time,
)
_increment_counter(
counter=MODEL_STATUS_COUNTER, labels=(payload.status.value, *_metric_label_values)
)
_increment_counter(counter=MSG_PROCESSED_COUNTER, labels=_metric_label_values)
return Messages(Message(keys=keys, value=payload.to_json()))
return msgs

def is_model_stale(self, artifact_data: ArtifactData, payload: StreamPayload) -> bool:
"""
Expand Down
75 changes: 42 additions & 33 deletions numalogic/udfs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from numalogic.registry import LocalLRUCache
from numalogic.tools.types import redis_client_t, artifact_t
from numalogic.udfs import NumalogicUDF
from numalogic.udfs._config import PipelineConf
from numalogic.udfs.entities import Status, Header
from numalogic.udfs._config import PipelineConf, StreamConf
from numalogic.udfs.entities import Status, Header, TrainerPayload, StreamPayload
from numalogic.udfs.tools import make_stream_payload, get_df, _load_artifact, _update_info_metric

# TODO: move to config
Expand Down Expand Up @@ -130,7 +130,8 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
_increment_counter(counter=MSG_IN_COUNTER, labels=_metric_label_values)
# Drop message if dataframe shape conditions are not met
if raw_df.shape[0] < _stream_conf.window_size or raw_df.shape[1] != len(_conf.metrics):
_LOGGER.error("Dataframe shape: (%f, %f) error ", raw_df.shape[0], raw_df.shape[1])
_LOGGER.critical("Dataframe shape: (%f, %f) error ", raw_df.shape[0], raw_df.shape[1])
print(_metric_label_values)
_increment_counter(
counter=DATASHAPE_ERROR_COUNTER,
labels=_metric_label_values,
Expand Down Expand Up @@ -168,14 +169,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
)
payload = replace(payload, status=Status.ARTIFACT_FOUND)
else:
payload = replace(
payload, status=Status.ARTIFACT_NOT_FOUND, header=Header.TRAIN_REQUEST
)
_increment_counter(
counter=MODEL_STATUS_COUNTER,
labels=(payload.status.value, *_metric_label_values),
)
return Messages(Message(keys=keys, value=payload.to_json()))
return Messages(self._get_trainer_message(keys, _stream_conf, payload))
# Model will not be in registry
else:
# Load configuration for the config_id
Expand Down Expand Up @@ -217,32 +211,50 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload.metrics,
)
# TODO check again what error is causing this and if retraining is required
payload = replace(payload, status=Status.RUNTIME_ERROR, header=Header.TRAIN_REQUEST)
_increment_counter(
counter=MODEL_STATUS_COUNTER,
labels=(
payload.status.value,
*_metric_label_values,
),
payload = replace(
payload,
status=Status.RUNTIME_ERROR,
)
return Messages(Message(keys=keys, value=payload.to_json()))
return Messages(
self._get_trainer_message(keys, _stream_conf, payload, *_metric_label_values)
)
_increment_counter(
counter=MSG_PROCESSED_COUNTER,
labels=_metric_label_values,
)
_LOGGER.debug(
"%s - Time taken to execute Preprocess: %.4f sec",
payload.uuid,
time.perf_counter() - _start_time,
)
_increment_counter(
counter=MODEL_STATUS_COUNTER,
labels=(
payload.status.value,
*_metric_label_values,
),
return Messages(Message(keys=keys, value=payload.to_json(), tags=["inference"]))

@staticmethod
def _get_trainer_message(
keys: list[str],
stream_conf: StreamConf,
payload: StreamPayload,
*metric_values: str,
) -> Message:
ckeys = [_ckey for _, _ckey in zip(stream_conf.composite_keys, payload.composite_keys)]
train_payload = TrainerPayload(
uuid=payload.uuid,
composite_keys=ckeys,
metrics=payload.metrics,
config_id=payload.config_id,
pipeline_id=payload.pipeline_id,
)
_increment_counter(
counter=MSG_PROCESSED_COUNTER,
labels=_metric_label_values,
if metric_values:
_increment_counter(
counter=MODEL_STATUS_COUNTER,
labels=(payload.status.value, *metric_values),
)
_LOGGER.info(
"%s - Sending training request for: %s",
train_payload.uuid,
train_payload.composite_keys,
)
return Messages(Message(keys=keys, value=payload.to_json()))
return Message(keys=keys, value=train_payload.to_json(), tags=["train"])

@classmethod
def compute(
Expand All @@ -263,11 +275,8 @@ def compute(
------
RuntimeError: If preprocess fails
"""
_start_time = time.perf_counter()
try:
x_scaled = model.transform(input_)
_LOGGER.info("Time taken in preprocessing: %.4f sec", time.perf_counter() - _start_time)
except Exception as err:
raise RuntimeError("Model transform failed!") from err
else:
return x_scaled
return x_scaled
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[pytest]
log_cli = true
log_cli_level = INFO
4 changes: 1 addition & 3 deletions tests/udfs/resources/_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ stream_confs:
pipeline1:
pipeline_id: "pipeline1"
metrics: [ "col1" , "col2" ]
unified_scoring_conf:
scoring_function: "MEAN"
numalogic_conf:
model:
name: "VanillaAE"
Expand Down Expand Up @@ -71,4 +69,4 @@ druid_conf:
datasource: "table-name"
group_by: [ "timestamp", "col1" ]
pivot:
columns: [ "col2" ]
columns: [ "col2" ]
Loading

0 comments on commit 6623e2b

Please sign in to comment.