Skip to content

Commit

Permalink
Merge pull request huggingface#144 from huggingface/xrsrke/resume_tra…
Browse files Browse the repository at this point in the history
…ining_data_stage_without_breaking

[Bug] Resuming training for data stages
  • Loading branch information
NouamaneTazi committed Apr 22, 2024
2 parents 90780b5 + 3c0e379 commit 97ea5f6
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 72 deletions.
6 changes: 3 additions & 3 deletions examples/config_tiny_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ optimizer:
parallelism:
dp: 2
expert_parallel_size: 1
pp: 2
pp: 1
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: true
Expand All @@ -110,7 +110,7 @@ data_stages:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_or_datasets: HuggingFaceH4/testing_codealpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
Expand All @@ -132,7 +132,7 @@ checkpoints:
checkpoint_interval: 10
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
resume_checkpoint_path: checkpoints
save_initial_state: false
profiler: null
logging:
Expand Down
67 changes: 53 additions & 14 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
get_datasets,
get_train_dataloader,
)
from nanotron.helpers import (
compute_remain_train_steps_of_a_data_stage_from_ckp,
get_consumed_train_samples_of_a_data_stage_from_ckp,
)
from nanotron.logging import log_rank
from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks
from nanotron.trainer import DistributedTrainer
from nanotron.utils import (
main_rank_first,
)
from nanotron.utils import main_rank_first
from torch.utils.data import DataLoader

try:
Expand All @@ -41,8 +43,21 @@
logger = logging.get_logger(__name__)


def get_dataloader_from_data_stage(trainer: DistributedTrainer, data: DataArgs):
"""Returns a dataloader for training."""
def get_dataloader_from_data_stage(
trainer: DistributedTrainer,
data: DataArgs,
consumed_train_samples: int,
num_remaining_train_steps: int,
):
"""
Returns a dataloader for a given data stage.
data: The data configuration for the current stage.
consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero).
num_remaining_train_steps: The number of remaining training steps for this stage.
"""
assert consumed_train_samples >= 0, "consumed_train_samples should be greater than 0"
assert num_remaining_train_steps >= 0, "num_remaining_train_steps should be greater than 0"

# First, we need to know which ranks to feed the dataloader to
input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)
Expand Down Expand Up @@ -105,17 +120,16 @@ def get_dataloader_from_data_stage(trainer: DistributedTrainer, data: DataArgs):
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
consumed_train_samples=trainer.consumed_train_samples,
consumed_train_samples=consumed_train_samples,
dataloader_num_workers=data.num_loading_workers,
seed_worker=data.seed,
dataloader_drop_last=True,
)

# Check if we have enough samples for train_steps
total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length
num_tokens_needed_for_training = (
(trainer.config.tokens.train_steps - trainer.start_iteration_step)
* trainer.global_batch_size
* trainer.sequence_length
num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length
)
assert num_tokens_needed_for_training <= total_tokens_dataset, (
f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), "
Expand All @@ -128,16 +142,41 @@ def get_dataloader_from_data_stage(trainer: DistributedTrainer, data: DataArgs):


def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
sorted_stages = sorted(trainer.config.data_stages, key=lambda stage: stage.start_training_step)
dataloaders = {}
for idx, stage in enumerate(sorted_stages):

for stage_idx, stage in enumerate(trainer.config.data_stages):
# NOTE: we only create the dataloader for the first stage,
# then we lazy initialize the dataloader for the other stages
stage = cast(DatasetStageArgs, stage)
consumed_train_samples = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, trainer.metadata)
assert (
consumed_train_samples is not None
), f"Cannot find consumed_train_samples for stage {stage.start_training_step} in the checkpoint"

num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp(
stage, trainer.config, trainer.metadata
)
log_rank(
f"[Training Plan] Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples",
logger=logger,
level=logging.INFO,
rank=0,
)

dataloader = (
get_dataloader_from_data_stage(trainer, stage.data)
if idx == 0
else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data)
get_dataloader_from_data_stage(
trainer,
stage.data,
consumed_train_samples=consumed_train_samples,
num_remaining_train_steps=num_remaining_train_steps,
)
if stage_idx == 0
else lambda stage=stage: get_dataloader_from_data_stage(
trainer,
stage.data,
consumed_train_samples=consumed_train_samples,
num_remaining_train_steps=num_remaining_train_steps,
)
)
dataloaders[stage.name] = dataloader
return dataloaders
Expand Down
7 changes: 7 additions & 0 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def __post_init__(self):
)

if self.data_stages is not None:
self.data_stages = sorted(self.data_stages, key=lambda stage: stage.start_training_step)
names = [stage.name for stage in self.data_stages]
training_steps = [stage.start_training_step for stage in self.data_stages]
assert any(
Expand All @@ -353,6 +354,12 @@ def __post_init__(self):
f"Each stage should have unique starting training step, please change the starting training step for stage {stage.name}"
)

# NOTE: must order the stages by start_training_step from lowest to highest
assert all(
self.data_stages[i].start_training_step < self.data_stages[i + 1].start_training_step
for i in range(len(self.data_stages) - 1)
), "The stages are not sorted by start_training_step in increasing order"

# # if lighteval, we need tokenizer to be defined
# if self.checkpoints.lighteval is not None:
# assert self.tokenizer.tokenizer_name_or_path is not None
Expand Down
7 changes: 6 additions & 1 deletion src/nanotron/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

from packaging.version import Version, parse

CHECKPOINT_VERSION = Version("1.2")
CHECKPOINT_VERSION = Version("1.3")

