Skip to content

Commit

Permalink
self.output_data attribute addition for easier input/output comparison.
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Apr 13, 2023
1 parent febf350 commit 7236cef
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,7 @@ def __init__(self, strategy, chunk_index=0, node_index=0):
self.strategy = strategy
self.chunk_index = chunk_index
self.node_index = node_index
self.output_data = None

msg = (f'Requested forward pass on chunk_index={chunk_index} > '
f'n_chunks={strategy.chunks}')
Expand Down Expand Up @@ -1815,20 +1816,20 @@ def run_chunk(self):
if self.exogenous_data is not None:
exo_data = self._prep_exogenous_input(data_chunk.shape)

out_data = self._run_generator(
self.output_data = self._run_generator(
data_chunk, hr_crop_slices=self.hr_crop_slice, model=self.model,
model_kwargs=self.model_kwargs, model_class=self.model_class,
s_enhance=self.s_enhance, t_enhance=self.t_enhance,
exo_data=exo_data)

self._constant_output_check(out_data)
self._constant_output_check(self.output_data)

if self.out_file is not None:
logger.info(f'Saving forward pass output to {self.out_file}.')
self.output_handler_class._write_output(
data=out_data, features=self.model.output_features,
data=self.output_data, features=self.model.output_features,
lat_lon=self.hr_lat_lon, times=self.hr_times,
out_file=self.out_file, meta_data=self.meta,
max_workers=self.output_workers, gids=self.gids)
else:
return out_data
return self.output_data

0 comments on commit 7236cef

Please sign in to comment.