Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does LLaVA-Next support in-context(few-shot) inference? #61

Open
waltonfuture opened this issue Jun 15, 2024 · 4 comments
Open

Does LLaVA-Next support in-context(few-shot) inference? #61

waltonfuture opened this issue Jun 15, 2024 · 4 comments

Comments

@waltonfuture
Copy link

Thanks for your work! Can I input multi-images and multi-instructions for few-shot inference?

@carlos-havier
Copy link

Also very interested in this for few-shot image classification. So far, I haven't been able to get good results. Is it possible to do it with LLaVA-NeXT out of the box, or would it need fine tune for this use?

@YepJin
Copy link

YepJin commented Jun 23, 2024

same here, hope the authors can have some feedback

@ChunyuanLI
Copy link
Collaborator

The recently released LLaVA-NeXT (Interleave) supports the a variety of daily-life multi-image scenarios, but it is NOT specifically trained for in-context-learning.

@waltonfuture
Copy link
Author


# from .demo_modelpart import InferenceDemo
import gradio as gr
import os
# import time
import cv2
import torch
# import random
import numpy as np

from llava import conversation as conversation_lib
from llava.constants import DEFAULT_IMAGE_TOKEN


from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

from PIL import Image

import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer

torch.manual_seed(42)

class InferenceDemo(object):
    def __init__(self,args,model_path) -> None:
        disable_torch_init()

        model_name = get_model_name_from_path(args.model_path)
        self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(args.model_path, args.model_base, model_name)

        # if "llama-2" in model_name.lower():
        #     conv_mode = "llava_llama_2"
        # elif "v1" in model_name.lower():
        #     conv_mode = "llava_v1"
        # elif "mpt" in model_name.lower():
        #     conv_mode = "mpt"
        # elif 'qwen' in model_name.lower():
        #     conv_mode = "qwen_1_5"
        # else:
        #     conv_mode = "llava_v0"
        conv_mode = "qwen_1_5"
        if args.conv_mode is not None and conv_mode != args.conv_mode:
            print("[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(conv_mode, args.conv_mode, args.conv_mode))
        else:
            args.conv_mode = conv_mode
            pass
        self.conv_mode=conv_mode
        self.conversation = conv_templates[args.conv_mode].copy()
        self.num_frames = args.num_frames
        pass

def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        if response.status_code == 200:
            image = Image.open(BytesIO(response.content)).convert("RGB")
        else:
            print('failed to load the image')
    else:
        print('Load image from local file')
        print(image_file)
        image = Image.open(image_file).convert("RGB")
        
    return image

if __name__ == "__main__":
    import argparse
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--model_path", default="/data/weilai/weilai_code/model_weights/llava-next-interleave-7b", type=str)
    # argparser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    argparser.add_argument("--model-base", type=str, default=None)
    argparser.add_argument("--num-gpus", type=int, default=1)
    argparser.add_argument("--conv-mode", type=str, default=None)
    argparser.add_argument("--temperature", type=float, default=0.2)
    argparser.add_argument("--max-new-tokens", type=int, default=512)
    argparser.add_argument("--num_frames", type=int, default=16)
    argparser.add_argument("--load-8bit", action="store_true")
    argparser.add_argument("--load-4bit", action="store_true")
    argparser.add_argument("--debug", action="store_true")
    
    args = argparser.parse_args()
    model_path = args.model_path
    filt_invalid="cut"
    our_chatbot = InferenceDemo(args,model_path)
    images_this_term = ["/data/weilai/weilai_code/datasets/CodeLLaVA/images/matplotlib_images/matplotlib_Aligning_Labels_0029.jpg","/data/weilai/weilai_code/datasets/CodeLLaVA/images/matplotlib_images/matplotlib_Aligning_Labels_0031.jpg"]
    image_list=[]
    for f in images_this_term:
        image_list.append(load_image(f))
    image_tensor = [our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][0].half().to(our_chatbot.model.device) for f in image_list]

    image_tensor = torch.stack(image_tensor)
    image_token = DEFAULT_IMAGE_TOKEN*len(image_list)
    # if our_chatbot.model.config.mm_use_im_start_end:
    #     inp = DEFAULT_IM_START_TOKEN + image_token + DEFAULT_IM_END_TOKEN + "\n" + inp
    # else:
    inp='How to edit image1 to make it look like image2?'
    inp = image_token + "\n" + inp
    our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
    our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
    prompt = our_chatbot.conversation.get_prompt()
    print(prompt)
    input_ids = tokenizer_image_token(prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(our_chatbot.model.device)
    stop_str = our_chatbot.conversation.sep if our_chatbot.conversation.sep_style != SeparatorStyle.TWO else our_chatbot.conversation.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, our_chatbot.tokenizer, input_ids)
    # import pdb;pdb.set_trace()
    with torch.inference_mode():
        output_ids = our_chatbot.model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024,use_cache=True)

    outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip()
    # if outputs.endswith(stop_str):
    #     outputs = outputs[:-len(stop_str)]
    print(outputs)

I use this script for inference, but the model can only outputs "<|im_end|>". What's wrong with my script? Thanks a lot for help! @ChunyuanLI

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants