Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Deepseek-V2 #4650

Merged
merged 19 commits into from
Jun 28, 2024
6 changes: 6 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@ def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size

def get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
# FlashAttention supports only head_size 32, 64, 128, 256,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true? According to

def get_supported_head_sizes() -> List[int]:
, flash attention supports head size of 192 too. So I think you can remove this, right? And also the related padding code in deepseek_v2.py -- that should make it quite a bit simpler :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, i will test it later with the latest flash attn

# we need to pad head_size 192 to 256
return 256
if hasattr(self.hf_text_config, "head_dim"):
return self.hf_text_config.head_dim
# FIXME(woosuk): This may not be true for all models.
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name)
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)

__all__ = [
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
"grouped_topk",
]
31 changes: 31 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,37 @@ def fused_topk(
return topk_weights, topk_ids


# This is used by the Deepseek-V2 model
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids


def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
Expand Down
126 changes: 126 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,119 @@ def forward(
return query.flatten(-2), key.flatten(-2)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is extremely similar to YaRNScalingRotaryEmbedding, can you extend that one instead to support mscale?

"""RotaryEmbedding extended with YaRN method.

Credits to Peng et al. github.com/jquesnelle/yarn
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq

def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
print("Cache shape", cache.shape)
return cache

def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]

self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)

rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin

if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key


class GemmaRotaryEmbedding(RotaryEmbedding):

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
Expand Down Expand Up @@ -679,6 +792,19 @@ def get_rope(
base, is_neox_style,
scaling_factor, dtype,
**extra_kwargs)
elif scaling_type == "deepseek_yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow", "mscale", "mscale_all_dim")
}
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
# The correct one should be "longrope" but keep "su" here
# for backward compatible
elif scaling_type == "su" or scaling_type == "longrope":
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
Expand Down
Loading
Loading