Skip to content

Commit

Permalink
fixed input handler kwargs update
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Oct 5, 2022
1 parent e549566 commit 60d01f6
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 16 deletions.
18 changes: 4 additions & 14 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def get_input_handler_kwargs(self, input_handler_kwargs):
Parameters
----------
input_handler_kwargs : dict
Dictionary of args to pass to the
Dictionary of args to pass to the data handler
:class:`sup3r.preprocessing.data_handling.DataHandler`
"""
self._target = input_handler_kwargs.get('target', None)
Expand Down Expand Up @@ -1075,25 +1075,19 @@ def __init__(self, strategy, chunk_index=0, node_index=0):
elif strategy.output_type == 'h5':
self.output_handler_class = OutputHandlerH5

input_handler_kwargs = dict(
input_handler_kwargs = copy.deepcopy(strategy._input_handler_kwargs)
fwp_input_handler_kwargs = dict(
file_paths=self.file_paths,
features=self.features,
target=self.target,
shape=self.shape,
temporal_slice=self.temporal_pad_slice,
raster_file=self.raster_file,
cache_pattern=self.cache_pattern,
time_chunk_size=self.strategy.time_chunk_size,
overwrite_cache=self.strategy.overwrite_cache,
max_workers=self.max_workers,
extract_workers=self.extract_workers,
compute_workers=self.compute_workers,
load_workers=self.load_workers,
ti_workers=self.ti_workers,
handle_features=self.strategy.handle_features,
val_split=0.0)

input_handler_kwargs.update(self.strategy._input_handler_kwargs)
input_handler_kwargs.update(fwp_input_handler_kwargs)
self.data_handler = self.input_handler_class(**input_handler_kwargs)
self.data_handler.load_cached_data()
self.input_data = self.data_handler.data
Expand Down Expand Up @@ -1251,10 +1245,6 @@ def get_chunk_kwargs(self, strategy, chunk_index):
self._file_paths = strategy.file_paths
self.max_workers = strategy.max_workers
self.pass_workers = strategy.pass_workers
self.ti_workers = strategy.ti_workers
self.extract_workers = strategy.extract_workers
self.compute_workers = strategy.compute_workers
self.load_workers = strategy.load_workers
self.output_workers = strategy.output_workers
self.exo_kwargs = strategy.exo_kwargs
self.single_time_step_files = strategy.single_time_step_files
Expand Down
2 changes: 1 addition & 1 deletion sup3r/postprocessing/file_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,5 +532,5 @@ def _write_output(cls, data, features, lat_lon, times, out_file,

if meta_data is not None:
fh.run_attrs = {'gan_meta': json.dumps(meta_data)}
os.rename(tmp_file, out_file)
os.replace(tmp_file, out_file)
logger.info(f'Saved output of size {data.shape} to: {out_file}')
2 changes: 1 addition & 1 deletion sup3r/preprocessing/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,7 +1285,7 @@ def cache_data(self, cache_file_paths):
tmp_file = fp.replace('.pkl', '.pkl.tmp')
with open(tmp_file, 'wb') as fh:
pickle.dump(self.data[..., i], fh, protocol=4)
os.rename(tmp_file, fp)
os.replace(tmp_file, fp)
else:
msg = (f'Called cache_data but {fp} already exists. Set to '
'overwrite_cache to True to overwrite.')
Expand Down

0 comments on commit 60d01f6

Please sign in to comment.