Skip to content

Commit

Permalink
support llava
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Jun 16, 2024
1 parent f4d48fa commit a9f0a4b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
13 changes: 10 additions & 3 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,13 @@ def copy_tokenizer(self, path):

@torch.no_grad()
def save_model(self, path):
self.model.get_model().save_pretrained(path)
logger.info(f"save model done --")
self.copy_tokenizer(path)
if self.config.model.type == "Llava":
self.model.llava_model.language_model = self.model.get_model()
self.model.llava_model.save_pretrained(path)
logger.info(f"save model done --")
self.copy_tokenizer(path)
copy_files(self.config.model.path, path, "preprocessor_config")
else:
self.model.get_model().save_pretrained(path)
logger.info(f"save model done --")
self.copy_tokenizer(path)
1 change: 1 addition & 0 deletions llmc/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .internlm2 import InternLM2
from .qwen2 import Qwen2
from .mixtral import Mixtral
from .llava import Llava
22 changes: 22 additions & 0 deletions llmc/models/llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from .llama import Llama
from llmc.utils.registry_factory import MODEL_REGISTRY
from transformers import LlavaForConditionalGeneration, AutoConfig


@MODEL_REGISTRY
class Llava(Llama):
def __init__(self, model_path, torch_dtype):
super().__init__(model_path, torch_dtype)

def build_model(self):
self.model_config = AutoConfig.from_pretrained(
self.model_path, trust_remote_code=True
)
self.model_config.text_config.use_cache = False
self.llava_model = LlavaForConditionalGeneration.from_pretrained(
self.model_path,
config=self.model_config,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
)
self.model = self.llava_model.language_model

0 comments on commit a9f0a4b

Please sign in to comment.