From 6dddfc173827bf2079e7e11370304a8bbe09af41 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Tue, 6 Feb 2024 16:52:35 -0700 Subject: [PATCH] added kwarg to fwp to allow for constant tensor output e.g., in the case of precip=0 --- sup3r/pipeline/forward_pass.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 2173d7e4c..8b71b106d 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -610,7 +610,8 @@ def __init__(self, exo_kwargs=None, bias_correct_method=None, bias_correct_kwargs=None, - max_nodes=None): + max_nodes=None, + allowed_const=False): """Use these inputs to initialize data handlers on different nodes and to define the size of the data chunks that will be passed through the generator. @@ -715,6 +716,15 @@ def __init__(self, max_nodes : int | None Maximum number of nodes to distribute spatiotemporal chunks across. If None then a node will be used for each temporal chunk. + allowed_const : list | bool + Tensorflow has a tensor memory limit of 2GB (result of protobuf + limitation) and when exceeded can return a tensor with a + constant output. sup3r will raise a ``MemoryError`` in response. If + your model is allowed to output a constant output, set this to True + to allow any constant output or a list of allowed possible constant + outputs. For example, a precipitation model should be allowed to + output all zeros so set this to ``[0]``. For details on this limit: + https://github.com/tensorflow/tensorflow/issues/51870 """ self._input_handler_kwargs = input_handler_kwargs or {} target = self._input_handler_kwargs.get('target', None) @@ -749,6 +759,7 @@ def __init__(self, self._lr_lat_lon = None self._init_handler = None self._handle_features = None + self.allowed_const = allowed_const self._single_ts_files = self._input_handler_kwargs.get( 'single_ts_files', None) @@ -1779,9 +1790,20 @@ def _constant_output_check(self, out_data): out_data : ndarray Forward pass output corresponding to the given chunk index """ + + allowed_const = self.strategy.allowed_const + if allowed_const is True: + return + elif allowed_const is False: + allowed_const = [] + elif not isinstance(allowed_const, (list, tuple)): + allowed_const = [allowed_const] + for i, f in enumerate(self.output_features): msg = f'All spatiotemporal values are the same for {f} output!' - if np.all(out_data[0, 0, 0, i] == out_data[..., i]): + value0 = out_data[0, 0, 0, i] + all_same = (value0 == out_data[..., i]).all() + if all_same and value0 not in allowed_const: self.strategy.failed_chunks = True logger.error(msg) raise MemoryError(msg)