Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] support Distributed Data Parallel #137

Merged
merged 32 commits into from
Mar 10, 2023
Merged

Conversation

jinwonkim93
Copy link
Member

@jinwonkim93 jinwonkim93 commented Feb 22, 2023

Title

Colossal AI-based Distributed Data Parallel with oslo interface

Description

The purpose of this implementation is to enable DDP in Oslo, with the reducer method being identical to that of Colossal AI, but adapted to fit Oslo's interface. To enhance user experience, we replaced model.backward() with loss.backward() and added model.zero_grad() temporarily to the code. Any feedback is welcome :)

If you don't use model.zero_grad() there will be unexpected error.

test_data_parallel.py

import os
import torch.multiprocessing as mp

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import nn
from torch import optim
import torch.distributed as dist

from oslo.torch.distributed.parallel_context import ParallelContext


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def train(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)
    parallel_context = ParallelContext.from_torch(data_parallel_size=world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.zeros(20, 10).to(rank))
    labels = torch.zeros(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    print(outputs)
    cleanup()


def main(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main(2)

test_oslo_data_parallel.py

import os
import torch.multiprocessing as mp

import torch
from torch import nn
from torch import optim
import torch.distributed as dist

import oslo
from oslo.torch.distributed.parallel_context import ParallelContext
from oslo.torch.nn.parallel.data_parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def train(rank, world_size):
    print(f"Running oslo DDP example on rank {rank}.")
    setup(rank, world_size)
    parallel_context = ParallelContext.from_torch(data_parallel_size=world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, parallel_context)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    oslo.ready(ddp_model, parallel_context)
    optimizer.zero_grad()
    outputs = ddp_model(torch.zeros(20, 10).to(rank))
    labels = torch.zeros(20, 5).to(rank)
    loss = loss_fn(outputs, labels)
    ddp_model.backward(loss)
    optimizer.step()
    print(outputs)
    cleanup()


def main(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main(2)

image

image

pytorch DDP
ddp_before_backward
ddp_after_backward

Oslo DDP
oslo_before_backward
oslo_after_backward

By checking the model's parameters, oslo DDP is working as expected.

After Cleaning
image

Oslo DDP
oslo-ddp-time.log

Torch DDP
torch-ddp-time.log

Linked Issues

  • resolved #00

@jinwonkim93 jinwonkim93 added the Data Parallelism Data parallelism related label Feb 22, 2023
@jinwonkim93 jinwonkim93 closed this Mar 1, 2023
@jinwonkim93 jinwonkim93 reopened this Mar 2, 2023
@hyunwoongko
Copy link
Member

Could you run precommit run --all-files.

jinwonkim93 and others added 4 commits March 2, 2023 15:30
## Title
Deleted legacy code.

## Description
Only newly added code written by jinwonkim93 remains.

---------

Co-authored-by: KKIEEK <[email protected]>
@KKIEEK
Copy link
Contributor

KKIEEK commented Mar 2, 2023

I think it would be better to merge _DistirbutedDataParallelWrapper class into _DistributedDataParallel.

Related to #137
For now, our implementation of DDP does not support long tensor input,
so I fixed it.

---------

Co-authored-by: Hakjin Lee <[email protected]>
@jinwonkim93 jinwonkim93 marked this pull request as ready for review March 3, 2023 16:45
@hyunwoongko
Copy link
Member

please resolve conflict! @jinwonkim93

@jinwonkim93
Copy link
Member Author

jinwonkim93 commented Mar 8, 2023

please resolve conflict! @jinwonkim93

Completed. One question is there a reason for "ALL" rather than "all"?

@hyunwoongko
Copy link
Member

hyunwoongko commented Mar 8, 2023

no. I prefer __ALL__, but we don't use both of them because we think import is enough.
Is there any file which contains __all__?

@jinwonkim93
Copy link
Member Author

no. I prefer __ALL__, but we don't use both of them because we think import is enough. Is there any file which contains __all__?

__all__ = ["ZeroRedundancyOptimizer"]


__all__ = ["get_free_port", "set_seed"]

@hyunwoongko
Copy link
Member

I didn't add them. all the code added by DP new members. so it's okay to change them to uppercase.

@jinwonkim93
Copy link
Member Author

I didn't add them. all the code added by DP new members. so it's okay to change them to uppercase.

Okay. I think it is ready to merged. what do you think?

@hyunwoongko
Copy link
Member

@jinwonkim93 looks good to me.

@hyunwoongko hyunwoongko merged commit f129a90 into main Mar 10, 2023
dyanos pushed a commit that referenced this pull request Jun 8, 2023
## Title
Colossal AI-based Distributed Data Parallel with oslo interface
-

## Description
The purpose of this implementation is to enable DDP in Oslo, with the
reducer method being identical to that of Colossal AI, but adapted to
fit Oslo's interface. To enhance user experience, we replaced
model.backward() with loss.backward() and added model.zero_grad()
temporarily to the code. Any feedback is welcome :)

If you don't use model.zero_grad() there will be unexpected error.

test_data_parallel.py
```python
import os
import torch.multiprocessing as mp

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import nn
from torch import optim
import torch.distributed as dist

from oslo.torch.distributed.parallel_context import ParallelContext


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def train(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)
    parallel_context = ParallelContext.from_torch(data_parallel_size=world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.zeros(20, 10).to(rank))
    labels = torch.zeros(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    print(outputs)
    cleanup()


def main(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main(2)

```
test_oslo_data_parallel.py
```python
import os
import torch.multiprocessing as mp

import torch
from torch import nn
from torch import optim
import torch.distributed as dist

import oslo
from oslo.torch.distributed.parallel_context import ParallelContext
from oslo.torch.nn.parallel.data_parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def train(rank, world_size):
    print(f"Running oslo DDP example on rank {rank}.")
    setup(rank, world_size)
    parallel_context = ParallelContext.from_torch(data_parallel_size=world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, parallel_context)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    oslo.ready(ddp_model, parallel_context)
    optimizer.zero_grad()
    outputs = ddp_model(torch.zeros(20, 10).to(rank))
    labels = torch.zeros(20, 5).to(rank)
    loss = loss_fn(outputs, labels)
    ddp_model.backward(loss)
    optimizer.step()
    print(outputs)
    cleanup()


def main(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main(2)

```


![image](https://user-images.githubusercontent.com/26476095/220687694-4236dcaa-ae66-4332-8159-3e206b04df49.png)


![image](https://user-images.githubusercontent.com/26476095/220687852-578584a5-db9a-4a90-ab3e-bdc779bb39a2.png)

-

pytorch DDP
<img width="585" alt="ddp_before_backward"
src="https://user-images.githubusercontent.com/26476095/221404650-2525413c-ce86-44e9-bd53-897ac4077b4a.png">
<img width="577" alt="ddp_after_backward"
src="https://user-images.githubusercontent.com/26476095/221404654-ce1e2d45-9304-4d13-aa83-c5a5f8d06689.png">

Oslo DDP
<img width="610" alt="oslo_before_backward"
src="https://user-images.githubusercontent.com/26476095/221404663-e85a0462-6fd2-4a6d-85a3-7fdcf9a5e9a7.png">
<img width="576" alt="oslo_after_backward"
src="https://user-images.githubusercontent.com/26476095/221404668-8cdee44d-3d76-4d23-adc0-68983ea7b173.png">

By checking the model's parameters, oslo DDP is working as expected.

After Cleaning

![image](https://user-images.githubusercontent.com/26476095/222415778-3358b862-a8c4-416e-9bc1-338d915d5e79.png)
 
Oslo DDP

[oslo-ddp-time.log](https://github.com/EleutherAI/oslo/files/10887632/oslo-ddp-time.log)

Torch DDP

[torch-ddp-time.log](https://github.com/EleutherAI/oslo/files/10887634/torch-ddp-time.log)


## Linked Issues

- resolved #00

---------

Co-authored-by: dongsung kim <[email protected]>
Co-authored-by: Hakjin Lee <[email protected]>
Co-authored-by: KKIEEK <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Data Parallelism Data parallelism related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants