Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
heiheiyoyo committed Oct 2, 2022
2 parents e6f4fd7 + 3725fd9 commit 669a86e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,11 @@ def extract_and_convert(input_dir, output_dir):
open(os.path.join(input_dir, 'model_state.pdparams'), 'rb'))
del paddle_paddle_params['StructuredToParameterName@@']
for weight_name, weight_value in paddle_paddle_params.items():
transposed = ''
if 'weight' in weight_name:
if 'encoder.encoder' in weight_name or 'pooler' in weight_name or 'linear' in weight_name:
if '.encoder' in weight_name or 'pooler' in weight_name or 'linear' in weight_name:
weight_value = weight_value.transpose()
transposed = '.T'
# Fix: embedding error
if 'word_embeddings.weight' in weight_name:
weight_value[0, :] = 0
Expand All @@ -322,7 +324,7 @@ def extract_and_convert(input_dir, output_dir):
continue
state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value)
logger.info(
f"{weight_name} -> {weight_map[weight_name]} {weight_value.shape}")
f"{weight_name}{transposed} -> {weight_map[weight_name]} {weight_value.shape}")
torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))


Expand Down

0 comments on commit 669a86e

Please sign in to comment.