Skip to content

Commit

Permalink
support cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
pengxiaotao authored and pengxiaotao committed Mar 24, 2023
1 parent 57b0eb6 commit 2b71519
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 56 deletions.
12 changes: 8 additions & 4 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def setup_model_parallel() -> Tuple[int, int]:
world_size = int(os.environ.get("WORLD_SIZE", -1))

torch.distributed.init_process_group("nccl")
#torch.distributed.init_process_group("gloo")
initialize_model_parallel(world_size)
torch.cuda.set_device(local_rank)

Expand All @@ -36,12 +37,14 @@ def load(
world_size: int,
max_seq_len: int,
max_batch_size: int,
use_cpu: bool = True
) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert world_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
if not use_cpu:
assert world_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
ckpt_path = checkpoints[local_rank]
print("Loading")
checkpoint = torch.load(ckpt_path, map_location="cpu")
Expand All @@ -53,7 +56,8 @@ def load(
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
if not use_cpu:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)
Expand Down
6 changes: 5 additions & 1 deletion llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def generate(
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
use_cpu: bool = True
) -> List[str]:
bsz = len(prompts)
params = self.model.params
Expand All @@ -32,7 +33,10 @@ def generate(

total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)

tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
if not use_cpu:
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
else:
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cpu().long()
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id
Expand Down
158 changes: 107 additions & 51 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ModelArgs:

max_batch_size: int = 32
max_seq_len: int = 2048
use_cpu: bool = True


class RMSNorm(torch.nn.Module):
Expand Down Expand Up @@ -77,44 +78,77 @@ class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
self.head_dim = args.dim // args.n_heads
if args.use_cpu:
self.n_local_heads = args.n_heads
else:
self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()

self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.head_dim = args.dim // args.n_heads

self.cache_k = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
).cuda()
self.cache_v = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
).cuda()
if not args.use_cpu:
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)

self.cache_k = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
).cuda()
self.cache_v = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
).cuda()
else:
self.wq = nn.Linear(
args.dim,
args.n_heads * self.head_dim,
bias=False
)
self.wk = nn.Linear(
args.dim,
args.n_heads * self.head_dim,
bias=False
)
self.wv = nn.Linear(
args.dim,
args.n_heads * self.head_dim,
bias=False
)
self.wo = nn.Linear(
args.n_heads * self.head_dim,
args.dim,
bias=False
)

self.cache_k = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
)
self.cache_v = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
)

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
bsz, seqlen, _ = x.shape
Expand Down Expand Up @@ -156,20 +190,32 @@ def __init__(
dim: int,
hidden_dim: int,
multiple_of: int,
use_cpu: bool = True
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
)
self.w3 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
if not use_cpu:
self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
)
self.w3 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
else:
self.w1 = nn.Linear(
dim, hidden_dim, bias=False
)
self.w2 = nn.Linear(
hidden_dim, dim, bias=False
)
self.w3 = nn.Linear(
dim, hidden_dim, bias=False
)

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
Expand Down Expand Up @@ -202,18 +248,28 @@ def __init__(self, params: ModelArgs):
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers

self.tok_embeddings = ParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
)
if not params.use_cpu:
self.tok_embeddings = ParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
)
else:
self.tok_embeddings = nn.Embedding(
params.vocab_size, params.dim
)

self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))

self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)
if not params.use_cpu:
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)
else:
self.output = nn.Linear(
params.dim, params.vocab_size, bias=False
)

self.freqs_cis = precompute_freqs_cis(
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
Expand Down
12 changes: 12 additions & 0 deletions settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import sys
import os
import logging

basedir = os.path.dirname(__file__)
sys.path.append(basedir)

logger = logging.getLogger("llama")

TOKEN_MODEL_PATH = "/Users/pengxiaotao/Documents/llama/models/tokenizer.model"
LLM_MODEL_PATH = "/Users/pengxiaotao/Documents/llama/models/7B"

1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ..settings import *
31 changes: 31 additions & 0 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import sys
import os

basedir = os.path.dirname(__file__)
llama_lib = os.path.join(basedir, "..")
sys.path.append(llama_lib)

import unittest
from unittest import TestCase
from settings import TOKEN_MODEL_PATH, LLM_MODEL_PATH
from llama.generation import LLaMA
from llama.model import ModelArgs, Transformer
from llama.tokenizer import Tokenizer
from example import load


class TestAttention(TestCase):
def setUp(self):
self.llama = load(LLM_MODEL_PATH, TOKEN_MODEL_PATH, 0, 0, 512, 32, True)

def test_generate(self):
prompts = [
"implement lr with python"
]
generate_text = self.llama.generate(prompts, 128)
self.assertGreater(len(generate_text), 0)
print(generate_text)


if __name__ == "__main__":
unittest.main()
61 changes: 61 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import sys
import os

basedir = os.path.dirname(__file__)
llama_lib = os.path.join(basedir, "..")
sys.path.append(llama_lib)

import unittest
from unittest import TestCase
from settings import TOKEN_MODEL_PATH
from llama.model import Attention, ModelArgs, FeedForward, TransformerBlock, Transformer


class TestAttention(TestCase):
def setUp(self):
self.args = ModelArgs()

def test_init(self):
network = Attention(self.args)
print(network)
self.assertIsNotNone(network)

def test_forward(self):
pass


class TestFeedForward(TestCase):

def setUp(self):
self.args = ModelArgs()

def test_init(self):
network = FeedForward(self.args.dim, self.args.dim * 2, self.args.multiple_of)
print(network)
self.assertIsNotNone(network)


class TestTransformerBlock(TestCase):
def setUp(self) -> None:
self.args = ModelArgs()

def test_init(self):
network = TransformerBlock(1, self.args)
print(network)
self.assertIsNotNone(network)


class TestTransformer(TestCase):

def setUp(self) -> None:
self.args = ModelArgs()
self.args.vocab_size = 32000

def test_init(self):
network = Transformer(self.args)
print(network)
self.assertIsNotNone(network)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 2b71519

Please sign in to comment.