PY_VERSION = parse(platform.python_version())

#### FOR SERIALIZATION ####

CHECKPOINT_FILE_NAME = "checkpoint_metadata.json"
MODEL_CONFIG_FILE_NAME = "model_config.json"
32 changes: 31 additions & 1 deletion src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import Config, LRSchedulerArgs, OptimizerArgs, ParallelismArgs
from nanotron.config import Config, DatasetStageArgs, LRSchedulerArgs, OptimizerArgs, ParallelismArgs
from nanotron.distributed import ProcessGroup
from nanotron.logging import LogItem, log_rank
from nanotron.models.base import NanotronModel
Expand All @@ -42,6 +42,7 @@
get_synced_random_state,
)
from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod
from nanotron.serialize.metadata import TrainingMetadata

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -669,3 +670,32 @@ def log_throughput(

if dist.get_rank(parallel_context.world_pg) == 0:
write_to_csv(config.general.benchmark_csv_path, table_log, model_tflops, slurm_job_id)


def compute_remain_train_steps_of_a_data_stage_from_ckp(
stage: DatasetStageArgs, config: Config, metadata: TrainingMetadata
) -> int:
def is_last_stage():
sorted_stages = sorted(config.data_stages, key=lambda x: x.start_training_step)
return sorted_stages[-1].start_training_step == stage.start_training_step

def is_resume_from_training():
return metadata.last_train_step > 0

if is_last_stage() is True:
total_train_steps = config.tokens.train_steps
else:
next_stage = next((s for s in config.data_stages if s.start_training_step > stage.start_training_step), None)
total_train_steps = next_stage.start_training_step
last_train_steps = metadata.last_train_step if is_resume_from_training() else stage.start_training_step
return total_train_steps - last_train_steps


def get_consumed_train_samples_of_a_data_stage_from_ckp(
stage: DatasetStageArgs, metadata: TrainingMetadata
) -> Optional[int]:
start_training_step = stage.start_training_step
return next(
(s.consumed_train_samples for s in metadata.data_stages if s.start_training_step == start_training_step),
None,
)
21 changes: 13 additions & 8 deletions src/nanotron/serialize/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from pathlib import Path
from typing import Optional
from typing import Optional, cast

import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import LambdaLR

from nanotron import distributed as dist
from nanotron import logging
from nanotron import optim as optim
from nanotron.config import Config
from nanotron.constants import MODEL_CONFIG_FILE_NAME
from nanotron.distributed import get_global_rank
from nanotron.logging import log_rank
from nanotron.parallel import ParallelContext
Expand All @@ -17,7 +19,7 @@
assert_tensor_synced_across_pg,
check_optim_state_in_sync,
)
from nanotron.serialize.metadata import CheckpointMetadata, load_meta, save_meta
from nanotron.serialize.metadata import CheckpointMetadata, TrainingMetadata, load_meta, save_meta
from nanotron.serialize.optimizer import (
load_lr_scheduler,
load_optimizer,
Expand Down Expand Up @@ -51,16 +53,15 @@ def save(
optimizer: optim.BaseOptimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
parallel_context: ParallelContext,
training_metadata: TrainingMetadata,
root_folder: Path,
should_save_config: bool = True,
should_save_model: bool = True,
should_save_optimizer: bool = True,
should_save_lr_scheduler: bool = True,
checkpoint_metadata: dict = None,
sanity_checks: bool = True,
) -> None:
if checkpoint_metadata is None:
checkpoint_metadata = {}
assert isinstance(training_metadata, TrainingMetadata)

try:
if should_save_config:
Expand Down Expand Up @@ -98,6 +99,11 @@ def save(
raise e
try:
if should_save_lr_scheduler:
lr_scheduler = cast(LambdaLR, lr_scheduler)
assert len(lr_scheduler.lr_lambdas) == len(
optimizer.param_groups
), "The number of lambdas functions in the scheduler should be equal to the number of parameter groups in the optimizer."

save_lr_scheduler(
lr_scheduler=lr_scheduler,
parallel_context=parallel_context,
Expand All @@ -112,7 +118,7 @@ def save(
)
raise e

save_meta(root_folder=root_folder, parallel_context=parallel_context, checkpoint_metadata=checkpoint_metadata)
save_meta(root_folder=root_folder, parallel_context=parallel_context, training_metadata=training_metadata)

# TODO @thomas21: sanity check, not sure whether that needs to happen at testing or now (depends how much it costs)
###
Expand Down Expand Up @@ -194,7 +200,6 @@ def save(
rtol=0,
msg=lambda msg: f"tensor at {current_state_dict['names'][index]} doesn't match with our reference. Optimizer key: {name}\nCur: {tensor}\nRef: {reference_tensor}\n{msg}",
)
###

dist.barrier(parallel_context.world_pg)

Expand Down Expand Up @@ -256,7 +261,7 @@ def parse_ckpt_path(config: Config) -> Optional[Path]:
load_from_candidate = int(fi.read())
checkpoint_path = config.checkpoints.resume_checkpoint_path / str(load_from_candidate)

elif (config.checkpoints.resume_checkpoint_path / "model_config.json").exists():
elif (config.checkpoints.resume_checkpoint_path / MODEL_CONFIG_FILE_NAME).exists():
# we assume that the checkpoint path is a path to a checkpoint
checkpoint_path = config.checkpoints.resume_checkpoint_path

Expand Down
Loading

0 comments on commit 97ea5f6

Please sign in to comment.