Skip to content

Commit

Permalink
fix examples
Browse files Browse the repository at this point in the history
  • Loading branch information
moskomule committed Dec 8, 2021
1 parent 01f8825 commit e7c9471
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
6 changes: 3 additions & 3 deletions examples/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@chika.config
class Config:
model: str = chika.choices(*MODEL_REGISTRY.choices())
model: str = chika.choices("wrn28_2", "wrn28_10")
batch_size: int = 128

epochs: int = 200
Expand Down Expand Up @@ -38,7 +38,7 @@ def main(cfg):
model = MODEL_REGISTRY(cfg.model)(num_classes=data.num_classes)
optimizer = None if cfg.bn_no_wd else optim.SGD(lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay,
multi_tensor=cfg.use_multi_tensor)
scheduler = lr_scheduler.CosineAnnealingWithWarmup(cfg.epochs, 4, 5)
scheduler = lr_scheduler.CosineAnnealingWithWarmup(cfg.epochs, 5)

if cfg.bn_no_wd:
def set_optimizer(trainer):
Expand All @@ -53,7 +53,7 @@ def set_optimizer(trainer):
{"params": bn_params, "weight_decay": 0},
{"params": non_bn_parameters, "weight_decay": cfg.weight_decay},
]
trainer.optimizer = torch.optim.SGD(optim_params, lr=1e-1, momentum=0.9)
trainer.optimizer = torch.optim.SGD(optim_params, lr=cfg.lr, momentum=0.9)

trainers.SupervisedTrainer.set_optimizer = set_optimizer

Expand Down
5 changes: 3 additions & 2 deletions examples/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Config:
debug: bool = False
use_amp: bool = False
use_sync_bn: bool = False
num_workers: int = 4
num_workers: int = 16

init_method: str = "env:https://"
backend: str = "nccl"
Expand Down Expand Up @@ -51,9 +51,10 @@ def main(cfg: Config):
use_sync_bn=cfg.use_sync_bn,
report_accuracy_topk=5) as trainer:

for epoch in trainer.epoch_range(cfg.epochs):
for _ in trainer.epoch_range(cfg.epochs):
trainer.train(train_loader)
trainer.test(test_loader)
trainer.scheduler.step()

print(f"Max Test Accuracy={max(trainer.reporter.history('accuracy/test')):.3f}")

Expand Down

0 comments on commit e7c9471

Please sign in to comment.