Skip to content

Commit

Permalink
fix: AutoencoderPipeline logged loss mean (#55)
Browse files Browse the repository at this point in the history
Signed-off-by: Diego Ponce <[email protected]>
  • Loading branch information
diego-ponce committed Oct 20, 2022
1 parent 056786d commit 537fae5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
10 changes: 6 additions & 4 deletions numalogic/models/autoencoder/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(
):
if not (model and seq_len):
raise ValueError("No model and seq len provided!")
if num_epochs < 1:
raise ValueError("num_epochs must be a positive interger")

self._model = model
self.seq_len = seq_len
Expand Down Expand Up @@ -147,18 +149,18 @@ def fit(self, X: NDArray[float], y=None, log_freq: int = 5) -> "AutoencoderPipel
dataset = self._model.construct_dataset(X, self.seq_len)
loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
self._model.train()

loss = torch.Tensor([0.0])
losses = []
for epoch in range(1, self.num_epochs + 1):
for x_batch in loader:
self.optimizer.zero_grad()
_, decoded = self._model(x_batch)
loss = self.criterion(decoded, x_batch)
loss.backward()
self.optimizer.step()

losses.append(loss.item())
if epoch % log_freq == 0:
_LOGGER.info(f"epoch : {epoch}, loss_mean : {loss.item():.7f}")
_LOGGER.info(f"epoch : {epoch}, loss_mean : {np.mean(losses):.7f}")
losses = []

self._thresholds, _mean, _std = self.find_thresholds(X)
self._stats["mean"] = _mean
Expand Down
5 changes: 5 additions & 0 deletions numalogic/tests/models/autoencoder/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,11 @@ def test_exception_in_load_model(self):
)
self.assertEqual(model_pl2.err_stats["std"], model_pl1.err_stats["std"])

def test_exception_invalid_epoch(self):
model = VanillaAE(10)
with self.assertRaises(ValueError):
AutoencoderPipeline(model, 10, num_epochs=-10)


class TestSparseAEPipeline(unittest.TestCase):
X_train = None
Expand Down

0 comments on commit 537fae5

Please sign in to comment.