Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorial release, new features, and bug fix #75

Merged
merged 23 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5325d25
fix: add missing https:// in the issue-template config file;
WenjieDu Apr 22, 2023
44335be
Add unit-test cases for `pypots-cli` (#72)
WenjieDu Apr 23, 2023
d84e595
fix: only report coverage if file .coverage exists;
WenjieDu Apr 23, 2023
a54cea3
Merge branch 'main' into dev
WenjieDu Apr 23, 2023
07128b4
fix: remove cli-testing case of show-coverage to avoid mis-calculation;
WenjieDu Apr 23, 2023
dd6b793
fix: must not delete .coverage file after testing;
WenjieDu Apr 23, 2023
568b3c5
Fix bugs in the code-coverage report (#73)
WenjieDu Apr 23, 2023
f330f85
feat: default disabling early-stopping mechanism during model training;
WenjieDu Apr 24, 2023
c27f22c
fix: return correct val_X and test_X in gene_physionet2012() when art…
WenjieDu Apr 24, 2023
ea04dd6
feat: add pypots.random.set_random_seed();
WenjieDu Apr 24, 2023
0787260
feat: enable `return_labels` in Dataset classes;
WenjieDu Apr 24, 2023
895f9bc
refactor: remove autoflake that is not quite useful;
WenjieDu Apr 25, 2023
504bdd0
feat: enable automatically saving model into file if necessary;
WenjieDu Apr 25, 2023
e2485de
fix: remove typing.Literal which is not supported in python 3.7;
WenjieDu Apr 25, 2023
922bbfb
fix: the disordered labels in the returned data;
WenjieDu Apr 25, 2023
c7b6e26
fix: mistaken logical code in auto_save_model_if_necessary;
WenjieDu Apr 25, 2023
ea560d4
Add devcontainer config (#76)
WenjieDu Apr 27, 2023
4df32de
fix: set return_labels=False for training Dataset for CRLI and VaDER;
WenjieDu Apr 27, 2023
baab39e
feat: add git stale config file;
WenjieDu Apr 27, 2023
cce28bd
doc: remove tutorials dir, will create a new repo to put all tutorials;
WenjieDu Apr 27, 2023
4b25fb6
fix: remove tutorials from checking;
WenjieDu Apr 27, 2023
1f42c77
feat: add jupyterlab as a dev dependency, update README;
WenjieDu Apr 27, 2023
39b2bbe
doc: update README to add the link of BrewedPOTS;
WenjieDu Apr 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: add pypots.random.set_random_seed();
  • Loading branch information
WenjieDu committed Apr 24, 2023
commit ea04dd634204428c52c17f853782273f9e47aa11
23 changes: 23 additions & 0 deletions pypots/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import shutil
import unittest

import torch

from pypots.utils.logging import Logger
from pypots.utils.random import set_random_seed


class TestLogging(unittest.TestCase):
Expand Down Expand Up @@ -46,5 +49,25 @@ def test_saving_log_into_file(self):
shutil.rmtree("test_log", ignore_errors=True)


class TestRandom(unittest.TestCase):
def test_set_random_seed(self):
random_state1 = torch.get_rng_state()
torch.rand(
1, 3
) # randomly generate something, the random state will be reset, so two states should be varying
random_state2 = torch.get_rng_state()
assert not torch.equal(
random_state1, random_state2
), "The random seed hasn't set, so two random states should be different."

set_random_seed(26)
random_state1 = torch.get_rng_state()
set_random_seed(26)
random_state2 = torch.get_rng_state()
assert torch.equal(
random_state1, random_state2
), "The random seed has been set, two random states are not the same."


if __name__ == "__main__":
unittest.main()
29 changes: 29 additions & 0 deletions pypots/utils/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Transformer model for time-series imputation.
"""

# Created by Wenjie Du <[email protected]>
# License: GLP-v3

import numpy as np
import torch
from pypots.utils.logging import logger

RANDOM_SEED = 2204


def set_random_seed(random_seed: int = RANDOM_SEED):
"""Manually set the random state to make PyPOTS output reproducible results.

Parameters
----------
random_seed : int, default = RANDOM_SEED,
The seed to be set for generating random numbers in PyPOTS.

"""

np.random.seed(RANDOM_SEED)
torch.manual_seed(random_seed)
logger.info(
f"Done. Have already set the random seed as {random_seed} for numpy and pytorch."
)