Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor parallelism for RWKV #1237

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def init_specs(self):
LayerSpec(
RWKVResidualLayerPipe,
neox_args=self.neox_args,
init_method=self.init_method,
layer_number=i,
)
)
Expand Down
60 changes: 45 additions & 15 deletions megatron/model/rwkv/v6/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.cpp_extension import load

from megatron import mpu

class WKV(torch.autograd.Function):
"""
Expand Down Expand Up @@ -179,7 +179,7 @@ def __init__(self, neox_args, layer_number):

self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False)
self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False)
self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False)
self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) # column
self.ln_x = nn.GroupNorm(
neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2)
)
Expand Down Expand Up @@ -228,15 +228,19 @@ def forward(self, x):
return self.jit_func_2(x, g)


class RWKV_ChannelMix(nn.Module):
class ParallelRWKV_ChannelMix(nn.Module):
"""
Channel Mix layer. The ffn in RWKV
"""

def __init__(self, neox_args, layer_number):
def __init__(self, neox_args, layer_number, init_method):
super().__init__()
self.neox_args = neox_args
self.layer_number = layer_number

world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size)

self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

with torch.no_grad(): # fancy init of time_mix
Expand All @@ -247,29 +251,54 @@ def __init__(self, neox_args, layer_number):
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))

self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False)
self.receptance = nn.Linear(
neox_args.hidden_size, neox_args.hidden_size, bias=False
)
self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False)

#self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False)
self.key = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.dim_ffn,
gather_output=False,
init_method=init_method,
bias=False,
)
#self.receptance = nn.Linear(
# neox_args.hidden_size, neox_args.hidden_size, bias=False
#)
self.receptance = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.hidden_size,
gather_output=True,
init_method=init_method,
bias=False
)
#self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False)
self.value = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.dim_ffn,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=init_method,
parallel_output=False,
bias=False
)
def forward(self, x):
xx = self.time_shift(x) - x
xk = x + xx * self.time_maa_k
xr = x + xx * self.time_maa_r

k = self.key(xk)
k, _ = self.key(xk)
k = torch.relu(k) ** 2
kv = self.value(k)
return torch.sigmoid(self.receptance(xr)) * kv
kv, _ = self.value(k)
receptance, _ = self.receptance(xr)
return torch.sigmoid(receptance) * kv


class RWKVResidualLayer(nn.Module):
"""
RWKV layer definition
"""

def __init__(self, neox_args, layer_number):
def __init__(self, neox_args, init_method, layer_number):
super().__init__()
self.neox_args = neox_args
self.layer_number = layer_number
Expand All @@ -288,6 +317,7 @@ def __init__(self, neox_args, layer_number):
self.num_attention_heads = neox_args.num_attention_heads
assert neox_args.dim_att % self.num_attention_heads == 0

self.init_method = init_method
if neox_args.attention_dropout > 0:
self.drop0 = nn.Dropout(p=neox_args.attention_dropout)

Expand All @@ -296,7 +326,7 @@ def __init__(self, neox_args, layer_number):

self.att = RWKV_TimeMix(neox_args, layer_number)

self.ffn = RWKV_ChannelMix(neox_args, layer_number)
self.ffn = ParallelRWKV_ChannelMix(neox_args, layer_number, init_method=init_method)

if neox_args.attention_dropout > 0:
self.drop0 = nn.Dropout(p=neox_args.attention_dropout)
Expand Down
8 changes: 1 addition & 7 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,17 +1066,11 @@ def calculate_derived(self):
if isinstance(self.zero_stage, int):
assert self.zero_stage <= 2, "Zero stage 3 not compatible with Mamba"
assert (
self.hidden_dropout == 0.0,
self.hidden_dropout != 0.0,
), "Mamba does not yet have dropout implemented"
if "rwkv" in self.attention_config:
assert (
not self.is_pipe_parallel and self.model_parallel_size == 1
), "RWKV not currently compatible with parallelism"
if isinstance(self.zero_stage, int):
assert self.zero_stage <= 2, "Zero stage 3 not compatible with RWKV"
assert (
self.hidden_dropout == 0.0,
), "RWKV does not yet have dropout implemented"

# Sparsity config
if self.sparsity_config is None:
Expand Down
Loading