Skip to content

Commit

Permalink
modified model file
Browse files Browse the repository at this point in the history
  • Loading branch information
Yixin Wan [email protected] committed Jul 9, 2023
1 parent c06bd5e commit be7a8b7
Showing 1 changed file with 50 additions and 114 deletions.
164 changes: 50 additions & 114 deletions src/pip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, config, tokenizer, device, debug=True):
self.config = config
self.tokenizer = tokenizer
self.parse_vocab_size = 83
self.input_size = 768 # Bert Hidden size
self.input_size = 768 # Bart Hidden size

self.bart_config = BartConfig.from_pretrained('facebook/bart-base')

Expand All @@ -38,63 +38,68 @@ def __init__(self, config, tokenizer, device, debug=True):
self.n_embd_per_head = self.input_size // self.n_heads
self.prefix_length = self.config.prefix_length

if self.config.prefix_type in ["attention0", "ptuning"]:
# # self.attention = nn.MultiheadAttention(embed_dim = 1, num_heads = 1)
# self.attention = nn.MultiheadAttention(embed_dim = self.input_size, num_heads = self.n_heads)
self.register_buffer("prefix_ids", torch.arange(self.config.prefix_length).expand((1, -1))) # (1, prefix_len)
self.wte_1 = nn.Embedding(self.config.prefix_length, self.input_size)
self.wte_2 = nn.Embedding(self.config.prefix_length, self.input_size)
self.wte_3 = nn.Embedding(self.config.prefix_length, self.input_size)
# self.wte_1 = nn.Embedding(self.parse_vocab_size, self.input_size)
# self.wte_2 = nn.Embedding(self.parse_vocab_size, self.input_size)
# self.wte_3 = nn.Embedding(self.parse_vocab_size, self.input_size)

if self.config.prefix_type == "attention0":
# # self.linear = nn.Linear(self.prefix_length, self.prefix_length)
# # self.linear_1 = nn.Linear(self.input_size, self.input_size)
# # self.linear_2 = nn.Linear(self.input_size, self.input_size)
# self.linear = nn.Linear(self.input_size, self.input_size)
# # self.mu = nn.Parameter(torch.Tensor(1),requires_grad=True)
# self.mu = 1
# initialize prefix
self.register_buffer("prefix_ids", torch.arange(self.config.prefix_length).expand((1, -1))) # (1, prefix_len)
self.wte_1 = nn.Embedding(self.config.prefix_length, self.input_size)
self.wte_2 = nn.Embedding(self.config.prefix_length, self.input_size)
self.wte_3 = nn.Embedding(self.config.prefix_length, self.input_size)

if self.config.prefix_type in ["attention0", "attention0_direct"]:
if self.config.prefix_type == "attention0":
self.attention = nn.MultiheadAttention(embed_dim = self.input_size, num_heads = self.n_heads)
self.linear = nn.Linear(self.input_size, self.input_size)
self.mu = 1

self.control_trans_1 = nn.Sequential(
nn.Linear(self.input_size, self.prefix_config.bottleneck_size),
nn.Tanh(),
# nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # original
nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec
)
self.control_trans_2 = nn.Sequential(
nn.Linear(self.input_size, self.prefix_config.bottleneck_size),
nn.Tanh(),
# nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # original
nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec
)
self.control_trans_3 = nn.Sequential(
nn.Linear(self.input_size, self.prefix_config.bottleneck_size),
nn.Tanh(),
# nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # original
nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec
)

elif self.config.prefix_type == "ptuning":
self.control_trans_1 = nn.Sequential(
nn.Linear(self.input_size, self.prefix_config.bottleneck_size),
nn.Tanh(),
nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec
)
self.control_trans_2 = nn.Sequential(
nn.Linear(self.input_size, self.prefix_config.bottleneck_size),
nn.Tanh(),
nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec
)
self.control_trans_3 = nn.Sequential(
nn.Linear(self.input_size, self.prefix_config.bottleneck_size),
nn.Tanh(),
nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec
)

