From be7a8b723dd3916e88e90bbacffe704538cef872 Mon Sep 17 00:00:00 2001 From: "Yixin Wan elaine1wan@g.ucla.edu" Date: Sun, 9 Jul 2023 01:17:03 -0700 Subject: [PATCH] modified model file --- src/pip_model.py | 164 +++++++++++++++-------------------------------- 1 file changed, 50 insertions(+), 114 deletions(-) diff --git a/src/pip_model.py b/src/pip_model.py index fbf0694..bbb0a49 100644 --- a/src/pip_model.py +++ b/src/pip_model.py @@ -20,7 +20,7 @@ def __init__(self, config, tokenizer, device, debug=True): self.config = config self.tokenizer = tokenizer self.parse_vocab_size = 83 - self.input_size = 768 # Bert Hidden size + self.input_size = 768 # Bart Hidden size self.bart_config = BartConfig.from_pretrained('facebook/bart-base') @@ -38,63 +38,68 @@ def __init__(self, config, tokenizer, device, debug=True): self.n_embd_per_head = self.input_size // self.n_heads self.prefix_length = self.config.prefix_length - if self.config.prefix_type in ["attention0", "ptuning"]: - # # self.attention = nn.MultiheadAttention(embed_dim = 1, num_heads = 1) - # self.attention = nn.MultiheadAttention(embed_dim = self.input_size, num_heads = self.n_heads) - self.register_buffer("prefix_ids", torch.arange(self.config.prefix_length).expand((1, -1))) # (1, prefix_len) - self.wte_1 = nn.Embedding(self.config.prefix_length, self.input_size) - self.wte_2 = nn.Embedding(self.config.prefix_length, self.input_size) - self.wte_3 = nn.Embedding(self.config.prefix_length, self.input_size) - # self.wte_1 = nn.Embedding(self.parse_vocab_size, self.input_size) - # self.wte_2 = nn.Embedding(self.parse_vocab_size, self.input_size) - # self.wte_3 = nn.Embedding(self.parse_vocab_size, self.input_size) - - if self.config.prefix_type == "attention0": - # # self.linear = nn.Linear(self.prefix_length, self.prefix_length) - # # self.linear_1 = nn.Linear(self.input_size, self.input_size) - # # self.linear_2 = nn.Linear(self.input_size, self.input_size) - # self.linear = nn.Linear(self.input_size, self.input_size) - # # self.mu = nn.Parameter(torch.Tensor(1),requires_grad=True) - # self.mu = 1 + # initialize prefix + self.register_buffer("prefix_ids", torch.arange(self.config.prefix_length).expand((1, -1))) # (1, prefix_len) + self.wte_1 = nn.Embedding(self.config.prefix_length, self.input_size) + self.wte_2 = nn.Embedding(self.config.prefix_length, self.input_size) + self.wte_3 = nn.Embedding(self.config.prefix_length, self.input_size) + + if self.config.prefix_type in ["attention0", "attention0_direct"]: + if self.config.prefix_type == "attention0": + self.attention = nn.MultiheadAttention(embed_dim = self.input_size, num_heads = self.n_heads) + self.linear = nn.Linear(self.input_size, self.input_size) + self.mu = 1 self.control_trans_1 = nn.Sequential( nn.Linear(self.input_size, self.prefix_config.bottleneck_size), nn.Tanh(), - # nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # original nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec ) self.control_trans_2 = nn.Sequential( nn.Linear(self.input_size, self.prefix_config.bottleneck_size), nn.Tanh(), - # nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # original nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec ) self.control_trans_3 = nn.Sequential( nn.Linear(self.input_size, self.prefix_config.bottleneck_size), nn.Tanh(), - # nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # original nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec ) + elif self.config.prefix_type == "ptuning": + self.control_trans_1 = nn.Sequential( + nn.Linear(self.input_size, self.prefix_config.bottleneck_size), + nn.Tanh(), + nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec + ) + self.control_trans_2 = nn.Sequential( + nn.Linear(self.input_size, self.prefix_config.bottleneck_size), + nn.Tanh(), + nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec + ) + self.control_trans_3 = nn.Sequential( + nn.Linear(self.input_size, self.prefix_config.bottleneck_size), + nn.Tanh(), + nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec + ) + + elif self.config.prefix_type == "ptuning_large": self.control_trans_1 = nn.Sequential( nn.Linear(self.input_size, self.prefix_config.bottleneck_size), nn.Tanh(), nn.Linear(self.prefix_config.bottleneck_size, 2 * self.prefix_config.bottleneck_size), - # nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # original nn.Linear(2 * self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec ) self.control_trans_2 = nn.Sequential( nn.Linear(self.input_size, self.prefix_config.bottleneck_size), nn.Tanh(), nn.Linear(self.prefix_config.bottleneck_size, 2 * self.prefix_config.bottleneck_size), - # nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # original nn.Linear(2 * self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec ) self.control_trans_3 = nn.Sequential( nn.Linear(self.input_size, self.prefix_config.bottleneck_size), nn.Tanh(), nn.Linear(self.prefix_config.bottleneck_size, 2 * self.prefix_config.bottleneck_size), - # nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # original nn.Linear(2 * self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec ) @@ -222,47 +227,16 @@ def process_pip_data(self, src_sents, src_synts, tgt_synts, tgt_sents=None): # prefix_inputs = self.prefix_tokenizer(prefix_inputs, return_tensors='pt', padding=True) # (bsz, seq_len) # orginal if self.config.prefix_type == "attention0": - # for ptuning-based - # prefix_inputs = self.get_synt_tok(prefix_inputs, self.prefix_tok_word2idx) # (bsz, prefix_len) - # prefix_inputs = self.prefix_tokenizer(prefix_inputs, return_tensors='pt', padding=True) # (bsz, seq_len) # orginal - # for cossim-based p tuning prefix_inputs = self.prefix_tokenizer(prefix_inputs, return_tensors='pt', max_length = self.config.prefix_length, padding='max_length') prefix_inputs = prefix_inputs['input_ids'].to(self.device) - # self.enc_outputs_1 = enc_outputs.hidden_states[0].to(self.device) - # self.enc_outputs_2 = enc_outputs.last_hidden_state.to(self.device) - - # for encoder input cos sim - # enc_outputs = self.model.model.encoder(prefix_inputs,output_hidden_states=True) - # self.enc_inputs = enc_outputs.hidden_states[0].to(self.device) # for encoder output cos sim enc_outputs = self.model.model.encoder(prefix_inputs,output_hidden_states=True) self.enc_outputs = enc_outputs.last_hidden_state.to(self.device) - if self.config.prefix_type not in ["attention0", "ptuning"]: - prefix_inputs = self.model.model.encoder(prefix_inputs)["last_hidden_state"] - - if self.config.prefix_type == "attention0": - # original - # prefix_inputs = torch.cat([self.prefix_ids.expand(prefix_inputs.size()[0], -1), prefix_inputs], dim=1) # (batch_size, prefix_length + seq_length, input_size) - # prefix_inputs = self.attention(prefix_inputs.unsqueeze(-1).permute(1,0,2), prefix_inputs.unsqueeze(-1).permute(1,0,2), prefix_inputs.unsqueeze(-1).permute(1,0,2))[0].permute(1,0,2)[:, :self.config.prefix_length, :].squeeze(2) - # prefix_1 = self.wte_1(self.prefix_ids.expand(prefix_inputs.size()[0], -1)) # (batch_size, prefix_length, input_size) - # prefix_2 = self.wte_2(self.prefix_ids.expand(prefix_inputs.size()[0], -1)) # (batch_size, prefix_length, input_size) - # prefix_3 = self.wte_3(self.prefix_ids.expand(prefix_inputs.size()[0], -1)) # (batch_size, prefix_length, input_size) - - # 1 for mu - # prefix_inputs = (1 - self.mu) * self.prefix_ids.expand(prefix_inputs.size()[0], -1) + self.mu * prefix_inputs # original - # 2 + if self.config.prefix_type in ["attention0", "attention0_direct"]: prefix_inputs = self.prefix_ids.expand(len(prefix_inputs), -1) # for prefix tuning - # 3 for linear - # prefix_inputs = self.prefix_ids.expand(prefix_inputs.size()[0], -1) + self.linear(prefix_inputs) - # prefix_inputs = prefix_inputs - torch.min(prefix_inputs) - # prefix_inputs = prefix_inputs * (self.prefix_length - 1) // torch.max(prefix_inputs) - # 4 for cross similarity - # prefix_inputs = self.attention(prefix_inputs.unsqueeze(-1).permute(1,0,2).to(torch.float32), self.prefix_ids.expand(prefix_inputs.size()[0], -1).unsqueeze(-1).permute(1,0,2).to(torch.float32), self.prefix_ids.expand(prefix_inputs.size()[0], -1).unsqueeze(-1).permute(1,0,2).to(torch.float32))[0].permute(1,0,2).squeeze(2) - # prefix_inputs = prefix_inputs - torch.min(prefix_inputs) - # prefix_inputs = prefix_inputs * (self.prefix_length - 1) // torch.max(prefix_inputs) prefix_inputs = prefix_inputs.to(torch.long) # print('prefix', prefix_inputs[0]) @@ -298,44 +272,24 @@ def process_pip_data(self, src_sents, src_synts, tgt_synts, tgt_sents=None): # For regular (only enc) prefix key_values_1 = key_values_1.permute(0, 2, 1, 3, 4).split(self.n_layers, dim = 1) - # # for cos sim based p-tuning - # batch_size = key_values_3.size()[0] - # # for modified cosine similarity - # # for enc input value cos sim - # # prefix_enc_inputs = torch.cat([key_values_1[1][:,0,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_inputs], dim=1) # (batch_size, prefix_length + seq_length, input_size) - # # self.prefix_enc_inputs = self.attention(prefix_enc_inputs.permute(1,0,2), prefix_enc_inputs.permute(1,0,2), prefix_enc_inputs.permute(1,0,2))[0].permute(1,0,2)[:, :self.config.prefix_length, :].to(self.device) # for padded prefix - # # self.prefix_enc_inputs = self.linear(self.prefix_enc_inputs) - # # for enc input key cos sim - # # prefix_enc_inputs = torch.cat([key_values_1[0][:,0,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_inputs], dim=1) # (batch_size, prefix_length + seq_length, input_size) - # # self.prefix_enc_inputs = self.attention(prefix_enc_inputs.permute(1,0,2), prefix_enc_inputs.permute(1,0,2), prefix_enc_inputs.permute(1,0,2))[0].permute(1,0,2)[:, :self.config.prefix_length, :].to(self.device) # for padded prefix - # # self.prefix_enc_inputs = self.linear(self.prefix_enc_inputs) - # # for enc output cos sim - # # prefix_enc_outputs = torch.cat([key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size) - # # self.prefix_enc_outputs = self.attention(prefix_enc_outputs.permute(1,0,2), prefix_enc_outputs.permute(1,0,2), prefix_enc_outputs.permute(1,0,2))[0].permute(1,0,2)[:, :self.config.prefix_length, :].to(self.device) # for padded prefix - # # self.prefix_enc_outputs = self.linear(self.prefix_enc_outputs) - # # # for value-attended enc output vs enc output cos sim - # # prefix_enc_outputs = torch.cat([key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size) - # # self.prefix_enc_outputs = self.attention(prefix_enc_outputs.permute(1,0,2), prefix_enc_outputs.permute(1,0,2), prefix_enc_outputs.permute(1,0,2))[0].permute(1,0,2)[:, self.config.prefix_length:, :].to(self.device) # for padded prefix - # # self.prefix_enc_outputs = self.linear(self.prefix_enc_outputs) - # # for key-attended enc output vs enc output cos sim - # prefix_enc_outputs_key = torch.cat([key_values_1[0][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size) - # prefix_enc_outputs_val = torch.cat([key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size) - # self.prefix_enc_outputs = self.attention(self.enc_outputs.permute(1,0,2), prefix_enc_outputs_key.permute(1,0,2), prefix_enc_outputs_val.permute(1,0,2))[0].permute(1,0,2) #[:, self.config.prefix_length:, :].to(self.device) # for padded prefix - # self.prefix_enc_outputs = self.linear(self.prefix_enc_outputs) + if self.config.prefix_type == "attention0": + # for cos sim based p-tuning + batch_size = key_values_3.size()[0] + prefix_enc_outputs_key = torch.cat([key_values_1[0][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size) + prefix_enc_outputs_val = torch.cat([key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size) + self.prefix_enc_outputs = self.attention(self.enc_outputs.permute(1,0,2), prefix_enc_outputs_key.permute(1,0,2), prefix_enc_outputs_val.permute(1,0,2))[0].permute(1,0,2) #[:, self.config.prefix_length:, :].to(self.device) # for padded prefix + self.prefix_enc_outputs = self.linear(self.prefix_enc_outputs) - # for enc output sub - key_values_1_0 = key_values_1[0].clone() - key_values_1_1 = key_values_1[1].clone() - # for enc0 substitution - # key_values_1_1[:,0,:,:,:] = self.enc_inputs.reshape(key_values_3.size()[0], self.prefix_length, self.n_heads, self.n_embd_per_head) - # for enc5 substitution - # key_values_1_0[:,5,:,:,:] = self.enc_outputs.reshape(key_values_3.size()[0], self.prefix_length, self.n_heads, self.n_embd_per_head) - key_values_1_1[:,5,:,:,:] = self.enc_outputs.reshape(key_values_3.size()[0], self.prefix_length, self.n_heads, self.n_embd_per_head) - key_values_1 = (key_values_1_0, key_values_1_1) - - # # self.prefix_enc_outputs = self.linear(key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1)).to(self.device) - # # self.prefix_enc_outputs_1 = self.linear_2(key_values_1[1][:,0,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1)).to(self.device) - # # self.prefix_enc_outputs_2 = self.linear_1(key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1)).to(self.device) + elif self.config.prefix_type == "attention0_direct": + # for enc output sub + key_values_1_0 = key_values_1[0].clone() + key_values_1_1 = key_values_1[1].clone() + key_values_1_1[:,5,:,:,:] = self.enc_outputs.reshape(key_values_3.size()[0], self.prefix_length, self.n_heads, self.n_embd_per_head) + key_values_1 = (key_values_1_0, key_values_1_1) + + # # self.prefix_enc_outputs = self.linear(key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1)).to(self.device) + # # self.prefix_enc_outputs_1 = self.linear_2(key_values_1[1][:,0,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1)).to(self.device) + # # self.prefix_enc_outputs_2 = self.linear_1(key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1)).to(self.device) key_values_2 = key_values_2.permute(0, 2, 1, 3, 4).split(self.n_layers, dim = 1) key_values_3 = key_values_3.permute(0, 2, 1, 3, 4).split(self.n_layers, dim = 1) @@ -344,13 +298,6 @@ def process_pip_data(self, src_sents, src_synts, tgt_synts, tgt_sents=None): prefix_dict['cross_prefix'] = key_values_2 prefix_dict['decoder_prefix'] = key_values_3 - # for decoder output cos sim - # dec_outputs = self.model.model(prefix_inputs, prefix=prefix_dict) - # self.dec_outputs = dec_outputs.last_hidden_state.to(self.device) - # prefix_dec_outputs = torch.cat([key_values_3[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.dec_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size) - # self.prefix_dec_outputs = self.attention(prefix_dec_outputs.permute(1,0,2), prefix_dec_outputs.permute(1,0,2), prefix_dec_outputs.permute(1,0,2))[0].permute(1,0,2)[:, :self.config.prefix_length, :].to(self.device) # for padded prefix - # self.prefix_dec_outputs = self.linear(self.prefix_dec_outputs) - elif self.config.prefix_type == "ptuning": prefix_inputs = self.prefix_ids.expand(len(prefix_inputs), -1) # for prefix tuning prefix_1 = self.wte_1(prefix_inputs) # (batch_size, prefix_length, input_size) @@ -471,19 +418,8 @@ def forward(self, enc_idxs, enc_attn, dec_idxs, dec_attn, lbl_idxs, prefix_dict) return_dict=True) if self.config.prefix_type == "attention0": - # # sim1 = self.prefix_criterion(F.log_softmax(self.prefix_enc_outputs/ 1, dim=1), F.softmax(self.enc_outputs/ 1, dim=1)) - # # for enc input cos sim - # # sim1 = torch.mean(1 - torch.abs(self.prefix_criterion(self.prefix_enc_inputs, self.enc_inputs))).to(outputs['loss'].device) - # # for enc output cos sim - # sim1 = torch.mean(1 - torch.abs(self.prefix_criterion(self.prefix_enc_outputs, self.enc_outputs))).to(outputs['loss'].device) - # # for dec output cos sim - # # sim1 = torch.mean(1 - torch.abs(self.prefix_criterion(self.prefix_dec_outputs, self.dec_outputs))).to(outputs['loss'].device) - - # # sim2 = torch.abs(torch.mean(self.prefix_criterion(self.prefix_enc_outputs_2, self.enc_outputs_2))) - # # loss = outputs['loss'] + sim1.to(outputs['loss'].device) + sim2.to(outputs['loss'].device) - loss = outputs['loss'] # + self.mu * sim1 - # loss = outputs['loss'] + sim1.to(outputs['loss'].device) - + sim1 = torch.mean(1 - torch.abs(self.prefix_criterion(self.prefix_enc_outputs, self.enc_outputs))).to(outputs['loss'].device) + loss = outputs['loss'] + self.mu * sim1 else: loss = outputs['loss']