Skip to content

Commit

Permalink
Merge pull request #184 from NREL/gb/allowed_const
Browse files Browse the repository at this point in the history
added kwarg to fwp to allow for constant tensor output e.g., in the c…
  • Loading branch information
grantbuster committed Feb 7, 2024
2 parents 92ac031 + 6dddfc1 commit fc613de
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,8 @@ def __init__(self,
exo_kwargs=None,
bias_correct_method=None,
bias_correct_kwargs=None,
max_nodes=None):
max_nodes=None,
allowed_const=False):
"""Use these inputs to initialize data handlers on different nodes and
to define the size of the data chunks that will be passed through the
generator.
Expand Down Expand Up @@ -715,6 +716,15 @@ def __init__(self,
max_nodes : int | None
Maximum number of nodes to distribute spatiotemporal chunks across.
If None then a node will be used for each temporal chunk.
allowed_const : list | bool
Tensorflow has a tensor memory limit of 2GB (result of protobuf
limitation) and when exceeded can return a tensor with a
constant output. sup3r will raise a ``MemoryError`` in response. If
your model is allowed to output a constant output, set this to True
to allow any constant output or a list of allowed possible constant
outputs. For example, a precipitation model should be allowed to
output all zeros so set this to ``[0]``. For details on this limit:
https://github.com/tensorflow/tensorflow/issues/51870
"""
self._input_handler_kwargs = input_handler_kwargs or {}
target = self._input_handler_kwargs.get('target', None)
Expand Down Expand Up @@ -749,6 +759,7 @@ def __init__(self,
self._lr_lat_lon = None
self._init_handler = None
self._handle_features = None
self.allowed_const = allowed_const

self._single_ts_files = self._input_handler_kwargs.get(
'single_ts_files', None)
Expand Down Expand Up @@ -1779,9 +1790,20 @@ def _constant_output_check(self, out_data):
out_data : ndarray
Forward pass output corresponding to the given chunk index
"""

allowed_const = self.strategy.allowed_const
if allowed_const is True:
return
elif allowed_const is False:
allowed_const = []
elif not isinstance(allowed_const, (list, tuple)):
allowed_const = [allowed_const]

for i, f in enumerate(self.output_features):
msg = f'All spatiotemporal values are the same for {f} output!'
if np.all(out_data[0, 0, 0, i] == out_data[..., i]):
value0 = out_data[0, 0, 0, i]
all_same = (value0 == out_data[..., i]).all()
if all_same and value0 not in allowed_const:
self.strategy.failed_chunks = True
logger.error(msg)
raise MemoryError(msg)
Expand Down

0 comments on commit fc613de

Please sign in to comment.