Skip to content

Commit

Permalink
Modify for CV project.
Browse files Browse the repository at this point in the history
  • Loading branch information
wzf2000 committed Jun 9, 2024
1 parent c3dcd9f commit 7ecb9e8
Show file tree
Hide file tree
Showing 17 changed files with 564 additions and 75 deletions.
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
dataset/HM/*.zip
dataset/HM/**/*.jpg
dataset/HM/*.lmdb*
dataset/HM/*.tsv
pretrained_models/**/*.bin
pretrained_models/**/*.pth
__pycache__/
logs_*/
log_*
checkpoint_*/
10 changes: 7 additions & 3 deletions inbatch_sasrec_e2e_vision/data_utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ def get_itemId_embeddings(model, item_num, test_batch_size, args, local_rank):

def get_itemLMDB_embeddings(model, item_num, item_id_to_keys, lmdb_data, test_batch_size, args, local_rank):
model.eval()
item_dataset = Build_Lmdb_Eval_Dataset(data=np.arange(item_num + 1), item_id_to_keys=item_id_to_keys,
db_path=os.path.join(args.root_data_dir, args.dataset, lmdb_data),
resize=args.CV_resize)
if args.testing:
item_dataset = Build_Id_Eval_Dataset(data=np.arange(item_num + 1))
else:
item_dataset = Build_Lmdb_Eval_Dataset(data=np.arange(item_num + 1),
item_id_to_keys=item_id_to_keys,
db_path=os.path.join(args.root_data_dir, args.dataset, lmdb_data),
resize=args.CV_resize)
item_dataloader = DataLoader(item_dataset, batch_size=test_batch_size,
num_workers=args.num_workers, pin_memory=True)
item_embeddings = []
Expand Down
23 changes: 23 additions & 0 deletions inbatch_sasrec_e2e_vision/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,29 @@
from .modules import TransformerEncoder
from torch.nn.init import xavier_normal_, constant_

class MLP_Layers(torch.nn.Module):
def __init__(self, word_embedding_dim, item_embedding_dim, layers, drop_rate):
super(MLP_Layers, self).__init__()
self.layers = [word_embedding_dim] + layers + [item_embedding_dim]
mlp_modules = []
for idx, (input_size, output_size) in enumerate(zip(self.layers[:-1], self.layers[1:])):
mlp_modules.append(nn.Dropout(p=drop_rate))
mlp_modules.append(nn.Linear(input_size, output_size))
mlp_modules.append(nn.BatchNorm1d(output_size))
mlp_modules.append(nn.GELU())
self.mlp_layers = nn.Sequential(*mlp_modules)
self.apply(self._init_weights)

def _init_weights(self, module):
if isinstance(module, nn.Embedding):
xavier_normal_(module.weight.data)
elif isinstance(module, nn.Linear):
xavier_normal_(module.weight.data)
if module.bias is not None:
constant_(module.bias.data, 0)

def forward(self, sample_items):
return self.mlp_layers(sample_items)

class MAE_Encoder(torch.nn.Module):
def __init__(self, image_net, item_dim):
Expand Down
30 changes: 22 additions & 8 deletions inbatch_sasrec_e2e_vision/model/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
from torch import nn
from torch.nn.init import xavier_normal_
from .encoders import Resnet_Encoder, Vit_Encoder, User_Encoder, MAE_Encoder
from .encoders import Resnet_Encoder, Vit_Encoder, User_Encoder, MAE_Encoder, MLP_Layers


class Model(torch.nn.Module):
def __init__(self, args, item_num, use_modal, image_net, pop_prob_list):
def __init__(self, args, item_num, use_modal, image_net, pop_prob_list, input_dim: int | None = None, cv_embeddings: torch.Tensor | None = None):
super(Model, self).__init__()
self.args = args
self.use_modal = use_modal
Expand All @@ -21,12 +21,26 @@ def __init__(self, args, item_num, use_modal, image_net, pop_prob_list):
n_layers=args.transformer_block)

