Skip to content

Commit

Permalink
Add StableLM model.
Browse files Browse the repository at this point in the history
  • Loading branch information
janimo committed Mar 14, 2024
1 parent ed31579 commit 1a7665f
Showing 1 changed file with 293 additions and 0 deletions.
293 changes: 293 additions & 0 deletions python/sglang/srt/models/stablelm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
# This code is based on:
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/stablelm.py
"""Inference-only StabeLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
model compatible with HuggingFace weights."""
from typing import Optional, Tuple

import torch
from torch import nn
from transformers import PretrainedConfig

from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (
LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
ParallelLMHead,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)


class StablelmMLP(nn.Module):
def __init__(
self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = MergedColumnParallelLinear(
config.hidden_size,
[config.intermediate_size] * 2,
bias=False,
linear_method=linear_method,
)
self.down_proj = RowParallelLinear(
config.intermediate_size, config.hidden_size, bias=False
)
self.act_fn = SiluAndMul()

def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x


class StablelmAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_id: int = 0,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
self.num_heads = self.total_num_heads // tp_size

self.total_num_key_value_heads = config.num_key_value_heads
if self.total_num_key_value_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_key_value_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_key_value_heads == 0
self.num_key_value_heads = max(1, self.total_num_key_value_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
rope_pct = getattr(
config, "rope_pct", getattr(config, "partial_rotary_factor", 1)
)
self.rotary_ndims = int(self.head_dim * rope_pct)
self.scaling = self.head_dim**-0.5
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
self.qkv_bias = getattr(config, "use_qkv_bias", False)
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads "
f"(got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)

self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_key_value_heads,
self.qkv_bias,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_ndims,
max_position=self.config.max_position_embeddings,
base=self.config.rope_theta,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_key_value_heads,
layer_id=layer_id,
)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
output, _ = self.o_proj(attn_output)
return output


class StablelmDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_id: int = 0,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.self_attn = StablelmAttention(config, layer_id=layer_id)
self.mlp = StablelmMLP(config, linear_method)
norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

return hidden_states, residual


class StableLMEpochModel(nn.Module):
def __init__(
self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
StablelmDecoderLayer(config, i, linear_method)
for i in range(config.num_hidden_layers)
]
)
norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
)
hidden_states = self.norm(hidden_states)
return hidden_states


class StableLmForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = StableLMEpochModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)

def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)


EntryClass = StableLmForCausalLM

0 comments on commit 1a7665f

Please sign in to comment.