Skip to content

Commit

Permalink
Merge pull request #90 from jmisilo/refactor/run-ruff
Browse files Browse the repository at this point in the history
refactor: run ruff on entire repository
  • Loading branch information
jmisilo committed Dec 17, 2023
2 parents 4f74c7a + e05c332 commit a6ebb0e
Show file tree
Hide file tree
Showing 12 changed files with 418 additions and 359 deletions.
2 changes: 1 addition & 1 deletion src/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from data.dataset import *
from data.dataset import *
31 changes: 18 additions & 13 deletions src/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
'''
"""
Module contains Dataset class, collate function for DataLoader and loader getter function.
* MiniFlickrDataset loads data from pickle file and returns image embedding and caption.
* cl_fn is used to process batch of data and return tensors.
* get_loader returns DataLoader object.
'''
"""

import os
import pickle
Expand All @@ -15,13 +15,14 @@
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer


class MiniFlickrDataset(Dataset):
def __init__(self, path):
def __init__(self, path):
# check if file is file
if not os.path.isfile(path):
raise OSError('Dataset file not found. Downloading...')
raise OSError("Dataset file not found. Downloading...")

with open(path, 'rb') as f:
with open(path, "rb") as f:
self.data = pickle.load(f)

def __len__(self):
Expand All @@ -30,29 +31,33 @@ def __len__(self):
def __getitem__(self, idx):
return self.data[idx]


# collate_fn for DataLoader
def cl_fn(batch, tokenizer):
batch = list(zip(*batch))

_, img_emb, cap = batch
del batch

img_emb = torch.tensor(np.array(img_emb)) # better to convert list to numpy array first, then to tensor
cap = tokenizer(cap, padding=True, return_tensors='pt')
img_emb = torch.tensor(
np.array(img_emb)
) # better to convert list to numpy array first, then to tensor
cap = tokenizer(cap, padding=True, return_tensors="pt")

input_ids, attention_mask = cap['input_ids'], cap['attention_mask']
input_ids, attention_mask = cap["input_ids"], cap["attention_mask"]

return img_emb, input_ids, attention_mask


def get_loader(dataset, bs_exp=5, shuffle=True, num_workers=0, pin_memory=False):
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
tokenizer.pad_token = tokenizer.eos_token

return DataLoader(
dataset,
batch_size=2**bs_exp,
dataset,
batch_size=2**bs_exp,
collate_fn=lambda b: cl_fn(b, tokenizer),
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory
)
pin_memory=pin_memory,
)
45 changes: 27 additions & 18 deletions src/dataset_generation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'''
"""
Script to generate the actual dataset for the model. It generates pickle file dataset with image name, image embedding and caption for each image in passed dataset.
'''
"""

import os
import pickle
Expand All @@ -14,54 +14,63 @@
from transformers import CLIPModel, CLIPProcessor
from tqdm import tqdm

if __name__ == '__main__':
if __name__ == "__main__":
# Set constants
SEED = 100
DATA_PATH = os.path.join('data')
DATA_PATH = os.path.join("data")

# Set random seed
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load CLIP model and processor
preprocessor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14')
model = CLIPModel.from_pretrained('openai/clip-vit-large-patch14').vision_model.to(device)
preprocessor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").vision_model.to(
device
)

# Load dataset
df = pd.read_csv(os.path.join(DATA_PATH, 'raw', 'results.csv'), sep='|')
df = pd.read_csv(os.path.join(DATA_PATH, "raw", "results.csv"), sep="|")
df.columns = [col.strip() for col in df.columns]

df = df.drop(['comment_number'], axis=1)
df = df.drop(["comment_number"], axis=1)

# get every 5 element of the df (5 captions per image) and save image name with corresponding captions
ds = [(img_name, df[df['image_name'] == img_name]['comment'].values) for img_name, _ in df[0::5].to_numpy()]
ds = [
(img_name, df[df["image_name"] == img_name]["comment"].values)
for img_name, _ in df[0::5].to_numpy()
]

# Based on loaded dataset, create a list of (image name, image embedding, caption) tuples
results = []
loop = tqdm(ds, total=len(ds), position=0, leave=True)
for img_name, cap in loop:
try:
img = Image.open(os.path.join(DATA_PATH, 'raw', 'flickr30k_images', img_name))
img = Image.open(
os.path.join(DATA_PATH, "raw", "flickr30k_images", img_name)
)

with torch.no_grad():
img_prep = preprocessor(images=img, return_tensors='pt').to(device)
img_prep = preprocessor(images=img, return_tensors="pt").to(device)

img_features = model(**img_prep)
img_features = img_features.pooler_output
img_features = img_features.squeeze()
img_features = img_features.numpy()

for c in cap:
results.append((img_name, img_features, c[1:])) # because of the separator there is a space at the beginning of the caption

results.append(
(img_name, img_features, c[1:])
) # because of the separator there is a space at the beginning of the caption

except:
print(f'Lack of image {img_name}')
print(f"Lack of image {img_name}")

# save data into pickle file
# img_name, img_features, caption
with open(os.path.join(DATA_PATH, 'processed', 'dataset.pkl'), 'wb') as f:
pickle.dump(results, f)
with open(os.path.join(DATA_PATH, "processed", "dataset.pkl"), "wb") as f:
pickle.dump(results, f)
96 changes: 43 additions & 53 deletions src/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'''
"""
Script to evaluate the model on the whole test set and save the results in folder.
'''
"""