if self.use_modal:
if 'resnet' in args.CV_model_load:
self.cv_encoder = Resnet_Encoder(image_net=image_net)
elif 'beit' in args.CV_model_load or 'swin' in args.CV_model_load:
self.cv_encoder = Vit_Encoder(image_net=image_net)
elif 'mae' in args.CV_model_load or "checkpoint" in args.CV_model_load:
self.cv_encoder = MAE_Encoder(image_net=image_net, item_dim=args.embedding_dim)
if args.testing:
assert input_dim is not None
if input_dim <= 128:
layers = [256, 512, 1024, 2048]
elif input_dim <= 512:
layers = [1024, 1024, 2048, 2048]
else:
layers = [2048, 2048, 2048, 2048]
self.cv_encoder = nn.Sequential(
nn.Embedding.from_pretrained(cv_embeddings, freeze=not args.train_emb) if cv_embeddings is not None else nn.Embedding(item_num + 1, input_dim, padding_idx=0),
nn.Linear(input_dim, args.embedding_dim) if not args.enhance else MLP_Layers(input_dim, args.embedding_dim, layers, args.drop_rate),
nn.GELU(),
)
else:
if 'resnet' in args.CV_model_load:
self.cv_encoder = Resnet_Encoder(image_net=image_net)
elif 'beit' in args.CV_model_load or 'swin' in args.CV_model_load:
self.cv_encoder = Vit_Encoder(image_net=image_net)
elif 'mae' in args.CV_model_load or "checkpoint" in args.CV_model_load:
self.cv_encoder = MAE_Encoder(image_net=image_net, item_dim=args.embedding_dim)
else:
self.id_embedding = nn.Embedding(item_num + 1, args.embedding_dim, padding_idx=0)
xavier_normal_(self.id_embedding.weight.data)
Expand Down
5 changes: 4 additions & 1 deletion inbatch_sasrec_e2e_vision/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def parse_args():
parser.add_argument("--label_screen", type=str, default='None')
parser.add_argument("--logging_num", type=int, default=8)
parser.add_argument("--testing_num", type=int, default=1)
parser.add_argument("--local_rank", default=-1, type=int)
# store_true
parser.add_argument("--testing", action='store_true')
parser.add_argument("--train_emb", action='store_true')
parser.add_argument("--enhance", action='store_true')

args = parser.parse_args()

Expand Down
88 changes: 67 additions & 21 deletions inbatch_sasrec_e2e_vision/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from parameters import parse_args
from model import Model
from data_utils import read_images, read_behaviors, Build_Lmdb_Dataset, Build_Id_Dataset, LMDB_Image, \
from data_utils import read_images, read_behaviors, Build_Lmdb_Dataset, Build_Id_Dataset, Build_Lmdb_Eval_Dataset, LMDB_Image, \
eval_model, get_itemId_embeddings, get_itemLMDB_embeddings
from data_utils.utils import *
import torchvision.models as models
Expand All @@ -19,13 +19,15 @@
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.init import xavier_normal_, constant_

from tqdm import tqdm

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def train(args, use_modal, local_rank):
if use_modal:
if 'resnet' in args.CV_model_load:
cv_model_load = '../../pretrained_models/' + args.CV_model_load
cv_model_load = '../pretrained_models/' + args.CV_model_load
if '18' in cv_model_load:
cv_model = models.resnet18(pretrained=False)
elif '34' in cv_model_load:
Expand All @@ -40,18 +42,24 @@ def train(args, use_modal, local_rank):
cv_model = None
cv_model.load_state_dict(torch.load(cv_model_load))
num_fc_ftr = cv_model.fc.in_features
cv_model.fc = nn.Linear(num_fc_ftr, args.embedding_dim)
xavier_normal_(cv_model.fc.weight.data)
if cv_model.fc.bias is not None:
constant_(cv_model.fc.bias.data, 0)
if args.testing:
cv_model.fc = nn.Identity()
else:
cv_model.fc = nn.Linear(num_fc_ftr, args.embedding_dim)
xavier_normal_(cv_model.fc.weight.data)
if cv_model.fc.bias is not None:
constant_(cv_model.fc.bias.data, 0)
elif 'swin' in args.CV_model_load:
cv_model_load = '../../pretrained_models/' + args.CV_model_load
cv_model_load = '../pretrained_models/' + args.CV_model_load
cv_model = SwinForImageClassification.from_pretrained(cv_model_load)
num_fc_ftr = cv_model.classifier.in_features
cv_model.classifier = nn.Linear(num_fc_ftr, args.embedding_dim)
xavier_normal_(cv_model.classifier.weight.data)
if cv_model.classifier.bias is not None:
constant_(cv_model.classifier.bias.data, 0)
if args.testing:
cv_model.classifier = nn.Identity()
else:
cv_model.classifier = nn.Linear(num_fc_ftr, args.embedding_dim)
xavier_normal_(cv_model.classifier.weight.data)
if cv_model.classifier.bias is not None:
constant_(cv_model.classifier.bias.data, 0)
else:
cv_model = None

Expand All @@ -72,10 +80,18 @@ def train(args, use_modal, local_rank):

Log_file.info('build dataset...')
if use_modal:
train_dataset = Build_Lmdb_Dataset(u2seq=users_train, item_num=item_num, max_seq_len=args.max_seq_len,
db_path=os.path.join(args.root_data_dir, args.dataset, args.lmdb_data),
item_id_to_keys=item_id_to_keys, resize=args.CV_resize,
neg_sampling_list=neg_sampling_list)
if not args.testing:
train_dataset = Build_Lmdb_Dataset(u2seq=users_train, item_num=item_num, max_seq_len=args.max_seq_len,
db_path=os.path.join(args.root_data_dir, args.dataset, args.lmdb_data),
item_id_to_keys=item_id_to_keys, resize=args.CV_resize,
neg_sampling_list=neg_sampling_list)
else:
train_dataset = Build_Id_Dataset(u2seq=users_train, item_num=item_num, max_seq_len=args.max_seq_len,
neg_sampling_list=neg_sampling_list)
image_dataset = Build_Lmdb_Eval_Dataset(data=np.arange(item_num + 1),
item_id_to_keys=item_id_to_keys,
db_path=os.path.join(args.root_data_dir, args.dataset, args.lmdb_data),
resize=args.CV_resize)
else:
train_dataset = Build_Id_Dataset(u2seq=users_train, item_num=item_num, max_seq_len=args.max_seq_len,
neg_sampling_list=neg_sampling_list)
Expand All @@ -92,9 +108,32 @@ def worker_init_reset_seed(worker_id):
Log_file.info('build dataloader...')
train_dl = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers,
worker_init_fn=worker_init_reset_seed, pin_memory=True, sampler=train_sampler)
if use_modal and args.testing:
assert cv_model is not None
assert 'num_fc_ftr' in locals()
image_dl = DataLoader(image_dataset, batch_size=256, num_workers=args.num_workers,
pin_memory=True, shuffle=False)
cv_model.to(local_rank)
with torch.no_grad():
tensors = []
for data in tqdm(image_dl, desc='Getting embeddings'):
data = data.to(local_rank)
data = data.view(-1, 3, args.CV_resize, args.CV_resize)
if 'swin' in args.CV_model_load:
tensor = cv_model(data)[0] # bs, ed
else:
tensor = cv_model(data)
tensors.append(tensor)
tensors = torch.cat(tensors, dim=0) # data_num, ed
Log_file.info(f"tensors shape: {tensors.shape}")
Log_file.info(f"item_num: {item_num}")
Log_file.info(f"num_fc_ftr: {num_fc_ftr}")

Log_file.info('build model...')
model = Model(args, item_num, use_modal, cv_model, pop_prob_list).to(local_rank)
if not args.testing or not use_modal:
model = Model(args, item_num, use_modal, cv_model, pop_prob_list).to(local_rank)
else:
model = Model(args, item_num, use_modal, cv_model, pop_prob_list, input_dim=num_fc_ftr, cv_embeddings=tensors).to(local_rank)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(local_rank)

if 'None' not in args.load_ckpt_name:
Expand Down Expand Up @@ -134,9 +173,10 @@ def worker_init_reset_seed(worker_id):
{'params': recsys_params, 'lr': args.lr, 'weight_decay': args.l2_weight}
])

Log_file.info("***** {} parameters in images, {} parameters in model *****".format(
len(list(model.module.cv_encoder.image_net.parameters())),
len(list(model.module.parameters()))))
if not args.testing:
Log_file.info("***** {} parameters in images, {} parameters in model *****".format(
len(list(model.module.cv_encoder.image_net.parameters())),
len(list(model.module.parameters()))))

