Skip to content

Commit

Permalink
🧑‍💻 Update fine-tuning example to allow local save of the model and f…
Browse files Browse the repository at this point in the history
…lexible target seq len

Signed-off-by: gkumbhat <[email protected]>
  • Loading branch information
gkumbhat committed Aug 14, 2023
1 parent 9739ecd commit 8ca5da9
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions examples/run_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None:
print_colored("\n".join([print_str for print_str in print_strs if print_str]))


def get_model_preds_and_references(model, validation_stream, truncate_input_tokens):
def get_model_preds_and_references(model, validation_stream, truncate_input_tokens, max_new_tokens):
"""Given a model & a validation stream, run the model against every example in the validation
stream and compare the outputs to the target/output sequence.
Expand All @@ -256,7 +256,7 @@ def get_model_preds_and_references(model, validation_stream, truncate_input_toke
# Local .run() currently prepends the input text to the generated string;
# Ensure that we're just splitting the first predicted token & beyond.
raw_model_text = model.run(
datum.input, truncate_input_tokens=truncate_input_tokens
datum.input, truncate_input_tokens=truncate_input_tokens, max_new_tokens=max_new_tokens
).generated_text
parse_pred_text = raw_model_text.split(datum.input)[-1].strip()
model_preds.append(parse_pred_text)
Expand Down Expand Up @@ -333,10 +333,10 @@ def export_model_preds(preds_file, predictions, validation_stream):

print("Generated text: ", prediction_results)

if args.tgis:
# Saving model
model.save(args.output_dir)

# Saving model
model.save(args.output_dir)
if args.tgis:

# Load model in TGIS
# HACK: export args.output_dir as MODEL_NAME for TGIS
Expand All @@ -360,7 +360,7 @@ def export_model_preds(preds_file, predictions, validation_stream):
print_colored("Getting model predictions...")
truncate_input_tokens = args.max_source_length + args.max_target_length
predictions, references = get_model_preds_and_references(
loaded_model, validation_stream, truncate_input_tokens
loaded_model, validation_stream, truncate_input_tokens, args.max_target_length
)

export_model_preds(args.preds_file, predictions, validation_stream)
Expand All @@ -370,4 +370,4 @@ def export_model_preds(preds_file, predictions, validation_stream):

for metric_func in metric_funcs:
metric_res = metric_func(predictions=predictions, references=references)
print_colored(metric_res)
print_colored(metric_res)

0 comments on commit 8ca5da9

Please sign in to comment.