Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to reproduce Fig.3 Qualitative result about similarity map? #32

Open
euncheolChoi opened this issue May 20, 2024 · 1 comment
Open

Comments

@euncheolChoi
Copy link

euncheolChoi commented May 20, 2024

Hello. Thank you for sharing your great work!
I've been trying to visualize a similarity map like the image in Fig. 3 from Anyloc paper for a few days now. However, I have not been able to get the right result.
The code below is the code I was working with. I am performing PCA on the norm_patchtoken obtained from dino, and generating a similarity map on the feature map obtained. However, the result is shown below.

Anyloc's repository doesn't seem to include the code to generate this similarity map. I was wondering if you could provide the code to reproduce the visualization result in Fig.3, or tell me how to visualize it.
Thank you.

Sample image

dog_2

Feature map after PCA, Interpolation

pca_resized_dog_2

Similarity map

dog_2_similarity_map

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import os
import cv2
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

patch_h = 16
patch_w = 16
feat_dim = 1536 # vitg14

transform = T.Compose([
    T.Resize((patch_h * 14, patch_w * 14)),
    T.CenterCrop((patch_h * 14, patch_w * 14)),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')

print(dinov2_vitg14)

features = torch.zeros(1, patch_h * patch_w, feat_dim)
imgs_tensor = torch.zeros(1, 3, patch_h * 14, patch_w * 14)
for i in range(1):
    img_path = "/root/workspace/test_debug/dog_2.jpg"
    img = Image.open(img_path).convert('RGB')
    imgs_tensor[i] = transform(img)[:3]
with torch.no_grad():
    # features_dict = dinov2_vits14.forward_features(imgs_tensor) # torch.Size([1, 256, 384])
    features_dict = dinov2_vitg14.forward_features(imgs_tensor) # torch.Size([1, 256, 384])
    features = features_dict['x_norm_patchtokens']

from sklearn.decomposition import PCA

features = features.reshape(1 * patch_h * patch_w, feat_dim)

pca = PCA(n_components=3)  
pca.fit(features)  
pca_features = pca.transform(features).reshape(3, 16, 16)
print(f"pca_feature_shape : {pca_features.shape}")   # 368*256 => 256*3 

pca_features_tensor = torch.tensor(pca_features).unsqueeze(0)  # Add batch dimension
print(f"pca_features_tensor.shape : {pca_features_tensor.shape}")
pca_features_resized = F.interpolate(pca_features_tensor, size=(576,1024), mode='bilinear', align_corners=True)
pca_features_resized = pca_features_resized.squeeze(0).permute(1, 2, 0).numpy()

pca_features_resized = (pca_features_resized - pca_features_resized.min()) / (pca_features_resized.max() - pca_features_resized.min())

print(f"pca_features_resized shape ?? : {pca_features_resized.shape}")
pca_image = (pca_features_resized * 255).astype(np.uint8)
pca_image = Image.fromarray(pca_image)
pca_image.save('/root/workspace/result/pca_resized_dog_2.png')

def compute_similarity_map(feature_map, ref_point):
    ref_value = feature_map[ref_point]
    diff_map = np.abs(feature_map - ref_value)
    similarity_map = 1 - (diff_map / diff_map.max())
    return similarity_map

ref_point = (280, 800)
similarity_map = compute_similarity_map(pca_features_resized, ref_point)
output_dir = "/root/workspace/aerial_pr/dinov2_mixvpr_template/test_debug/result/"

plt.imshow(similarity_map, cmap='hot')
plt.colorbar()
plt.axis('off')
similarity_map_path = os.path.join(output_dir, 'dog_2_similarity_map.png')
plt.savefig(similarity_map_path, bbox_inches='tight', pad_inches=0)
plt.close()
@TheProjectsGuy
Copy link
Collaborator

TheProjectsGuy commented Jun 12, 2024

Hey @euncheolChoi, thank you for your interest in our work.

Figure 3 of our paper compares the similarity of the facets across multiple layers. We use the dino_v2_sim_facets.py file for generating a joblib dump (to save the similarity maps). This save is done here

# All results as joblib dump
res = {"source": simg_np, "target": timg_np, "similarities": sims,
"max": {"key": key_max, "query": query_max,
"value": value_max, "token": token_max},
"pix_loc": pix_loc}
joblib.dump(res, f"{dst_dir}/{save_fname}.gz")
if show_plts:
fig.show()
print(f"Saved in file: {dst_dir}/{save_fname}.[png,gz]")

You can see the arguments of the script for more information. We use a custom script like facet_sim_visualization.py to finally get the figure.

The script you gave seems to do dimensionality reduction using PCA and comparing similarity of a query point with all other points in the reduced space. What we do (instead) is that we take full-dimension features from a particular layer and facet and visualize similarity with a query feature (related to a point).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants