Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight mismatch when using deepspeed zero-stage 3 and pretrained codegen model #22017

Closed
2 of 4 tasks
KaiLv69 opened this issue Mar 8, 2023 · 6 comments
Closed
2 of 4 tasks

Comments

@KaiLv69
Copy link

KaiLv69 commented Mar 8, 2023

System Info

  • transformers version: 4.26.1
  • Platform: Linux-4.15.0-189-generic-x86_64-with-glibc2.17
  • Python version: 3.8.16
  • Huggingface_hub version: 0.12.1
  • PyTorch version (GPU?): 1.12.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: True
  • Using distributed or parallel set-up in script?: True

Who can help?

@stas @ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. My code for load model.
from transformers import AutoModelForCausalLM, AutoConfig
from transformers.models.codegen.modeling_codegen import CodeGenMLP
import argparse
import torch
import time, datetime
import deepspeed
from deepspeed.accelerator import get_accelerator
from torch.utils.data import Dataset
from transformers.activations import ClippedGELUActivation, LinearActivation
from lion_pytorch import Lion
SEQ_LEN = 300
VOCAB_SIZE = 10000
DATA_SIZE = 100

class FakeDataset(Dataset):
    def __init__(self, length, seq_len, vocab_size):
        self.length = length
        self.seq_len = seq_len
        self.vocab_size = vocab_size

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        input_ids = torch.randint(0, self.vocab_size, (self.seq_len, ))
        attention_mask = torch.ones_like(input_ids)
        return input_ids, attention_mask


def train():
    with deepspeed.zero.Init():
        model = AutoModelForCausalLM.from_pretrained(
            "Salesforce/codegen-350M-mono",
            ignore_mismatched_sizes=True  # if False, it would run in error
        )
       
    optimizer = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)

    print(f"[{datetime.datetime.today()}] Loading dataset.")
    dataset = FakeDataset(DATA_SIZE, SEQ_LEN, VOCAB_SIZE)

    print(f"[{datetime.datetime.today()}] Initializing DeepSpeed Engine.")
    model_engine, optimizer, trainloader, _ = deepspeed.initialize(
        args=args,
        model=model,
        optimizer=optimizer,
        model_parameters=model.parameters(),
        training_data=dataset)

    model.train()
    for i, data in enumerate(trainloader):
        model_engine.zero_grad()
        optimizer.zero_grad()
        input_ids, attn_mask = data[0].cuda(), data[1].cuda()
        output = model_engine(input_ids=input_ids,
                              attention_mask=attn_mask,
                              labels=input_ids)

        model_engine.backward(output['loss'])

        model_engine.step()
       
        # 2 pytorch allocator cache flushes since last step. this happens when
        # there is high memory pressure and is detrimental to performance. if
        # this is happening frequently consider adjusting settings to reduce
        # memory consumption. If you are unable to make the cache flushes go
        # away consider adding get_accelerator().empty_cache() calls in your
        # training loop to ensure that all ranks flush their caches at the
        # same time
        get_accelerator().empty_cache()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', type=int, default=-1)
    parser = deepspeed.add_config_arguments(parser)
    args = parser.parse_args()
    train()
  1. Deepspeed config
{
  "gradient_accumulation_steps": 1,
  "train_micro_batch_size_per_gpu": 1,
  "steps_per_print": 1,
  "wall_clock_breakdown": true,
  "fp16": {
    "enabled": true,
    "auto_cast": true,
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 0.001,
      "betas": [
        0.8,
        0.999
      ],
      "eps": 1e-8,
      "weight_decay": 3e-7
    }
  },
  "zero_allow_untested_optimizer": true,
  "zero_optimization": {
    "stage": 3,
    "contiguous_gradients": true,
    "overlap_comm": true,
    "reduce_scatter": true,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    }
  }
}
  1. bash script to run the
deepspeed --include localhost:0,1,2,3,4,5,6,7 train.py --deepspeed_config 350m.json
  1. Relevant output snippets. It shows the weird behaviour wherein the model isn't being properly initialized with the pretrained weights.

image

Expected behavior

Model being properly initialized with the pretrained weights when using DeepSpeed ZERO Stage-3. It seems that the model parameters are randomly initialized so far.

@ArthurZucker
Copy link
Collaborator

Hey, there is something wrong indeed :

ignore_mismatched_sizes=True # if False, it would run in error
The error in witch you run should indicate how to fix the problems (most probably malformed configuration file)

@KaiLv69
Copy link
Author

