Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference on one image and one text #21

Open
GabrielSoranzoUPEC opened this issue Apr 29, 2024 · 0 comments
Open

Inference on one image and one text #21

GabrielSoranzoUPEC opened this issue Apr 29, 2024 · 0 comments

Comments

@GabrielSoranzoUPEC
Copy link

Hi,
Thanks for your work.
I made a code to make inference for one image and one text, if it helps someone:

Getting the model:

def get_model(model_path):
    model = CLIPModel().to(CFG.device)
    model.load_state_dict(torch.load(model_path, map_location=CFG.device))
    model.eval()
    return model

Getting an image embedding:

def get_image_embedding(model,image_path):
  image = cv2.imread(image_path)
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  transforms=get_transforms(mode='valid')
  image = transforms(image=image)['image']
  image_tensor = torch.tensor(image).permute(2, 0, 1).float().unsqueeze(0)
  image_features = model.image_encoder(image_tensor.to(CFG.device))
  image_embeddings = model.image_projection(image_features)
  return image_embeddings

Getting text embedding:

def get_text_embedding(model,query):
    tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
    encoded_query = tokenizer([query])
    batch = {
        key: torch.tensor(values).to(CFG.device)
        for key, values in encoded_query.items()
    }
    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        text_embeddings = model.text_projection(text_features)
    return text_embeddings

Make an inference:

model=get_model("best.pt")

image_embedding=get_image_embedding(model,"cuillere.jpg")
text_embedding=get_text_embedding(model,"a spoon")

image_embeddings_n = F.normalize(image_embedding, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embedding, p=2, dim=-1)
dot_similarity = text_embeddings_n @ image_embeddings_n.T
print(dot_similarity.item())

Hope this help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant