Skip to content

Commit

Permalink
Merge pull request #62 from moskomule/dev
Browse files Browse the repository at this point in the history
Accumulated Updates
  • Loading branch information
moskomule committed Dec 20, 2021
2 parents 24b83ae + a2c4929 commit 721b03a
Show file tree
Hide file tree
Showing 28 changed files with 261 additions and 230 deletions.
10 changes: 4 additions & 6 deletions .github/workflows/ghpage.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
name: ghpage
on:
push:
branches:
- master
on: [ push ]

jobs:
build:
Expand All @@ -21,8 +18,8 @@ jobs:
python -m venv venv
. venv/bin/activate
pip install -U pip
pip install Sphinx sphinx-rtd-theme
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install Sphinx sphinx-rtd-theme myst-parser
pip install torch==1.10.0+cpu torchvision==0.11.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install -U .
- name: build
Expand All @@ -33,6 +30,7 @@ jobs:
make html
- name: push
if: github.ref == 'refs/heads/master'
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
Expand Down
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
'sphinx.ext.autodoc',
'sphinx.ext.viewcode',
'sphinx.ext.autosummary',
'myst_parser'
]

# Add any paths that contain templates here, relative to this directory.
Expand All @@ -49,8 +50,7 @@
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
source_suffix = ['.rst', '.md']

# The master toctree document.
master_doc = 'index'
Expand Down
2 changes: 0 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ Welcome to `homura`'s documentation!

homura.metrics
homura.modules

homura.utils
homura.nlp
homura.vision

Indices and tables
Expand Down
6 changes: 3 additions & 3 deletions examples/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@chika.config
class Config:
model: str = chika.choices(*MODEL_REGISTRY.choices())
model: str = chika.choices("wrn28_2", "wrn28_10")
batch_size: int = 128

epochs: int = 200
Expand Down Expand Up @@ -38,7 +38,7 @@ def main(cfg):
model = MODEL_REGISTRY(cfg.model)(num_classes=data.num_classes)
optimizer = None if cfg.bn_no_wd else optim.SGD(lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay,
multi_tensor=cfg.use_multi_tensor)
scheduler = lr_scheduler.CosineAnnealingWithWarmup(cfg.epochs, 4, 5)
scheduler = lr_scheduler.CosineAnnealingWithWarmup(cfg.epochs, 5)

if cfg.bn_no_wd:
def set_optimizer(trainer):
Expand All @@ -53,7 +53,7 @@ def set_optimizer(trainer):
{"params": bn_params, "weight_decay": 0},
{"params": non_bn_parameters, "weight_decay": cfg.weight_decay},
]
trainer.optimizer = torch.optim.SGD(optim_params, lr=1e-1, momentum=0.9)
trainer.optimizer = torch.optim.SGD(optim_params, lr=cfg.lr, momentum=0.9)

trainers.SupervisedTrainer.set_optimizer = set_optimizer

Expand Down
5 changes: 3 additions & 2 deletions examples/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Config:
debug: bool = False
use_amp: bool = False
use_sync_bn: bool = False
num_workers: int = 4
num_workers: int = 16

init_method: str = "env:https://"
backend: str = "nccl"
Expand Down Expand Up @@ -51,9 +51,10 @@ def main(cfg: Config):
use_sync_bn=cfg.use_sync_bn,
report_accuracy_topk=5) as trainer:

for epoch in trainer.epoch_range(cfg.epochs):
for _ in trainer.epoch_range(cfg.epochs):
trainer.train(train_loader)
trainer.test(test_loader)
trainer.scheduler.step()

print(f"Max Test Accuracy={max(trainer.reporter.history('accuracy/test')):.3f}")

Expand Down
10 changes: 5 additions & 5 deletions homura/liblog.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import threading
import warnings
from typing import Optional, TextIO
from typing import TextIO

import tqdm as _tqdm
from tqdm.contrib import DummyTqdmFile
Expand Down Expand Up @@ -117,7 +117,7 @@ def disable_propagation() -> None:


def set_file_handler(log_file: str or TextIO, level: str or int = logging.DEBUG,
formatter: Optional[logging.Formatter] = None) -> None:
formatter: logging.Formatter = None) -> None:
_configure_root_logger()
fh = logging.FileHandler(log_file)
if isinstance(level, str):
Expand All @@ -131,7 +131,7 @@ def set_file_handler(log_file: str or TextIO, level: str or int = logging.DEBUG,

# internal APIs
def set_tqdm_handler(level: str or int = logging.INFO,
formatter: Optional[logging.Formatter] = None) -> None:
formatter: logging.Formatter = None) -> None:
""" An alternative handler to avoid disturbing tqdm
"""

Expand Down Expand Up @@ -189,7 +189,7 @@ def tqdm(*args, **kwargs):

def log_once(logger,
message: str,
key=Optional[str]) -> None:
key=str) -> None:
""" Log message only once.
:param logger: e.g., `print`, `logger.info`
Expand All @@ -207,7 +207,7 @@ def log_once(logger,


def print_once(message: str,
key=Optional[str]) -> None:
key=str) -> None:
""" `print` version of `log_once`
"""

Expand Down
28 changes: 24 additions & 4 deletions homura/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import bisect
import math
import warnings
import bisect
from functools import partial
from typing import List

from torch.optim import lr_scheduler as _lr_scheduler

Expand All @@ -20,7 +19,7 @@ def MultiStepLR(milestones,


def MultiStepWithWarmup(warmup: int,
milestones: List[int],
milestones: list[int],
gamma: float = 0.1,
last_epoch: int = -1):
return partial(_lr_scheduler.LambdaLR,
Expand Down Expand Up @@ -51,6 +50,17 @@ def ReduceLROnPlateau(mode='min',
return partial(_lr_scheduler.ReduceLROnPlateau, **locals())


def InverseSquareRootWithWarmup(warmup_epochs: int,
last_epoch: int = -1):
""" inverse square root with warmup: $\\sqrt{w} \\min(1/\\sqrt{e}, e/\\sqrt{e}^3)$, where $w$ is `warmup_epochs` and
`e` is the current epoch
"""
return partial(_lr_scheduler.LambdaLR,
lr_lambda=inverse_square_root_with_warmup(warmup_epochs),
last_epoch=last_epoch)


def CosineAnnealingWithWarmup(total_epochs: int,
warmup_epochs: int,
min_lr: float = 0,
Expand Down Expand Up @@ -92,7 +102,7 @@ def f(epoch):


def multistep_with_warmup(warmup_epochs: int,
milestones: List[int],
milestones: list[int],
gamma: float = 0.1,
):
def f(epoch):
Expand All @@ -101,3 +111,13 @@ def f(epoch):
return gamma ** bisect.bisect_right(milestones, epoch)

return f


def inverse_square_root_with_warmup(warmup_epochs: int,
):
def f(epoch):
epoch += 1
factor = warmup_epochs ** 0.5
return factor * min(epoch ** -0.5, epoch * warmup_epochs ** -1.5)

return f
4 changes: 2 additions & 2 deletions homura/metrics/commons.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from __future__ import annotations

import torch
from torch import Tensor
Expand All @@ -14,7 +14,7 @@

def _base(input: Tensor,
target: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
classes = torch.arange(input.size(1), device=input.device)
pred = input.argmax(dim=1).view(-1, 1)
target = target.view(-1, 1)
Expand Down
8 changes: 4 additions & 4 deletions homura/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple
from __future__ import annotations

import torch
from torch import nn
Expand All @@ -25,9 +25,9 @@ def forward(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
additive_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
mask: torch.Tensor = None,
additive_mask: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
""" See `functional.attention.kv_attention` for details
:param query:
Expand Down
3 changes: 2 additions & 1 deletion homura/modules/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]:
return self._original_model.parameters(recurse)

def requires_grad_(self, requires_grad: bool = True) -> nn.Module:
return self._original_model.requires_grad_(requires_grad)
self._original_model.requires_grad_(requires_grad)
return self

@torch.no_grad()
def _update(self):
Expand Down
8 changes: 4 additions & 4 deletions homura/modules/functional/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from __future__ import annotations

import torch
from torch.nn import functional as F
Expand All @@ -7,12 +7,12 @@
def kv_attention(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
additive_mask: Optional[torch.Tensor] = None,
mask: torch.Tensor = None,
additive_mask: torch.Tensor = None,
training: bool = True,
dropout_prob: float = 0,
scaling: bool = True
) -> (torch.Tensor, torch.Tensor):
) -> tuple[torch.Tensor, torch.Tensor]:
"""Attention using queries, keys and value
:param query: `...JxM`
Expand Down
4 changes: 2 additions & 2 deletions homura/modules/functional/grad_approximation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from __future__ import annotations

import torch
from torch.autograd import Function
Expand All @@ -21,7 +21,7 @@ def forward(ctx,

@staticmethod
def backward(ctx,
grad_in: torch.Tensor) -> Tuple[None, torch.Tensor]:
grad_in: torch.Tensor) -> tuple[None, torch.Tensor]:
return None, grad_in.sum_to_size(ctx.shape)


Expand Down
8 changes: 4 additions & 4 deletions homura/modules/functional/knn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from __future__ import annotations

import torch

Expand All @@ -22,7 +22,7 @@ def torch_knn(keys: torch.Tensor,
queries: torch.Tensor,
num_neighbors: int,
distance: str
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
""" k nearest neighbor using torch. Users are recommended to use `k_nearest_neighbor` instead.
"""

Expand All @@ -48,7 +48,7 @@ def faiss_knn(keys: torch.Tensor,
queries: torch.Tensor,
num_neighbors: int,
distance: str
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
""" k nearest neighbor using faiss. Users are recommended to use `k_nearest_neighbor` instead.
:param keys: tensor of (num_keys, dim)
Expand Down Expand Up @@ -96,7 +96,7 @@ def k_nearest_neighbor(keys: torch.Tensor,
queries: torch.Tensor,
num_neighbors: int,
distance: str, *,
backend: str = "torch") -> Tuple[torch.Tensor, torch.Tensor]:
backend: str = "torch") -> tuple[torch.Tensor, torch.Tensor]:
""" k-Nearest Neighbor search. Faiss backend requires GPU. torch backend is JITtable
:param keys: tensor of (num_keys, dim)
Expand Down
12 changes: 7 additions & 5 deletions homura/register.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import contextlib
import functools
import types
from pathlib import Path
from typing import Dict, Optional, Type, TypeVar
from typing import Type, TypeVar

T = TypeVar("T")

Expand All @@ -27,7 +29,7 @@ def your_model(*args, **kwargs):

def __new__(cls,
name: str,
type: Optional[Type[T]] = None
type: Type[T] = None
):
if name in Registry._available_registries:
return Registry._available_registries[name]
Expand All @@ -36,21 +38,21 @@ def __new__(cls,

def __init__(self,
name: str,
type: Optional[Type[T]] = None):
type: Type[T] = None):
self.name = name
Registry._available_registries[name] = self
self.type = type
self._registry = {}

def register_from_dict(self,
name_to_func: Dict[str, T]):
name_to_func: dict[str, T]):
for k, v in name_to_func.items():
self.register(v, name=k)

def register(self,
func: T = None,
*,
name: Optional[str] = None
name: str = None
) -> T:
if func is None:
return functools.partial(self.register, name=name)
Expand Down
Loading

0 comments on commit 721b03a

Please sign in to comment.