KaiLv69 commented Mar 8, 2023

Thanks for you quick reply.
Actually, I followed the error message to set it to True. When I set 'ignore_mismatched_sizes' to False, it prints as followings:
image

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Mar 8, 2023

Ah sorry, you were right in ignoring the missmatches! Yes, there is a special argument to initialise your model using deepspeed in transformers but it does not support the deepspeed stage 3:

        * `low_cpu_mem_usage` algorithm:

        This is an experimental function that loads the model using ~1x model size CPU memory

        Here is how it works:

        1. save which state_dict keys we have
        2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory
        3. after the model has been instantiated switch to the meta device all params/buffers that
        are going to be replaced from the loaded state_dict
        4. load state_dict 2nd time
        5. replace the params/buffers from the state_dict

        Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors

The documentation mentions this.
So this is expeced, but @stas00 is the deep speed boss so pinging him for help, but this is more a feature request than a bug IMO

@stas00
Copy link
Contributor

stas00 commented Mar 9, 2023

For Non HF-Trainer integration please see:
https://huggingface.co/docs/transformers/main/main_classes/deepspeed#nontrainer-deepspeed-integration

zero.Init is already done for you inside the modeling code - you just need to set dschf = HfDeepSpeedConfig(args.deepspeed_config) and keep it alive before you call from_pretrained - that's it.

I fixed your program to work:

from transformers import AutoModelForCausalLM, AutoConfig
from transformers.models.codegen.modeling_codegen import CodeGenMLP
import argparse
import torch
import time, datetime
import deepspeed
from deepspeed.accelerator import get_accelerator
from torch.utils.data import Dataset
from transformers.activations import ClippedGELUActivation, LinearActivation
from lion_pytorch import Lion
SEQ_LEN = 300
VOCAB_SIZE = 10000
DATA_SIZE = 100

class FakeDataset(Dataset):
    def __init__(self, length, seq_len, vocab_size):
        self.length = length
        self.seq_len = seq_len
        self.vocab_size = vocab_size

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        input_ids = torch.randint(0, self.vocab_size, (self.seq_len, ))
        attention_mask = torch.ones_like(input_ids)
        return input_ids, attention_mask


def train(args):
    from transformers.deepspeed import HfDeepSpeedConfig
    dschf = HfDeepSpeedConfig(args.deepspeed_config)  # keep this object alive

    model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")

    optimizer = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)

    print(f"[{datetime.datetime.today()}] Loading dataset.")
    dataset = FakeDataset(DATA_SIZE, SEQ_LEN, VOCAB_SIZE)

    print(f"[{datetime.datetime.today()}] Initializing DeepSpeed Engine.")
    model_engine, optimizer, trainloader, _ = deepspeed.initialize(
        args=args,
        model=model,
        optimizer=optimizer,
        model_parameters=model.parameters(),
        training_data=dataset)

    model.train()
    for i, data in enumerate(trainloader):
        model_engine.zero_grad()
        optimizer.zero_grad()
        input_ids, attn_mask = data[0].cuda(), data[1].cuda()
        output = model_engine(input_ids=input_ids,
                              attention_mask=attn_mask,
                              labels=input_ids)

        model_engine.backward(output['loss'])

        model_engine.step()

        # 2 pytorch allocator cache flushes since last step. this happens when
        # there is high memory pressure and is detrimental to performance. if
        # this is happening frequently consider adjusting settings to reduce
        # memory consumption. If you are unable to make the cache flushes go
        # away consider adding get_accelerator().empty_cache() calls in your
        # training loop to ensure that all ranks flush their caches at the
        # same time
        get_accelerator().empty_cache()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--deepspeed_config', type=str)
    args = parser.parse_args()
    train(args)

@stas00
Copy link
Contributor

stas00 commented Mar 9, 2023

BTW, when you use deepspeed offload w/ LION it will be slow.

You want deepspeed's Adam instead or turn off offload. You shouldn't need it with 8 gpus and this small model. Unless you were just using it for a repro case, still 8 gpus is a lot of sharding.

The Deepspeed team are working on flagging this incompatibility here microsoft/DeepSpeed#2971

Make sure to enabled gradient checkpointing - which will save you a ton of gpu memory at a small cost of slowdown. (unrelated to deepspeed)

@KaiLv69
Copy link
Author

KaiLv69 commented Mar 9, 2023

Thanks very much. The problem have been solved.

@KaiLv69 KaiLv69 closed this as completed Mar 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants