Skip to content

Commit

Permalink
added kwarg to fwp to allow for constant tensor output e.g., in the c…
Browse files Browse the repository at this point in the history
…ase of precip=0
  • Loading branch information
grantbuster committed Feb 6, 2024
1 parent 92ac031 commit 6dddfc1
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,8 @@ def __init__(self,
exo_kwargs=None,
bias_correct_method=None,
bias_correct_kwargs=None,
max_nodes=None):
max_nodes=None,
allowed_const=False):
"""Use these inputs to initialize data handlers on different nodes and
to define the size of the data chunks that will be passed through the
generator.
Expand Down Expand Up @@ -715,6 +716,15 @@ def __init__(self,
max_nodes : int | None
Maximum number of nodes to distribute spatiotemporal chunks across.
If None then a node will be used for each temporal chunk.
allowed_const : list | bool
Tensorflow has a tensor memory limit of 2GB (result of protobuf
limitation) and when exceeded can return a tensor with a
constant output. sup3r will raise a ``MemoryError`` in response. If
your model is allowed to output a constant output, set this to True
to allow any constant output or a list of allowed possible constant
outputs. For example, a precipitation model should be allowed to
output all zeros so set this to ``[0]``. For details on this limit:
https://github.com/tensorflow/tensorflow/issues/51870
"""
self._input_handler_kwargs = input_handler_kwargs or {}
target = self._input_handler_kwargs.get('target', None)
Expand Down Expand Up @@ -749,6 +759,7 @@ def __init__(self,
self._lr_lat_lon = None
self._init_handler = None
self._handle_features = None
self.allowed_const = allowed_const

self._single_ts_files = self._input_handler_kwargs.get(
'single_ts_files', None)
Expand Down Expand Up @@ -1779,9 +1790,20 @@ def _constant_output_check(self, out_data):
out_data : ndarray
Forward pass output corresponding to the given chunk index
"""

allowed_const = self.strategy.allowed_const
if allowed_const is True:
return
elif allowed_const is False:
allowed_const = []
elif not isinstance(allowed_const, (list, tuple)):
allowed_const = [allowed_const]

for i, f in enumerate(self.output_features):
msg = f'All spatiotemporal values are the same for {f} output!'
if np.all(out_data[0, 0, 0, i] == out_data[..., i]):
value0 = out_data[0, 0, 0, i]
all_same = (value0 == out_data[..., i]).all()
if all_same and value0 not in allowed_const:
self.strategy.failed_chunks = True
logger.error(msg)
raise MemoryError(msg)
Expand Down

0 comments on commit 6dddfc1

Please sign in to comment.