Skip to content

Commit

Permalink
fix bug (#81)
Browse files Browse the repository at this point in the history
Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin committed Aug 8, 2023
1 parent 57b0dfa commit d502a86
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion uniem/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import annotations

import logging
import os
import re
import shutil
from typing import Any, Callable, Sequence, Sized

import torch
Expand All @@ -13,6 +17,8 @@
except ImportError:
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler

logger = logging.getLogger(__name__)


class Trainer:
def __init__(
Expand Down Expand Up @@ -89,7 +95,7 @@ def train(self):
self.accelerator.log(validation_metrics, step=current_epoch)

if self.save_on_epoch_end:
self.accelerator.save_state()
self.accelerator.save_state(self.get_checkpoint_dir())

if self.epoch_end_callbacks:
for callback in self.epoch_end_callbacks:
Expand All @@ -105,6 +111,34 @@ def log_metrics(self, metrics: dict[str, float], step: int):
def add_prefix(values: dict[str, Any], prefix: str):
return {f'{prefix}/{k}': v for k, v in values.items()}

def get_checkpoint_dir(self):
# COPY FROM accelerator to fix Checkpoint bug
self.accelerator.project_configuration.automatic_checkpoint_naming = False
output_dir = os.path.join(self.accelerator.project_dir, 'checkpoints')
if self.accelerator.is_local_main_process:
os.makedirs(output_dir, exist_ok=True)
folders = [os.path.join(output_dir, folder) for folder in os.listdir(output_dir)]
if self.accelerator.project_configuration.total_limit is not None and (
len(folders) + 1 > self.accelerator.project_configuration.total_limit
):

def _inner(folder):
return list(map(int, re.findall(r'[\/]?([0-9]+)(?=[^\/]*$)', folder)))[0]

folders.sort(key=_inner)
logger.warning(
f'Deleting {len(folders) + 1 - self.accelerator.project_configuration.total_limit}'
'checkpoints to make room for new checkpoint.'
)
for folder in folders[: len(folders) + 1 - self.accelerator.project_configuration.total_limit]:
shutil.rmtree(folder)

output_dir = os.path.join(output_dir, f'checkpoint_{self.accelerator.save_iteration}')
if self.accelerator.is_local_main_process:
os.makedirs(output_dir, exist_ok=True)
logger.info(f'Saving current state to {output_dir}')
return output_dir


def evaluate(
model: torch.nn.Module,
Expand Down

0 comments on commit d502a86

Please sign in to comment.