From 1b4778e8ab24533afc68b87f9253e44b4b5f6766 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 9 Apr 2024 09:58:05 -0600 Subject: [PATCH] linting --- sup3r/models/abstract.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index a69b23c6b..21888ce9f 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1009,6 +1009,27 @@ def save(self, out_dir): if it does not already exist. """ + def _log_to_tensorboard(self, epoch, extras=None): + """Write data to tensorboard log file. Includes history values, some + timing info, and provided extras. + + Parameters + ---------- + epoch : int + Current epoch to write info for + extras : dict | None + Extra kwargs/parameters to save in the epoch history. + """ + if self._tb_writer is not None: + with self._tb_writer.as_default(): + for col in self._history.columns: + tf.summary.scalar(col, self._history.at[epoch, col], epoch) + for name, value in self._timing_details.items(): + tf.summary.scalar(name, value, epoch) + if extras is not None: + for name, value in extras.items(): + tf.summary.scalar(name, value, epoch) + def finish_epoch(self, epoch, epochs, @@ -1062,7 +1083,6 @@ def finish_epoch(self, stop : bool Flag to early stop training. """ - self.log_loss_details(loss_details) self._history.at[epoch, 'elapsed_time'] = time.time() - t0 for key, value in loss_details.items(): @@ -1090,14 +1110,7 @@ def finish_epoch(self, for k, v in extras.items(): self._history.at[epoch, k] = v - if self._tb_writer is not None: - with self._tb_writer.as_default(): - for col in self._history.columns: - tf.summary.scalar(col, self._history.at[epoch, col], epoch) - for name, value in extras.items(): - tf.summary.scalar(name, value, epoch) - for name, value in self._timing_details.items(): - tf.summary.scalar(name, value, epoch) + self._log_to_tensorboard(epoch, extras=extras) return stop @@ -1433,7 +1446,8 @@ def get_single_grad(self, self._timing_details['dt:tape.watch'] = time.time() - t0 t0 = time.time() hi_res_exo = self.get_high_res_exo_input(hi_res_true) - self._timing_details['dt:get_high_res_exo_input'] = time.time() - t0 + self._timing_details[ + 'dt:get_high_res_exo_input'] = time.time() - t0 t0 = time.time() hi_res_gen = self._tf_generate(low_res, hi_res_exo) self._timing_details['dt:tf.generate'] = time.time() - t0