diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 73d3e77f..14132f96 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -687,9 +687,10 @@ def get_sample_start_indexes(tokenized: BatchEncoding) -> List[int]: # note: tokenized["overflow_to_sample_mapping"] is a torch.Tensor samples_start_indexes: Dict[int, int] = {} - for i, sample in enumerate(tokenized["overflow_to_sample_mapping"]): - if int(sample) not in samples_start_indexes: - samples_start_indexes[int(sample)] = i + for i, tensor_sample in enumerate(tokenized["overflow_to_sample_mapping"]): + int_sample = int(tensor_sample) + if int_sample not in samples_start_indexes: + samples_start_indexes[int_sample] = i return list(samples_start_indexes.values())