elif self.config.prefix_type == "ptuning_large":
self.control_trans_1 = nn.Sequential(
nn.Linear(self.input_size, self.prefix_config.bottleneck_size),
nn.Tanh(),
nn.Linear(self.prefix_config.bottleneck_size, 2 * self.prefix_config.bottleneck_size),
# nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # original
nn.Linear(2 * self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec
)
self.control_trans_2 = nn.Sequential(
nn.Linear(self.input_size, self.prefix_config.bottleneck_size),
nn.Tanh(),
nn.Linear(self.prefix_config.bottleneck_size, 2 * self.prefix_config.bottleneck_size),
# nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # original
nn.Linear(2 * self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec
)
self.control_trans_3 = nn.Sequential(
nn.Linear(self.input_size, self.prefix_config.bottleneck_size),
nn.Tanh(),
nn.Linear(self.prefix_config.bottleneck_size, 2 * self.prefix_config.bottleneck_size),
# nn.Linear(self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # original
nn.Linear(2 * self.prefix_config.bottleneck_size, self.n_layers * 2 * self.input_size), # enc+dec
)

Expand Down Expand Up @@ -222,47 +227,16 @@ def process_pip_data(self, src_sents, src_synts, tgt_synts, tgt_sents=None):

# prefix_inputs = self.prefix_tokenizer(prefix_inputs, return_tensors='pt', padding=True) # (bsz, seq_len) # orginal
if self.config.prefix_type == "attention0":
# for ptuning-based
# prefix_inputs = self.get_synt_tok(prefix_inputs, self.prefix_tok_word2idx) # (bsz, prefix_len)
# prefix_inputs = self.prefix_tokenizer(prefix_inputs, return_tensors='pt', padding=True) # (bsz, seq_len) # orginal

# for cossim-based p tuning
prefix_inputs = self.prefix_tokenizer(prefix_inputs, return_tensors='pt', max_length = self.config.prefix_length, padding='max_length')
prefix_inputs = prefix_inputs['input_ids'].to(self.device)

# self.enc_outputs_1 = enc_outputs.hidden_states[0].to(self.device)
# self.enc_outputs_2 = enc_outputs.last_hidden_state.to(self.device)

# for encoder input cos sim
# enc_outputs = self.model.model.encoder(prefix_inputs,output_hidden_states=True)
# self.enc_inputs = enc_outputs.hidden_states[0].to(self.device)
# for encoder output cos sim
enc_outputs = self.model.model.encoder(prefix_inputs,output_hidden_states=True)
self.enc_outputs = enc_outputs.last_hidden_state.to(self.device)

if self.config.prefix_type not in ["attention0", "ptuning"]:
prefix_inputs = self.model.model.encoder(prefix_inputs)["last_hidden_state"]

if self.config.prefix_type == "attention0":
# original
# prefix_inputs = torch.cat([self.prefix_ids.expand(prefix_inputs.size()[0], -1), prefix_inputs], dim=1) # (batch_size, prefix_length + seq_length, input_size)
# prefix_inputs = self.attention(prefix_inputs.unsqueeze(-1).permute(1,0,2), prefix_inputs.unsqueeze(-1).permute(1,0,2), prefix_inputs.unsqueeze(-1).permute(1,0,2))[0].permute(1,0,2)[:, :self.config.prefix_length, :].squeeze(2)
# prefix_1 = self.wte_1(self.prefix_ids.expand(prefix_inputs.size()[0], -1)) # (batch_size, prefix_length, input_size)
# prefix_2 = self.wte_2(self.prefix_ids.expand(prefix_inputs.size()[0], -1)) # (batch_size, prefix_length, input_size)
# prefix_3 = self.wte_3(self.prefix_ids.expand(prefix_inputs.size()[0], -1)) # (batch_size, prefix_length, input_size)

# 1 for mu
# prefix_inputs = (1 - self.mu) * self.prefix_ids.expand(prefix_inputs.size()[0], -1) + self.mu * prefix_inputs # original
# 2
if self.config.prefix_type in ["attention0", "attention0_direct"]:
prefix_inputs = self.prefix_ids.expand(len(prefix_inputs), -1) # for prefix tuning
# 3 for linear
# prefix_inputs = self.prefix_ids.expand(prefix_inputs.size()[0], -1) + self.linear(prefix_inputs)
# prefix_inputs = prefix_inputs - torch.min(prefix_inputs)
# prefix_inputs = prefix_inputs * (self.prefix_length - 1) // torch.max(prefix_inputs)
# 4 for cross similarity
# prefix_inputs = self.attention(prefix_inputs.unsqueeze(-1).permute(1,0,2).to(torch.float32), self.prefix_ids.expand(prefix_inputs.size()[0], -1).unsqueeze(-1).permute(1,0,2).to(torch.float32), self.prefix_ids.expand(prefix_inputs.size()[0], -1).unsqueeze(-1).permute(1,0,2).to(torch.float32))[0].permute(1,0,2).squeeze(2)
# prefix_inputs = prefix_inputs - torch.min(prefix_inputs)
# prefix_inputs = prefix_inputs * (self.prefix_length - 1) // torch.max(prefix_inputs)

prefix_inputs = prefix_inputs.to(torch.long)
# print('prefix', prefix_inputs[0])
Expand Down Expand Up @@ -298,44 +272,24 @@ def process_pip_data(self, src_sents, src_synts, tgt_synts, tgt_sents=None):
# For regular (only enc) prefix
key_values_1 = key_values_1.permute(0, 2, 1, 3, 4).split(self.n_layers, dim = 1)

# # for cos sim based p-tuning
# batch_size = key_values_3.size()[0]
# # for modified cosine similarity
# # for enc input value cos sim
# # prefix_enc_inputs = torch.cat([key_values_1[1][:,0,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_inputs], dim=1) # (batch_size, prefix_length + seq_length, input_size)
# # self.prefix_enc_inputs = self.attention(prefix_enc_inputs.permute(1,0,2), prefix_enc_inputs.permute(1,0,2), prefix_enc_inputs.permute(1,0,2))[0].permute(1,0,2)[:, :self.config.prefix_length, :].to(self.device) # for padded prefix
# # self.prefix_enc_inputs = self.linear(self.prefix_enc_inputs)
# # for enc input key cos sim
# # prefix_enc_inputs = torch.cat([key_values_1[0][:,0,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_inputs], dim=1) # (batch_size, prefix_length + seq_length, input_size)
# # self.prefix_enc_inputs = self.attention(prefix_enc_inputs.permute(1,0,2), prefix_enc_inputs.permute(1,0,2), prefix_enc_inputs.permute(1,0,2))[0].permute(1,0,2)[:, :self.config.prefix_length, :].to(self.device) # for padded prefix
# # self.prefix_enc_inputs = self.linear(self.prefix_enc_inputs)
# # for enc output cos sim
# # prefix_enc_outputs = torch.cat([key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size)
# # self.prefix_enc_outputs = self.attention(prefix_enc_outputs.permute(1,0,2), prefix_enc_outputs.permute(1,0,2), prefix_enc_outputs.permute(1,0,2))[0].permute(1,0,2)[:, :self.config.prefix_length, :].to(self.device) # for padded prefix
# # self.prefix_enc_outputs = self.linear(self.prefix_enc_outputs)
# # # for value-attended enc output vs enc output cos sim
# # prefix_enc_outputs = torch.cat([key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size)
# # self.prefix_enc_outputs = self.attention(prefix_enc_outputs.permute(1,0,2), prefix_enc_outputs.permute(1,0,2), prefix_enc_outputs.permute(1,0,2))[0].permute(1,0,2)[:, self.config.prefix_length:, :].to(self.device) # for padded prefix
# # self.prefix_enc_outputs = self.linear(self.prefix_enc_outputs)
# # for key-attended enc output vs enc output cos sim
# prefix_enc_outputs_key = torch.cat([key_values_1[0][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size)
# prefix_enc_outputs_val = torch.cat([key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size)
# self.prefix_enc_outputs = self.attention(self.enc_outputs.permute(1,0,2), prefix_enc_outputs_key.permute(1,0,2), prefix_enc_outputs_val.permute(1,0,2))[0].permute(1,0,2) #[:, self.config.prefix_length:, :].to(self.device) # for padded prefix
# self.prefix_enc_outputs = self.linear(self.prefix_enc_outputs)
if self.config.prefix_type == "attention0":
# for cos sim based p-tuning
batch_size = key_values_3.size()[0]
prefix_enc_outputs_key = torch.cat([key_values_1[0][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size)
prefix_enc_outputs_val = torch.cat([key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.enc_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size)
self.prefix_enc_outputs = self.attention(self.enc_outputs.permute(1,0,2), prefix_enc_outputs_key.permute(1,0,2), prefix_enc_outputs_val.permute(1,0,2))[0].permute(1,0,2) #[:, self.config.prefix_length:, :].to(self.device) # for padded prefix
self.prefix_enc_outputs = self.linear(self.prefix_enc_outputs)

# for enc output sub
key_values_1_0 = key_values_1[0].clone()
key_values_1_1 = key_values_1[1].clone()
# for enc0 substitution
# key_values_1_1[:,0,:,:,:] = self.enc_inputs.reshape(key_values_3.size()[0], self.prefix_length, self.n_heads, self.n_embd_per_head)
# for enc5 substitution
# key_values_1_0[:,5,:,:,:] = self.enc_outputs.reshape(key_values_3.size()[0], self.prefix_length, self.n_heads, self.n_embd_per_head)
key_values_1_1[:,5,:,:,:] = self.enc_outputs.reshape(key_values_3.size()[0], self.prefix_length, self.n_heads, self.n_embd_per_head)
key_values_1 = (key_values_1_0, key_values_1_1)

# # self.prefix_enc_outputs = self.linear(key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1)).to(self.device)
# # self.prefix_enc_outputs_1 = self.linear_2(key_values_1[1][:,0,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1)).to(self.device)
# # self.prefix_enc_outputs_2 = self.linear_1(key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1)).to(self.device)
elif self.config.prefix_type == "attention0_direct":
# for enc output sub
key_values_1_0 = key_values_1[0].clone()
key_values_1_1 = key_values_1[1].clone()
key_values_1_1[:,5,:,:,:] = self.enc_outputs.reshape(key_values_3.size()[0], self.prefix_length, self.n_heads, self.n_embd_per_head)
key_values_1 = (key_values_1_0, key_values_1_1)

# # self.prefix_enc_outputs = self.linear(key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1)).to(self.device)
# # self.prefix_enc_outputs_1 = self.linear_2(key_values_1[1][:,0,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1)).to(self.device)
# # self.prefix_enc_outputs_2 = self.linear_1(key_values_1[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1)).to(self.device)

key_values_2 = key_values_2.permute(0, 2, 1, 3, 4).split(self.n_layers, dim = 1)
key_values_3 = key_values_3.permute(0, 2, 1, 3, 4).split(self.n_layers, dim = 1)
Expand All @@ -344,13 +298,6 @@ def process_pip_data(self, src_sents, src_synts, tgt_synts, tgt_sents=None):
prefix_dict['cross_prefix'] = key_values_2
prefix_dict['decoder_prefix'] = key_values_3

# for decoder output cos sim
# dec_outputs = self.model.model(prefix_inputs, prefix=prefix_dict)
# self.dec_outputs = dec_outputs.last_hidden_state.to(self.device)
# prefix_dec_outputs = torch.cat([key_values_3[1][:,5,:,:,:].squeeze(1).reshape(batch_size, self.prefix_length, -1), self.dec_outputs], dim=1) # (batch_size, prefix_length + seq_length, input_size)
# self.prefix_dec_outputs = self.attention(prefix_dec_outputs.permute(1,0,2), prefix_dec_outputs.permute(1,0,2), prefix_dec_outputs.permute(1,0,2))[0].permute(1,0,2)[:, :self.config.prefix_length, :].to(self.device) # for padded prefix
# self.prefix_dec_outputs = self.linear(self.prefix_dec_outputs)

elif self.config.prefix_type == "ptuning":
prefix_inputs = self.prefix_ids.expand(len(prefix_inputs), -1) # for prefix tuning
prefix_1 = self.wte_1(prefix_inputs) # (batch_size, prefix_length, input_size)
Expand Down Expand Up @@ -471,19 +418,8 @@ def forward(self, enc_idxs, enc_attn, dec_idxs, dec_attn, lbl_idxs, prefix_dict)
return_dict=True)

if self.config.prefix_type == "attention0":
# # sim1 = self.prefix_criterion(F.log_softmax(self.prefix_enc_outputs/ 1, dim=1), F.softmax(self.enc_outputs/ 1, dim=1))
# # for enc input cos sim
# # sim1 = torch.mean(1 - torch.abs(self.prefix_criterion(self.prefix_enc_inputs, self.enc_inputs))).to(outputs['loss'].device)
# # for enc output cos sim
# sim1 = torch.mean(1 - torch.abs(self.prefix_criterion(self.prefix_enc_outputs, self.enc_outputs))).to(outputs['loss'].device)
# # for dec output cos sim
# # sim1 = torch.mean(1 - torch.abs(self.prefix_criterion(self.prefix_dec_outputs, self.dec_outputs))).to(outputs['loss'].device)

# # sim2 = torch.abs(torch.mean(self.prefix_criterion(self.prefix_enc_outputs_2, self.enc_outputs_2)))
# # loss = outputs['loss'] + sim1.to(outputs['loss'].device) + sim2.to(outputs['loss'].device)
loss = outputs['loss'] # + self.mu * sim1
# loss = outputs['loss'] + sim1.to(outputs['loss'].device)

sim1 = torch.mean(1 - torch.abs(self.prefix_criterion(self.prefix_enc_outputs, self.enc_outputs))).to(outputs['loss'].device)
loss = outputs['loss'] + self.mu * sim1
else:
loss = outputs['loss']

Expand Down

0 comments on commit be7a8b7

Please sign in to comment.