Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Apr 9, 2024
1 parent 81e3e72 commit 1b4778e
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,27 @@ def save(self, out_dir):
if it does not already exist.
"""

def _log_to_tensorboard(self, epoch, extras=None):
"""Write data to tensorboard log file. Includes history values, some
timing info, and provided extras.
Parameters
----------
epoch : int
Current epoch to write info for
extras : dict | None
Extra kwargs/parameters to save in the epoch history.
"""
if self._tb_writer is not None:
with self._tb_writer.as_default():
for col in self._history.columns:
tf.summary.scalar(col, self._history.at[epoch, col], epoch)
for name, value in self._timing_details.items():
tf.summary.scalar(name, value, epoch)
if extras is not None:
for name, value in extras.items():
tf.summary.scalar(name, value, epoch)

def finish_epoch(self,
epoch,
epochs,
Expand Down Expand Up @@ -1062,7 +1083,6 @@ def finish_epoch(self,
stop : bool
Flag to early stop training.
"""

self.log_loss_details(loss_details)
self._history.at[epoch, 'elapsed_time'] = time.time() - t0
for key, value in loss_details.items():
Expand Down Expand Up @@ -1090,14 +1110,7 @@ def finish_epoch(self,
for k, v in extras.items():
self._history.at[epoch, k] = v

if self._tb_writer is not None:
with self._tb_writer.as_default():
for col in self._history.columns:
tf.summary.scalar(col, self._history.at[epoch, col], epoch)
for name, value in extras.items():
tf.summary.scalar(name, value, epoch)
for name, value in self._timing_details.items():
tf.summary.scalar(name, value, epoch)
self._log_to_tensorboard(epoch, extras=extras)

return stop

Expand Down Expand Up @@ -1433,7 +1446,8 @@ def get_single_grad(self,
self._timing_details['dt:tape.watch'] = time.time() - t0
t0 = time.time()
hi_res_exo = self.get_high_res_exo_input(hi_res_true)
self._timing_details['dt:get_high_res_exo_input'] = time.time() - t0
self._timing_details[
'dt:get_high_res_exo_input'] = time.time() - t0
t0 = time.time()
hi_res_gen = self._tf_generate(low_res, hi_res_exo)
self._timing_details['dt:tf.generate'] = time.time() - t0
Expand Down

0 comments on commit 1b4778e

Please sign in to comment.