Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: stack expects each tensor to be equal size #929

Closed
cateto opened this issue May 10, 2023 · 4 comments
Closed

RuntimeError: stack expects each tensor to be equal size #929

cateto opened this issue May 10, 2023 · 4 comments
Labels
bug Something isn't working

Comments

@cateto
Copy link

cateto commented May 10, 2023

Describe the bug
when i train this model by config below (train_micro_batch_size_per_gpu=100),
raise runtime error.
but i try to set train_micro_batch_size_per_gpu < 100. it works.
but i want to use full gpu memory..!
please let me know

ML-01: training ...
ML-01: Traceback (most recent call last):
ML-01:   File "train.py", line 27, in <module>
ML-01:     pretrain(neox_args=neox_args)
ML-01:   File "/home/research/gpt-neox/megatron/training.py", line 226, in pretrain
ML-01:     iteration = train(
ML-01:   File "/home/research/gpt-neox/megatron/training.py", line 778, in train
ML-01:     loss_dict, skipped_iter = train_step(
ML-01:   File "/home/research/gpt-neox/megatron/training.py", line 684, in train_step
ML-01:     reduced_loss = train_step_pipe(
ML-01:   File "/home/research/gpt-neox/megatron/training.py", line 734, in train_step_pipe
ML-01:     loss = model.train_batch(data_iter=data_iterator)
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 346, in train_batch
ML-01:     self._exec_schedule(sched)
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 1374, in _exec_schedule
ML-01:     self._exec_instr(**cmd.kwargs)
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 790, in _exec_load_micro_batch
ML-01:     batch = self._next_batch()
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 622, in _next_batch
ML-01:     batch = next(self.data_iterator)
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 681, in __next__
ML-01:     data = self._next_data()
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1376, in _next_data
ML-01:     return self._process_data(data)
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1402, in _process_data
ML-01:     data.reraise()
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/torch/_utils.py", line 461, in reraise
ML-01:     raise exception
ML-01: RuntimeError: Caught RuntimeError in DataLoader worker process 0.
ML-01: Original Traceback (most recent call last):
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
ML-01:     data = fetcher.fetch(index)
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
ML-01:     return self.collate_fn(data)
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 162, in default_collate
ML-01:     return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 162, in <dictcomp>
ML-01:     return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 151, in default_collate
ML-01:     return default_collate([torch.as_tensor(b) for b in batch])
ML-01:   File "/home/research/anaconda3/envs/gpt_neox_py38/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 143, in default_collate
ML-01:     return torch.stack(batch, 0, out=out)
ML-01: RuntimeError: stack expects each tensor to be equal size, but got [2049] at entry 0 and [6805] at entry 30

Environment (please complete the following information):

  • GPUs: A100 80GB * 4
  • Configs:
Configs(Click)
```

    -------------------- arguments --------------------
      attention_config ................ ['global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global', 'global']updated
      attention_dropout ............... 0.0.........................updated
      batch_size ...................... 100..........................updated
      checkpoint_activations .......... True........................updated
      checkpoint_factor ............... 10000.......................updated
      clip_grad ....................... 1.0.........................updated
      config_files .................... {'125M-multinode.yml': '# GPT-2 pretraining setup\n{\n   # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages\n   # across the node boundaries )\n   "pipe-parallel-size": 2,\n   "model-parallel-size": 2,\n   "num_gpus": 2,\n   "num_nodes": 2,\n\n   # model settings\n   "num-layers": 12,\n   "hidden-size": 768,\n   "num-attention-heads": 12,\n   "seq-length": 2048,\n   "max-position-embeddings": 2048,\n   "norm": "layernorm",\n   "pos-emb": "rotary",\n   "no-weight-tying": true,\n   "gpt_j_residual": false,\n   "output_layer_parallelism": "column",\n\n   # these should provide some speedup but takes a while to build, set to true if desired\n   "scaled-upper-triang-masked-softmax-fusion": false,\n   "bias-gelu-fusion": false,\n\n   # init methods\n   "init_method": "small_init",\n   "output_layer_init_method": "wang_init",\n\n\n   # optimizer settings\n   "optimizer": {\n     "type": "Adam",\n     "params": {\n       "lr": 0.0006,\n       "betas": [0.9, 0.95],\n       "eps": 1.0e-8,\n     }\n   },\n   "min_lr": 0.00006,\n\n   # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training\n   "zero_optimization": {\n    "stage": 1,\n    "allgather_partitions": True,\n    "allgather_bucket_size": 500000000,\n    "overlap_comm": True,\n    "reduce_scatter": True,\n    "reduce_bucket_size": 500000000,\n    "contiguous_gradients": True,\n  },\n\n   # batch / data settings\n   "train_micro_batch_size_per_gpu": 100,\n   "data-impl": "mmap",\n   #"split": "949,50,1",\n\n   # activation checkpointing\n   "checkpoint-activations": true,\n   "checkpoint-num-layers": 1,\n   "partition-activations": true,\n   "synchronize-each-layer": true,\n\n   # regularization\n   "gradient_clipping": 1.0,\n   "weight-decay": 0.1,\n   "hidden-dropout": 0.0,\n   "attention-dropout": 0.0,\n\n   # precision settings\n   "fp16": {\n     "enabled": true,\n     "loss_scale": 0,\n     "loss_scale_window": 1000,\n     "hysteresis": 2,\n     "min_loss_scale": 1\n   },\n\n   # misc. training settings\n   "train-iters": 1440000,\n   "lr-decay-iters": 1440000,\n   "distributed-backend": "nccl",\n   "lr-decay-style": "cosine",\n   "warmup": 0.01,\n   "checkpoint-factor": 10000,\n   "eval-interval": 1000,\n   "eval-iters": 10,\n\n   # logging\n   "log-interval": 100,\n   "steps_per_print": 10,\n   "keep-last-n-checkpoints": 450,\n   "wall_clock_breakdown": true,\n\n  #  networking\n  "hostfile": "./hostfile",\n\n  # tokenizer\n  "tokenizer_type" : "HFGPT2Tokenizer"\n}\n', 'local_setup-multinode.yml': '# Suggested data paths when using GPT-NeoX locally\n{\n  "data-path": "data/mydataset/mydataset_text_document",\n\n  # or for weighted datasets:\n  # "train-data-paths": ["data/enwik8/enwik8_text_document", "data/enwik8/enwik8_text_document"],\n  # "test-data-paths": ["data/enwik8/enwik8_text_document", "data/enwik8/enwik8_text_document"],\n  # "valid-data-paths": ["data/enwik8/enwik8_text_document", "data/enwik8/enwik8_text_document"],\n  # "train-data-weights": [1., 2.],\n  # "test-data-weights": [2., 1.],\n  # "valid-data-weights": [0.5, 0.4],\n\n  # If weight_by_num_documents is True, Builds dataset weights from a multinomial distribution over groups of data according to the number of documents in each group.\n  # WARNING: setting this to True will override any user provided weights\n  # "weight_by_num_documents": false,\n  # "weighted_sampler_alpha": 0.3,\n\n  "vocab-file": "data/mydataset/",\n\n  "save": "checkpoints_multinode",\n  "load": "checkpoints_multinode",\n  "checkpoint_validation_with_forward_pass": False,\n\n  "tensorboard-dir": "tensorboard",\n  "log-dir": "logs_multinode",\n  "use_wandb": True,\n  "wandb_host": "https://api.wandb.ai",\n  "wandb_project": "neox"\n}\n'}updated
      data_impl ....................... mmap........................updated
      data_path ....................... data/mydataset/mydataset_text_documentupdated
      dynamic_loss_scale .............. True........................updated
      eval_iters ...................... 10..........................updated
      fp16 ............................ {'enabled': True, 'loss_scale': 0, 'loss_scale_window': 1000, 'hysteresis': 2, 'min_loss_scale': 1}updated
      gas ............................. 1...........................updated
      global_num_gpus ................. 4...........................updated
      gradient_clipping ............... 1.0.........................updated
      hidden_dropout .................. 0.0.........................updated
      hidden_size ..................... 768.........................updated
      hostfile ........................ ./hostfile..................updated
      init_method ..................... small_init..................updated
      is_pipe_parallel ................ True........................updated
      keep_last_n_checkpoints ......... 450.........................updated
      load ............................ checkpoints_multinode.......updated
      log_dir ......................... logs_multinode..............updated
      log_interval .................... 100.........................updated
      lr .............................. 0.0006......................updated
      lr_decay_iters .................. 1440000.....................updated
      lr_decay_style .................. cosine......................updated
      max_position_embeddings ......... 2048........................updated
      min_lr .......................... 6e-05.......................updated
      model_parallel_size ............. 2...........................updated
      no_weight_tying ................. True........................updated
      num_attention_heads ............. 12..........................updated
      num_gpus ........................ 2...........................updated
      num_layers ...................... 12..........................updated
      num_nodes ....................... 2...........................updated
      optimizer ....................... {'type': 'Adam', 'params': {'lr': 0.0006, 'betas': [0.9, 0.95], 'eps': 1e-08}}updated
      optimizer_type .................. Adam........................updated
      output_layer_init_method ........ wang_init...................updated
      output_layer_parallelism ........ column......................updated
      partition_activations ........... True........................updated
      pipe_parallel_size .............. 2...........................updated
      pos_emb ......................... rotary......................updated
      precision ....................... fp16........................updated
      save ............................ checkpoints_multinode.......updated
      save_iters ...................... [10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000, 100000, 110000, 120000, 130000, 140000, 150000, 160000, 170000, 180000, 190000, 200000, 210000, 220000, 230000, 240000, 250000, 260000, 270000, 280000, 290000, 300000, 310000, 320000, 330000, 340000, 350000, 360000, 370000, 380000, 390000, 400000, 410000, 420000, 430000, 440000, 450000, 460000, 470000, 480000, 490000, 500000, 510000, 520000, 530000, 540000, 550000, 560000, 570000, 580000, 590000, 600000, 610000, 620000, 630000, 640000, 650000, 660000, 670000, 680000, 690000, 700000, 710000, 720000, 730000, 740000, 750000, 760000, 770000, 780000, 790000, 800000, 810000, 820000, 830000, 840000, 850000, 860000, 870000, 880000, 890000, 900000, 910000, 920000, 930000, 940000, 950000, 960000, 970000, 980000, 990000, 1000000, 1010000, 1020000, 1030000, 1040000, 1050000, 1060000, 1070000, 1080000, 1090000, 1100000, 1110000, 1120000, 1130000, 1140000, 1150000, 1160000, 1170000, 1180000, 1190000, 1200000, 1210000, 1220000, 1230000, 1240000, 1250000, 1260000, 1270000, 1280000, 1290000, 1300000, 1310000, 1320000, 1330000, 1340000, 1350000, 1360000, 1370000, 1380000, 1390000, 1400000, 1410000, 1420000, 1430000]updated
      seq_length ...................... 2048........................updated
      sparsity_config ................. {}..........................updated
      synchronize_each_layer .......... True........................updated
      tensorboard_dir ................. tensorboard.................updated
      text_gen_type ................... unconditional...............updated
      tokenizer_type .................. HFGPT2Tokenizer.............updated
      train_batch_size ................ 100..........................updated
      train_iters ..................... 1440000.....................updated
      train_micro_batch_size_per_gpu .. 100..........................updated
      use_wandb ....................... True........................updated
      user_script ..................... train.py....................updated
      vocab_file ...................... data/mydataset/.............updated
      wall_clock_breakdown ............ True........................updated
      wandb_group ..................... g3da4r9m....................updated
      weight_decay .................... 0.1.........................updated
      zero_allgather_bucket_size ...... 500000000...................updated
      zero_contiguous_gradients ....... True........................updated
      zero_optimization ............... {'stage': 1, 'allgather_partitions': True, 'allgather_bucket_size': 500000000, 'overlap_comm': True, 'reduce_scatter': True, 'reduce_bucket_size': 500000000, 'contiguous_gradients': True}updated
      zero_reduce_bucket_size ......... 500000000...................updated
      zero_reduce_scatter ............. True........................updated
      zero_stage ...................... 1...........................updated
      activation ...................... gelu........................default
      adlr_autoresume ................. False.......................default
      adlr_autoresume_interval ........ 1000........................default
      amp ............................. None........................default
      apply_query_key_layer_scaling ... False.......................default
      attention_softmax_in_fp32 ....... False.......................default
      autotuning ...................... None........................default
      autotuning_run .................. None........................default
      base_shapes_file ................ None........................default
      bias_dropout_fusion ............. False.......................default
      bias_gelu_fusion ................ False.......................default
      char_level_ppl .................. False.......................default
      checkpoint_in_cpu ............... False.......................default
      checkpoint_num_layers ........... 1...........................default
      checkpoint_scale ................ linear......................default
      checkpoint_validation_with_forward_pass  False................default
      comment ......................... None........................default
      contiguous_checkpointing ........ False.......................default
      coord_check ..................... False.......................default
      curriculum_learning ............. None........................default
      curriculum_seqlen ............... 0...........................default
      deepscale ....................... False.......................default
      deepscale_config ................ None........................default
      deepspeed ....................... True........................default
      deepspeed_activation_checkpointing  True......................default
      deepspeed_mpi ................... False.......................default
      deepspeed_slurm ................. False.......................default
      detect_nvlink_pairs ............. False.......................default
      distributed_backend ............. nccl........................default
      do_test ......................... None........................default
      do_train ........................ None........................default
      do_valid ........................ None........................default
      dump_state ...................... False.......................default
      eod_mask_loss ................... False.......................default
      eval_interval ................... 1000........................default
      eval_results_prefix ............. ............................default
      eval_tasks ...................... None........................default
      exclude ......................... None........................default
      exit_interval ................... None........................default
      extra_save_iters ................ None........................default
      finetune ........................ False.......................default
      flops_profiler .................. None........................default
      fp16_lm_cross_entropy ........... False.......................default
      fp32_allreduce .................. False.......................default
      git_hash ........................ e25ded7.....................default
      gmlp_attn_dim ................... 64..........................default
      gpt_j_residual .................. False.......................default
      gpt_j_tied ...................... False.......................default
      gradient_accumulation_steps ..... 1...........................default
      gradient_noise_scale_cpu_offload  False.......................default
      gradient_noise_scale_n_batches .. 5...........................default
      gradient_predivide_factor ....... 1.0.........................default
      hysteresis ...................... 2...........................default
      include ......................... None........................default
      init_method_std ................. 0.02........................default
      iteration ....................... None........................default
      launcher ........................ pdsh........................default
      layernorm_epsilon ............... 1e-05.......................default
      lazy_mpu_init ................... False.......................default
      local_rank ...................... None........................default
      log_grad_norm ................... False.......................default
      log_grad_pct_zeros .............. False.......................default
      log_gradient_noise_scale ........ False.......................default
      log_optimizer_states ............ False.......................default
      log_param_norm .................. False.......................default
      loss_scale ...................... None........................default
      loss_scale_window ............... 1000.0......................default
      make_vocab_size_divisible_by .... 128.........................default
      master_addr ..................... None........................default
      master_port ..................... 29500.......................default
      maximum_tokens .................. 64..........................default
      merge_file ...................... None........................default
      min_scale ....................... 1.0.........................default
      mmap_warmup ..................... False.......................default
      mup_attn_temp ................... 1.0.........................default
      mup_embedding_mult .............. 1.0.........................default
      mup_init_scale .................. 1.0.........................default
      mup_output_temp ................. 1.0.........................default
      mup_rp_embedding_mult ........... 1.0.........................default
      mup_width_scale ................. 2...........................default
      no_load_optim ................... False.......................default
      no_load_rng ..................... False.......................default
      no_save_optim ................... False.......................default
      no_save_rng ..................... False.......................default
      no_ssh_check .................... False.......................default
      norm ............................ layernorm...................default
      num_samples ..................... 1...........................default
      num_unique_layers ............... None........................default
      num_workers ..................... 2...........................default
      onnx_safe ....................... False.......................default
      opt_pos_emb_offset .............. 0...........................default
      override_lr_scheduler ........... False.......................default
      padded_vocab_size ............... None........................default
      param_sharing_style ............. grouped.....................default
      pipe_partition_method ........... type:transformer|mlp........default
      prescale_gradients .............. False.......................default
      profile_backward ................ False.......................default
      prompt_end ......................
    ...........................default
      rank ............................ None........................default
      recompute ....................... False.......................default
      return_logits ................... False.......................default
      rms_norm_epsilon ................ 1e-08.......................default
      rotary_emb_base ................. 10000.......................default
      rotary_pct ...................... 1.0.........................default
      rpe_max_distance ................ 128.........................default
      rpe_num_buckets ................. 32..........................default
      sample_input_file ............... None........................default
      sample_output_file .............. samples.txt.................default
      save_base_shapes ................ False.......................default
      scaled_masked_softmax_fusion .... False.......................default
      scaled_upper_triang_masked_softmax_fusion  False..............default
      scalenorm_epsilon ............... 1e-08.......................default
      scheduler ....................... None........................default
      seed ............................ 1234........................default
      short_seq_prob .................. 0.1.........................default
      soft_prompt_tuning .............. None........................default
      sparse_gradients ................ False.......................default
      split ........................... 969, 30, 1..................default
      steps_per_print ................. 10..........................default
      temperature ..................... 0.0.........................default
      test_data_paths ................. None........................default
      test_data_weights ............... None........................default
      top_k ........................... 0...........................default
      top_p ........................... 0.0.........................default
      train_data_paths ................ None........................default
      train_data_weights .............. None........................default
      use_bnb_optimizer ............... False.......................default
      use_checkpoint_lr_scheduler ..... False.......................default
      use_cpu_initialization .......... False.......................default
      use_mup ......................... False.......................default
      use_shared_fs ................... True........................default
      valid_data_paths ................ None........................default
      valid_data_weights .............. None........................default
      wandb_host ...................... https://api.wandb.ai........default
      wandb_init_all_ranks ............ False.......................default
      wandb_project ................... neox........................default
      wandb_team ...................... None........................default
      warmup .......................... 0.01........................default
      weight_by_num_documents ......... False.......................default
      weighted_sampler_alpha .......... 0.3.........................default
      world_size ...................... None........................default
      zero_allow_untested_optimizer ... False.......................default
    ---------------- end of arguments ---------------- 
    
    ```

Additional context
Add any other context about the problem here.

@cateto cateto added the bug Something isn't working label May 10, 2023
@StellaAthena
Copy link
Member

There are several strange things in your config file. Can tell us a bit more about what you’re trying to do? Specifically:

  1. You probably shouldn’t be using a MBS of 100 ever.
  2. You’re using MP = 2 and PP = 2, but your model is small enough that it should be trainable without either. You can fit the entire model on a single GPU and then do pure data parallelism.
  3. Is it correct that you have two nodes each with two A100s, and not four A100s on a single node?

It seems plausible that we might have a data loader bug that only appears in excess of a MBS of 100 because we’ve probably never tested it that large. However our data loader is largely the same as Megatron-DeepSpeed’s… have you tried that code base to see if it works with a MBS of 100?

@Quentin-Anthony
Copy link
Member

There are several strange things in your config file. Can tell us a bit more about what you’re trying to do? Specifically:

  1. You probably shouldn’t be using a MBS of 100 ever.
  2. You’re using MP = 2 and PP = 2, but your model is small enough that it should be trainable without either. You can fit the entire model on a single GPU and then do pure data parallelism.
  3. Is it correct that you have two nodes each with two A100s, and not four A100s on a single node?

It seems plausible that we might have a data loader bug that only appears in excess of a MBS of 100 because we’ve probably never tested it that large. However our data loader is largely the same as Megatron-DeepSpeed’s… have you tried that code base to see if it works with a MBS of 100?

This issue should have been fixed in #835

@cateto -- Following up on @StellaAthena's questions, please try setting model-parallel-size and pipe-parallel-size to 1, which will improve performance but force you to reduce the batch size. Also, please share your commit hash here so that we can investigate further if the above PR didn't fix this issue properly.

@StellaAthena
Copy link
Member

@cateto hey, any updates on this?

@cateto
Copy link
Author

cateto commented Jun 5, 2023

Solved it ! Thank you.
with 2 nodes (2 * A100 80GB)

# GPT-2 pretraining setup
{
   # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
   # across the node boundaries )
   "pipe-parallel-size": 1,
   "model-parallel-size": 1,
   "num_nodes": 2,

   # model settings
   "num-layers": 12,
   "hidden-size": 768,
   "num-attention-heads": 12,
   "seq-length": 2048,
   "max-position-embeddings": 2048,
   "norm": "layernorm",
   "pos-emb": "rotary",
   "no-weight-tying": true,
   "gpt_j_residual": false,
   "output_layer_parallelism": "column",

   # these should provide some speedup but takes a while to build, set to true if desired
   "scaled-upper-triang-masked-softmax-fusion": false,
   "bias-gelu-fusion": false,

   # init methods
   "init_method": "small_init",
   "output_layer_init_method": "wang_init",


   # optimizer settings
   "optimizer": {
     "type": "Adam",
     "params": {
       "lr": 0.0006,
       "betas": [0.9, 0.95],
       "eps": 1.0e-8,
     }
   },
  "min_lr": 0.00006,

   # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
   "zero_optimization": {
    "stage": 1,
    "allgather_partitions": True,
    "allgather_bucket_size": 500000000,
    "overlap_comm": True,
    "reduce_scatter": True,
    "reduce_bucket_size": 500000000,
    "contiguous_gradients": True,
  },

   # batch / data settings
   "train_micro_batch_size_per_gpu": 55,
   "gradient_accumulation_steps": 2,
   "data-impl": "mmap",
   #"split": "949,50,1",

   # activation checkpointing
   "checkpoint-activations": true,
   "checkpoint-num-layers": 1,
   "partition-activations": true,
   "synchronize-each-layer": true,

   # regularization
   "gradient_clipping": 1.0,
   "weight-decay": 0.1,
   "hidden-dropout": 0.0,
   "attention-dropout": 0.0,

   # precision settings
   "fp16": {
     "fp16": true,
     "enabled": true,
     "loss_scale": 0,
     "loss_scale_window": 1000,
     "hysteresis": 2,
     "min_loss_scale": 1
   },

   # misc. training settings
   "train-iters": 100000,
   "lr-decay-iters": 100000,
   "distributed-backend": "nccl",
   "lr-decay-style": "cosine",
   "warmup": 0.01,
   "checkpoint-factor": 2500,
   "eval-interval": 1000,
   "eval-iters": 10,

   # logging
   "log-interval": 100,
   "steps_per_print": 10,
   "keep-last-n-checkpoints": 450,
   "wall_clock_breakdown": true,

}


@cateto cateto closed this as completed Jun 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants