Skip to content

Commit

Permalink
feat: multivariate backtesting
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Jan 5, 2024
1 parent c0596d7 commit 3035407
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 168 deletions.
4 changes: 2 additions & 2 deletions numalogic/backtest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from importlib.util import find_spec
from numalogic.backtest._prom import PromUnivarBacktester
from numalogic.backtest._prom import PromBacktester


def _validate_req_pkgs():
Expand All @@ -12,4 +12,4 @@ def _validate_req_pkgs():
_validate_req_pkgs()


__all__ = ["PromUnivarBacktester"]
__all__ = ["PromBacktester"]
244 changes: 84 additions & 160 deletions numalogic/backtest/_prom.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@
import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
from numpy.typing import NDArray
from omegaconf import OmegaConf

from numalogic._constants import BASE_DIR
from numalogic.backtest._constants import DEFAULT_SEQUENCE_LEN
from numalogic.config import (
TrainerConf,
LightningTrainerConf,
NumalogicConf,
ModelInfo,
ModelFactory,
PreprocessFactory,
PostprocessFactory,
Expand All @@ -43,110 +40,46 @@
LOGGER = logging.getLogger(__name__)


def _init_default_streamconf(metrics: list[str]) -> StreamConf:
numalogic_cfg = NumalogicConf(
model=ModelInfo(
"VanillaAE", conf={"seq_len": DEFAULT_SEQUENCE_LEN, "n_features": len(metrics)}
),
preprocess=[ModelInfo("StandardScaler")],
trainer=TrainerConf(pltrainer_conf=LightningTrainerConf(accelerator="cpu")),
)
return StreamConf(
source=ConnectorType.prometheus,
window_size=DEFAULT_SEQUENCE_LEN,
metrics=metrics,
numalogic_conf=numalogic_cfg,
)


class PromUnivarBacktester:
"""
Class for running backtest for a single metric on data from Prometheus or Thanos.
Args:
url: Prometheus/Thanos URL
namespace: Namespace of the metric
appname: Application name
metric: Metric name
return_labels: Prometheus label names as columns to return
lookback_days: Number of days of data to fetch
output_dir: Output directory
test_ratio: Ratio of test data to total data
stream_conf: Stream configuration
"""

class PromBacktester:
def __init__(
self,
url: str,
namespace: str,
appname: str,
metric: str,
query: str,
return_labels: Optional[list[str]] = None,
metrics: Optional[list[str]] = None,
lookback_days: int = 8,
output_dir: Union[str, Path] = DEFAULT_OUTPUT_DIR,
test_ratio: float = 0.25,
stream_conf: Optional[StreamConf] = None,
numalogic_cfg: Optional[dict] = None,
experiment_name: str = "exp",
):
self._url = url
self.namespace = namespace
self.appname = appname
self.metric = metric
self.conf = stream_conf or _init_default_streamconf([metric])
self.test_ratio = test_ratio
self.lookback_days = lookback_days
self.return_labels = return_labels

self._seq_len = self.conf.window_size
self._n_features = len(self.conf.metrics)

self.out_dir = self.get_outdir(appname, metric, outdir=output_dir)
self.out_dir = self.get_outdir(experiment_name, outdir=output_dir)
self._datapath = os.path.join(self.out_dir, "data.csv")
self._modelpath = os.path.join(self.out_dir, "models.pt")
self._outpath = os.path.join(self.out_dir, "output.csv")

if not os.path.exists(self.out_dir):
os.makedirs(self.out_dir)

@classmethod
def get_outdir(cls, appname: str, metric: str, outdir=DEFAULT_OUTPUT_DIR) -> str:
"""Get the output directory for the given metric."""
if not appname:
return os.path.join(outdir, metric)
_key = ":".join([appname, metric])
return os.path.join(outdir, _key)

def read_data(self, fill_na_value: float = 0.0) -> pd.DataFrame:
"""
Reads data from Prometheus/Thanos and returns a dataframe.
Args:
fill_na_value: Value to fill NaNs with
Returns
-------
Dataframe with timestamp and metric values
"""
datafetcher = PrometheusFetcher(self._url)
df = datafetcher.fetch(
metric_name=self.metric,
start=(datetime.now() - timedelta(days=self.lookback_days)),
end=datetime.now(),
filters={"namespace": self.namespace, "app": self.appname},
return_labels=self.return_labels,
aggregate=False,
)
LOGGER.info(
"Fetched dataframe with lookback days: %s with shape: %s", self.lookback_days, df.shape
self.query = query
self.metrics = metrics or []
self.conf = StreamConf(
source=ConnectorType.prometheus,
window_size=DEFAULT_SEQUENCE_LEN,
metrics=metrics,
numalogic_conf=OmegaConf.to_object(
OmegaConf.merge(
OmegaConf.structured(NumalogicConf), OmegaConf.create(numalogic_cfg)
),
),
)

df.set_index(["timestamp"], inplace=True)
df.index = pd.to_datetime(df.index)

df.replace([np.inf, -np.inf], np.nan, inplace=True)
df = df.fillna(fill_na_value)

df.to_csv(self._datapath, index=True)
return df
self._seq_len = self.conf.window_size
self._n_features = len(self.conf.metrics)

def train_models(
self,
Expand All @@ -165,11 +98,15 @@ def train_models(
if df is None:
df = self._read_or_fetch_data()

df_train, _ = self._split_data(df[[self.metric]])
if self.metrics:
df = df[self.metrics]

df_train, _ = self._split_data(df)

x_train = df_train.to_numpy(dtype=np.float32)
LOGGER.info("Training data shape: %s", x_train.shape)

artifacts = UDFFactory.get_udf_cls("trainer").compute(
artifacts = UDFFactory.get_udf_cls("promtrainer").compute(
model=ModelFactory().get_instance(self.conf.numalogic_conf.model),
input_=x_train,
preproc_clf=PreprocessFactory().get_pipeline_instance(
Expand Down Expand Up @@ -219,86 +156,70 @@ def generate_scores(
) from err

if use_full_data:
df_test = df[[self.metric]]
df_test = df[self.metrics]
else:
_, df_test = self._split_data(df[[self.metric]])
_, df_test = self._split_data(df[self.metrics])
x_test = df_test.to_numpy(dtype=np.float32)
LOGGER.info("Test data shape: %s", df_test.shape)

preproc_udf = UDFFactory.get_udf_cls("preprocess")
nn_udf = UDFFactory.get_udf_cls("inference")
postproc_udf = UDFFactory.get_udf_cls("postprocess")

# Preprocess
x_scaled = preproc_udf.compute(model=artifacts["preproc_clf"], input_=x_test)

ds = StreamingDataset(x_scaled, seq_len=self.conf.window_size)
anomaly_scores = np.zeros(
anomaly_scores = np.zeros((len(ds), self.conf.window_size), dtype=np.float32)
postproc_func = PostprocessFactory().get_instance(self.conf.numalogic_conf.postprocess)

x_recon = np.zeros(
(len(ds), self.conf.window_size, len(self.conf.metrics)), dtype=np.float32
)
x_recon = np.zeros_like(anomaly_scores, dtype=np.float32)
postproc_func = PostprocessFactory().get_instance(self.conf.numalogic_conf.postprocess)

# Model Inference
for idx, arr in enumerate(ds):
x_recon[idx] = nn_udf.compute(model=artifacts["model"], input_=arr)
_, anomaly_scores[idx] = postproc_udf.compute(
anomaly_scores[idx], _ = postproc_udf.compute(
model=artifacts["threshold_clf"],
input_=x_recon[idx],
postproc_clf=postproc_func,
)

x_recon = inverse_window(torch.from_numpy(x_recon)).numpy()
final_scores = np.mean(anomaly_scores, axis=1)
output_df = self._construct_output_df(
timestamps=df_test.index,
input_=x_test,
anomaly_scores = inverse_window(
torch.unsqueeze(torch.from_numpy(anomaly_scores), dim=2)
).numpy()

return self._construct_output(
df_test,
preproc_out=x_scaled,
nn_out=x_recon,
postproc_out=final_scores,
postproc_out=anomaly_scores,
)
output_df.to_csv(self._outpath, index=True, index_label="timestamp")
LOGGER.info("Results saved in: %s", self._outpath)
return output_df

def save_plots(
self, output_df: Optional[pd.DataFrame] = None, plotname: str = "plot.png"
) -> None:
"""
Save plots for the given data.

Args:
output_df: Dataframe with timestamp, and anomaly scores
plotname: Name of the plot file
"""
if output_df is None:
output_df = pd.read_csv(self._outpath, index_col="timestamp", parse_dates=True)

fig, axs = plt.subplots(4, 1, sharex="col", figsize=(15, 8))
@classmethod
def get_outdir(cls, expname: str, outdir=DEFAULT_OUTPUT_DIR) -> str:
"""Get the output directory for the given metric."""
return os.path.join(outdir, expname)

axs[0].plot(output_df["metric"], color="b")
axs[0].set_ylabel("Original metric")
axs[0].grid(True)
axs[0].set_title(
f"TEST SET RESULTS\nMetric: {self.metric}\n"
f"namespace: {self.namespace}\napp: {self.appname}"
def read_data(self, fill_na_value: float = 0.0, save=True) -> pd.DataFrame:
datafetcher = PrometheusFetcher(self._url)
df = datafetcher.raw_fetch(
query=self.query,
start=(datetime.now() - timedelta(days=self.lookback_days)),
end=datetime.now(),
return_labels=self.return_labels,
)

axs[1].plot(output_df["preprocessed"], color="g")
axs[1].grid(True)
axs[1].set_ylabel("Preprocessed metric")

axs[2].plot(output_df["model_out"], color="black")
axs[2].grid(True)
axs[2].set_ylabel("NN model output")

axs[3].plot(output_df["scores"], color="r")
axs[3].grid(True)
axs[3].set_ylabel("Anomaly Score")
axs[3].set_xlabel("Time")
axs[3].set_ylim(0, 10)

fig.tight_layout()
_fname = os.path.join(self.out_dir, plotname)
fig.savefig(_fname)
LOGGER.info("Plot file: %s saved in %s", plotname, self.out_dir)
LOGGER.info(
"Fetched dataframe with lookback days: %s with shape: %s", self.lookback_days, df.shape
)
if self.metrics:
df = df[self.metrics]
df = df.replace([np.inf, -np.inf], np.nan).fillna(fill_na_value)
if save:
df.to_csv(self._datapath, index=True)
return df

def _split_data(self, df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
test_size = int(df.shape[0] * self.test_ratio)
Expand Down Expand Up @@ -332,27 +253,30 @@ def _load_or_train_model(
LOGGER.info("Loaded models from %s", _modelpath)
return artifacts

def _construct_output_df(
def _construct_output(
self,
timestamps: pd.Index,
input_: NDArray[float],
input_df: pd.DataFrame,
preproc_out: NDArray[float],
nn_out: NDArray[float],
postproc_out: NDArray[float],
) -> pd.DataFrame:
scores = np.vstack(
[
np.full((self._seq_len - 1, self._n_features), fill_value=np.nan),
ts_idx = input_df.index
dfs = {
"input": input_df,
"preproc_out": pd.DataFrame(
preproc_out,
columns=self.metrics,
index=ts_idx,
),
"model_out": pd.DataFrame(
nn_out,
columns=self.metrics,
index=ts_idx,
),
"postproc_out": pd.DataFrame(
postproc_out,
]
)

return pd.DataFrame(
{
"metric": input_.squeeze(),
"preprocessed": preproc_out.squeeze(),
"model_out": nn_out.squeeze(),
"scores": scores.squeeze(),
},
index=timestamps,
)
columns=["unified_score"],
index=ts_idx,
),
}
return pd.concat(dfs, axis=1)
2 changes: 1 addition & 1 deletion numalogic/udfs/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,4 @@ def compute(
_LOGGER.debug(
"Time taken in postprocess compute: %.4f sec", time.perf_counter() - _start_time
)
return y_score, score.reshape(-1)
return y_score.reshape(-1), score.reshape(-1)
Loading

0 comments on commit 3035407

Please sign in to comment.