Skip to content

Commit

Permalink
🐛 Fix training arguments for service generation to work correctly
Browse files Browse the repository at this point in the history
Signed-off-by: gkumbhat <[email protected]>
  • Loading branch information
gkumbhat committed Aug 9, 2023
1 parent 8ece6e3 commit c8868b0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions caikit_nlp/modules/text_generation/text_generation_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def train(
lr: float = 2e-5,
# Directory where model predictions and checkpoints will be written
checkpoint_dir: str = "/tmp",
**training_arguments,
):
**kwargs,
) -> "TextGeneration":
"""
Fine-tune a CausalLM or Seq2seq text generation model.
Expand Down Expand Up @@ -177,7 +177,7 @@ def train(
Learning rate to be used while tuning model. Default: 2e-5.
checkpoint_dir: str
Directory where model predictions and checkpoints will be written
**training_arguments:
**kwargs:
Arguments supported by HF Training Arguments.
TrainingArguments:
https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.TrainingArguments
Expand Down Expand Up @@ -274,7 +274,7 @@ def train(
"eval_accumulation_steps": accumulate_steps,
# eval_steps=1,
# load_best_model_at_end
**training_arguments,
**kwargs,
**dtype_based_params,
}

Expand Down

0 comments on commit c8868b0

Please sign in to comment.