From 942826f2be469cf593c88e3cd7fe411fe0cd4777 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Sat, 9 Mar 2024 10:36:13 -0500 Subject: [PATCH] Readability updates for get_sample_start_indexes Signed-off-by: Mynhardt Burger --- caikit_nlp/modules/text_embedding/embedding.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 73d3e77f..14132f96 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -687,9 +687,10 @@ def get_sample_start_indexes(tokenized: BatchEncoding) -> List[int]: # note: tokenized["overflow_to_sample_mapping"] is a torch.Tensor samples_start_indexes: Dict[int, int] = {} - for i, sample in enumerate(tokenized["overflow_to_sample_mapping"]): - if int(sample) not in samples_start_indexes: - samples_start_indexes[int(sample)] = i + for i, tensor_sample in enumerate(tokenized["overflow_to_sample_mapping"]): + int_sample = int(tensor_sample) + if int_sample not in samples_start_indexes: + samples_start_indexes[int_sample] = i return list(samples_start_indexes.values())