import argparse
import os
Expand All @@ -15,59 +15,43 @@
from tqdm import tqdm

from data import MiniFlickrDataset
from model import Net
from model import Net
from utils import ConfigS, ConfigL, download_weights

parser = argparse.ArgumentParser()

parser.add_argument(
'-C',
'--checkpoint-name',
type=str,
default='model.pt',
help='Checkpoint name'
"-C", "--checkpoint-name", type=str, default="model.pt", help="Checkpoint name"
)

parser.add_argument(
'-S',
'--size',
"-S",
"--size",
type=str,
default='S',
help='Model size [S, L]',
choices=['S', 'L', 's', 'l']
default="S",
help="Model size [S, L]",
choices=["S", "L", "s", "l"],
)

parser.add_argument(
'-I',
'--img-path',
type=str,
default='',
help='Path to the test image folder'
"-I", "--img-path", type=str, default="", help="Path to the test image folder"
)

parser.add_argument(
'-R',
'--res-path',
type=str,
default='',
help='Path to the results folder'
"-R", "--res-path", type=str, default="", help="Path to the results folder"
)

parser.add_argument(
'-T',
'--temperature',
type=float,
default=1.0,
help='Temperature for sampling'
"-T", "--temperature", type=float, default=1.0, help="Temperature for sampling"
)

args = parser.parse_args()

config = ConfigL() if args.size.upper() == 'L' else ConfigS()
config = ConfigL() if args.size.upper() == "L" else ConfigS()

ckp_path = os.path.join(config.weights_dir, args.checkpoint_name)

assert os.path.exists(args.img_path), 'Path to the test image folder does not exist'
assert os.path.exists(args.img_path), "Path to the test image folder does not exist"

# set seed
random.seed(config.seed)
Expand All @@ -77,18 +61,19 @@
torch.backends.cudnn.deterministic = True

is_cuda = torch.cuda.is_available()
device = 'cuda' if is_cuda else 'cpu'
device = "cuda" if is_cuda else "cpu"


def evaluate_dataset(model, dataset, img_path, save_path, temperature=1.0):
'''
Evaluate model on dataset.
Args:
model: model to evaluate
dataset: dataset to evaluate on
img_path: path to images
save_path: path to save results
'''
"""
Evaluate model on dataset.
Args:
model: model to evaluate
dataset: dataset to evaluate on
img_path: path to images
save_path: path to save results
"""
model.eval()

loop = tqdm(dataset, total=len(dataset))
Expand All @@ -100,33 +85,36 @@ def evaluate_dataset(model, dataset, img_path, save_path, temperature=1.0):

plt.imshow(img)
plt.title(caption)
plt.axis('off')
plt.axis("off")

plt.savefig(os.path.join(save_path, img_name), bbox_inches='tight')
plt.savefig(os.path.join(save_path, img_name), bbox_inches="tight")

plt.clf()
plt.close()

if __name__ == '__main__':

if __name__ == "__main__":
model = Net(
clip_model=config.clip_model,
text_model=config.text_model,
ep_len=config.ep_len,
num_layers=config.num_layers,
n_heads=config.n_heads,
forward_expansion=config.forward_expansion,
dropout=config.dropout,
num_layers=config.num_layers,
n_heads=config.n_heads,
forward_expansion=config.forward_expansion,
dropout=config.dropout,
max_len=config.max_len,
device=device
device=device,
)

dataset = MiniFlickrDataset(os.path.join('data', 'processed', 'dataset.pkl'))
dataset = MiniFlickrDataset(os.path.join("data", "processed", "dataset.pkl"))

config.train_size = int(config.train_size * len(dataset))
config.val_size = int(config.val_size * len(dataset))
config.test_size = len(dataset) - config.train_size - config.val_size

_, _, test_dataset = random_split(dataset, [config.train_size, config.val_size, config.test_size])
_, _, test_dataset = random_split(
dataset, [config.train_size, config.val_size, config.test_size]
)

if not os.path.exists(config.weights_dir):
os.makedirs(config.weights_dir)
Expand All @@ -135,11 +123,13 @@ def evaluate_dataset(model, dataset, img_path, save_path, temperature=1.0):
download_weights(ckp_path, args.size)

checkpoint = torch.load(ckp_path, map_location=device)
model.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)

save_path = os.path.join(args.res_path, f'{args.checkpoint_name[:-3]}_{args.size.upper()}')
save_path = os.path.join(
args.res_path, f"{args.checkpoint_name[:-3]}_{args.size.upper()}"
)

if not os.path.exists(save_path):
os.mkdir(save_path)

evaluate_dataset(model, test_dataset, args.img_path, save_path, args.temperature)
evaluate_dataset(model, test_dataset, args.img_path, save_path, args.temperature)
2 changes: 1 addition & 1 deletion src/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from model.trainer import *
from model.model import *
from model.model import *
Loading

0 comments on commit a6ebb0e

Please sign in to comment.