Skip to content

Commit

Permalink
Readability updates for get_sample_start_indexes
Browse files Browse the repository at this point in the history
Signed-off-by: Mynhardt Burger <[email protected]>
  • Loading branch information
mynhardtburger committed Mar 9, 2024
1 parent e89a84d commit 942826f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,9 +687,10 @@ def get_sample_start_indexes(tokenized: BatchEncoding) -> List[int]:
# note: tokenized["overflow_to_sample_mapping"] is a torch.Tensor

samples_start_indexes: Dict[int, int] = {}
for i, sample in enumerate(tokenized["overflow_to_sample_mapping"]):
if int(sample) not in samples_start_indexes:
samples_start_indexes[int(sample)] = i
for i, tensor_sample in enumerate(tokenized["overflow_to_sample_mapping"]):
int_sample = int(tensor_sample)
if int_sample not in samples_start_indexes:
samples_start_indexes[int_sample] = i

return list(samples_start_indexes.values())

Expand Down

0 comments on commit 942826f

Please sign in to comment.