for children_model in optimizer.state_dict()['param_groups']:
Log_file.info("***** {} parameters have learning rate {}, weight_decay {} *****".format(
Expand Down Expand Up @@ -200,7 +240,7 @@ def worker_init_reset_seed(worker_id):
sample_items_id, sample_items, log_mask = data
sample_items_id, sample_items, log_mask = \
sample_items_id.to(local_rank), sample_items.to(local_rank), log_mask.to(local_rank)
if use_modal:
if use_modal and not args.testing:
sample_items = sample_items.view(-1, 3, args.CV_resize, args.CV_resize)
else:
sample_items = sample_items.view(-1)
Expand Down Expand Up @@ -298,7 +338,7 @@ def setup_seed(seed):

if __name__ == "__main__":
args = parse_args()
local_rank = args.local_rank
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')
setup_seed(12345)
Expand All @@ -313,6 +353,12 @@ def setup_seed(seed):
f'_ed_{args.embedding_dim}_bs_{args.batch_size*gpus}' \
f'_lr_{args.lr}_Flr_{args.fine_tune_lr}' \
f'_L2_{args.l2_weight}_FL2_{args.fine_tune_l2_weight}'
if args.testing:
log_paras += '_testing'
if args.train_emb:
log_paras += '_train_emb'
if args.enhance:
log_paras += '_enhance'
else:
is_use_modal = False
model_load = 'id'
Expand Down
39 changes: 27 additions & 12 deletions inbatch_sasrec_e2e_vision/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def test(args, use_modal, local_rank):
if use_modal:
if 'resnet' in args.CV_model_load:
cv_model_load = '../../pretrained_models/' + args.CV_model_load
cv_model_load = '../pretrained_models/' + args.CV_model_load
if '18' in cv_model_load:
cv_model = models.resnet18(pretrained=False)
elif '34' in cv_model_load:
Expand All @@ -40,18 +40,24 @@ def test(args, use_modal, local_rank):
cv_model = None
cv_model.load_state_dict(torch.load(cv_model_load))
num_fc_ftr = cv_model.fc.in_features
cv_model.fc = nn.Linear(num_fc_ftr, args.embedding_dim)
xavier_normal_(cv_model.fc.weight.data)
if cv_model.fc.bias is not None:
constant_(cv_model.fc.bias.data, 0)
if args.testing:
cv_model.fc = nn.Identity()
else:
cv_model.fc = nn.Linear(num_fc_ftr, args.embedding_dim)
xavier_normal_(cv_model.fc.weight.data)
if cv_model.fc.bias is not None:
constant_(cv_model.fc.bias.data, 0)
elif 'swin' in args.CV_model_load:
cv_model_load = '../../pretrained_models/' + args.CV_model_load
cv_model_load = '../pretrained_models/' + args.CV_model_load
cv_model = SwinForImageClassification.from_pretrained(cv_model_load)
num_fc_ftr = cv_model.classifier.in_features
cv_model.classifier = nn.Linear(num_fc_ftr, args.embedding_dim)
xavier_normal_(cv_model.classifier.weight.data)
if cv_model.classifier.bias is not None:
constant_(cv_model.classifier.bias.data, 0)
if args.testing:
cv_model.classifier = nn.Identity()
else:
cv_model.classifier = nn.Linear(num_fc_ftr, args.embedding_dim)
xavier_normal_(cv_model.classifier.weight.data)
if cv_model.classifier.bias is not None:
constant_(cv_model.classifier.bias.data, 0)
else:
cv_model = None

Expand All @@ -73,7 +79,10 @@ def test(args, use_modal, local_rank):


Log_file.info('build model...')
model = Model(args, item_num, use_modal, cv_model, pop_prob_list).to(local_rank)
if args.testing:
model = Model(args, item_num, use_modal, cv_model, pop_prob_list, input_dim=num_fc_ftr).to(local_rank)
else:
model = Model(args, item_num, use_modal, cv_model, pop_prob_list).to(local_rank)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(local_rank)

Log_file.info('load ckpt if not None...')
Expand Down Expand Up @@ -120,7 +129,7 @@ def setup_seed(seed):

if __name__ == "__main__":
args = parse_args()
local_rank = args.local_rank
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')
setup_seed(12345)
Expand All @@ -135,6 +144,12 @@ def setup_seed(seed):
f'_ed_{args.embedding_dim}_bs_{args.batch_size*gpus}' \
f'_lr_{args.lr}_Flr_{args.fine_tune_lr}' \
f'_L2_{args.l2_weight}_FL2_{args.fine_tune_l2_weight}'
if args.testing:
log_paras += '_testing'
if args.train_emb:
log_paras += '_train_emb'
if args.enhance:
log_paras += '_enhance'
else:
is_use_modal = False
model_load = 'id'
Expand Down
Loading

0 comments on commit 7ecb9e8

Please sign in to comment.