Skip to content

Commit

Permalink
Handle drop_last=False in step calculation
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks committed Sep 12, 2023
1 parent dc7044e commit e39e57e
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion caikit_nlp/modules/text_generation/text_generation_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,8 +589,13 @@ def infer_max_steps(
data_len = 0
for _ in training_dataset:
data_len += 1
# Figure out how many batches we'll have per epoch; we assume drop_last=True for now
# Figure out how many batches we'll have per epoch
num_batches = data_len // batch_size
# Assume drop_last=False; in general, this doesn't really matter.
# We mostly do this to avoid strange behavior when the dataset
# size is smaller than the batch size.
if num_batches != (data_len * batch_size):
num_batches += 1
num_steps = num_batches * num_epochs
log.debug("Number of inferred steps: [%s]", num_steps)
return num_steps
Expand Down

0 comments on commit e39e57e

Please sign in to comment.