Skip to content

Commit

Permalink
[Model] Support Monkey (#45)
Browse files Browse the repository at this point in the history
Co-authored-by: yuluoyun <[email protected]>
  • Loading branch information
ShuoZhang2003 and echo840 authored Jan 10, 2024
1 parent 032b1f3 commit 546eebf
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
1 change: 1 addition & 0 deletions vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
'cogvlm-grounding-generalist':partial(CogVlm, name='cogvlm-grounding-generalist',tokenizer_name ='lmsys/vicuna-7b-v1.5'),
'cogvlm-chat':partial(CogVlm, name='cogvlm-chat',tokenizer_name ='lmsys/vicuna-7b-v1.5'),
'sharedcaptioner':partial(SharedCaptioner, model_path='Lin-Chen/ShareCaptioner'),
'monkey':partial(Monkey, model_path='echo840/Monkey'),
}

api_models = {
Expand Down
1 change: 1 addition & 0 deletions vlmeval/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from .llava_xtuner import LLaVA_XTuner
from .cogvlm import CogVlm
from .sharedcaptioner import SharedCaptioner
from .monkey import Monkey
43 changes: 43 additions & 0 deletions vlmeval/vlm/monkey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
import os.path as osp
from vlmeval.smp import isimg
import re

class Monkey:

INSTALL_REQ = False

def __init__(self, model_path='echo840/Monkey', **kwargs):
assert model_path is not None
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval()
self.kwargs = kwargs
warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
torch.cuda.empty_cache()

def generate(self, image_path, prompt, dataset=None):
cur_prompt = f'<img>{image_path}</img> {prompt} Answer: '
input_ids = self.tokenizer(cur_prompt, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids

output_ids = self.model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=False,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
length_penalty=3,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=self.tokenizer.eod_id,
eos_token_id=self.tokenizer.eod_id,
)
response = self.tokenizer.decode(output_ids[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()

return response

0 comments on commit 546eebf

Please sign in to comment.