Skip to content

Commit

Permalink
return empty string when context length exceed
Browse files Browse the repository at this point in the history
  • Loading branch information
ander1119 committed May 19, 2024
1 parent 5a2e686 commit f0e7cc4
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def forward(self, image, prompt, task='score', return_index=True, negative_categ
class MaskRCNNModel(BaseModel):
name = 'maskrcnn'

def __init__(self, gpu_number=1, threshold=config.detect_thresholds.maskrcnn):
def __init__(self, gpu_number=0, threshold=config.detect_thresholds.maskrcnn):
super().__init__(gpu_number)
# with HiddenPrints('MaskRCNN'):
obj_detect = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights='COCO_V1').to(self.dev)
Expand Down Expand Up @@ -429,7 +429,7 @@ def forward(self, image: torch.Tensor, text: List[str], return_labels: bool = Fa
class GLIPModel(BaseModel):
name = 'glip'

def __init__(self, model_size='large', gpu_number=1, *args):
def __init__(self, model_size='large', gpu_number=0, *args):
BaseModel.__init__(self, gpu_number)

with contextlib.redirect_stderr(open(os.devnull, "w")): # Do not print nltk_data messages when importing
Expand Down Expand Up @@ -910,13 +910,17 @@ def get_summarization(self, prompts) -> list[dict]:
"content": prompt
}
]
response = openai.ChatCompletion.create(
model=self.model,
messages=message,
# response_format={"type": "json_object"},
temperature=self.temperature,
)
responses.append(response.choices[0].message.content)
try:
response = openai.ChatCompletion.create(
model=self.model,
messages=message,
# response_format={"type": "json_object"},
temperature=self.temperature,
)
responses.append(response.choices[0].message.content)
except:
responses.append("")

return responses

def query_gpt3(self, prompt, model="text-davinci-003", max_tokens=16, logprobs=None, stream=False,
Expand Down Expand Up @@ -1394,7 +1398,7 @@ class BLIPModel(BaseModel):
max_batch_size = 32
seconds_collect_data = 0.2 # The queue has additionally the time it is executing the previous forward pass

def __init__(self, gpu_number=1, half_precision=config.blip_half_precision,
def __init__(self, gpu_number=0, half_precision=config.blip_half_precision,
blip_v2_model_type=config.blip_v2_model_type):
super().__init__(gpu_number)

Expand Down

0 comments on commit f0e7cc4

Please sign in to comment.