Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samples #40

Open
wants to merge 8 commits into
base: inference
Choose a base branch
from
Next Next commit
Fix the error when passing load_4bit=True to load_pretrained_model
When load_4bit=True is passed to load_pretrained_model(), we get the
following error:
  File "LLaVA-NeXT/scripts/image/./gradio-ui.py", line 30, in load_model
    tokenizer, model, image_processor, max_length = load_pretrained_model(
                                                    ^^^^^^^^^^^^^^^^^^^^^^
  File "LLaVA-NeXT/llava/model/builder.py", line 175, in load_pretrained_model
    model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "LLaVA-NeXT/venv/lib/python3.12/site-packages/transformers/modeling_utils.py", line 2977, in from_pretrained
    raise ValueError(
ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time.

This commit fixes this by removing the "load_in_4bit" kwarg and relying
on the quantization_config only.

Signed-off-by: Alastair D'Silva <[email protected]>
  • Loading branch information
deece committed May 24, 2024
commit e0d6a2b1591ab20dc907e1660a19a2cbd7b575c7
2 changes: 1 addition & 1 deletion llava/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
if load_8bit:
kwargs["load_in_8bit"] = True
elif load_4bit:
kwargs["load_in_4bit"] = True
#kwargs["load_in_4bit"] = True
kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
else:
kwargs["torch_dtype"] = torch.float16
Expand Down