Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Temporal slicing in Validation data #123

Closed
malihass opened this issue Nov 22, 2022 · 2 comments · Fixed by #121
Closed

Temporal slicing in Validation data #123

malihass opened this issue Nov 22, 2022 · 2 comments · Fixed by #121
Labels
bug Something isn't working

Comments

@malihass
Copy link
Collaborator

malihass commented Nov 22, 2022

Bug Description
When using temporal slicing (None, None, x) where x>1, some tests fail. It seems to be due to tupleIndex in ValidationData

Full Traceback

Traceback (most recent call last):
  File "test_train_gan_tslice.py", line 69, in <module>
    test_train_st_weight_update(log=True, n_epoch=1, temporal_slice=(None, None, 3))
  File "test_train_gan_tslice.py", line 63, in test_train_st_weight_update
    adaptive_update_fraction=0.05)
  File "/Users/mhassana/Desktop/GitHub/sup3r_nov24_issue/sup3r/models/base.py", line 1140, in train
    loss_details)
  File "/Users/mhassana/Desktop/GitHub/sup3r_nov24_issue/sup3r/models/base.py", line 901, in calc_val_loss
    for val_batch in batch_handler.val_data:
  File "/Users/mhassana/Desktop/GitHub/sup3r_nov24_issue/sup3r/preprocessing/batch_handling.py", line 316, in __next__
    val_index['tuple_index']]
ValueError: could not broadcast input array from shape (18,18,14,3) into shape (18,18,24,3)

Code Sample

import os
import json
import numpy as np
import pytest
import tempfile
import tensorflow as tf
from tensorflow.python.framework.errors_impl import InvalidArgumentError

from rex import init_logger

from sup3r import TEST_DATA_DIR
from sup3r import CONFIG_DIR
from sup3r.models import Sup3rGan
from sup3r.models.data_centric import Sup3rGanDC, Sup3rGanSpatialDC
from sup3r.preprocessing.data_handling import (DataHandlerH5,
                                               DataHandlerDCforH5)
from sup3r.preprocessing.batch_handling import (BatchHandler,
                                                BatchHandlerDC,
                                                SpatialBatchHandler,
                                                BatchHandlerSpatialDC)
from sup3r.utilities.loss_metrics import MmdMseLoss


FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
TARGET_COORD = (39.01, -105.15)
FEATURES = ['U_100m', 'V_100m', 'BVF2_200m']


def test_train_st_weight_update(n_epoch=5, log=False, temporal_slice=slice(None, None, 1)):
    """Test basic spatiotemporal model training with discriminators and
    adversarial loss updating."""
    if log:
        init_logger('sup3r', log_level='DEBUG')

    fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json')
    fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json')

    Sup3rGan.seed()
    model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4,
                     learning_rate_disc=3e-4)

    handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD,
                            shape=(20, 20),
                            sample_shape=(18, 18, 24),
                            temporal_slice=temporal_slice,
                            val_split=0.005,
                            max_workers=1)

    batch_handler = BatchHandler([handler], batch_size=4,
                                 s_enhance=3, t_enhance=4,
                                 n_batches=4)

    adaptive_update_bounds = (0.9, 0.99)
    with tempfile.TemporaryDirectory() as td:
        model.train(batch_handler, n_epoch=n_epoch,
                    weight_gen_advers=1e-6,
                    train_gen=True, train_disc=True,
                    checkpoint_int=10,
                    out_dir=os.path.join(td, 'test_{epoch}'),
                    adaptive_update_bounds=adaptive_update_bounds,
                    adaptive_update_fraction=0.05)

if __name__ == "__main__":
    print("\n\n DOING temporal_slice=(None, None, 1) \n\n")
    test_train_st_weight_update(log=True, n_epoch=1, temporal_slice=(None, None, 1))
    print("\n\n DOING temporal_slice=(None, None, 3) \n\n")
    test_train_st_weight_update(log=True, n_epoch=1, temporal_slice=(None, None, 3))

To Reproduce
Steps to reproduce the problem behavior

  1. Copy code sample to tests/
  2. Execute the python script

Expected behavior
Any temporal slicing should work (in the limit of the dataset size)

@malihass malihass added the bug Something isn't working label Nov 22, 2022
@bnb32
Copy link
Collaborator

bnb32 commented Nov 22, 2022

It looks like since val_split=0.005 and the step=3 that only leaves 14 time steps in the validation data (8784 * 0.005 // 3), and the sample_shape is requesting 24 time steps. Should definitely be a clearer error message.

Added warning here -

def _val_split_check(self):

@malihass
Copy link
Collaborator Author

Ah you are right, this was the issue, I let you close this when that PR with val_split_check is merged.
Thanks!

@bnb32 bnb32 linked a pull request Nov 23, 2022 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants