Skip to content

Commit

Permalink
Fix get_sample_start_indexes
Browse files Browse the repository at this point in the history
Signed-off-by: Mynhardt Burger <[email protected]>
  • Loading branch information
mynhardtburger committed Mar 9, 2024
1 parent c8b2a8f commit 11d6135
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
10 changes: 5 additions & 5 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,12 +686,12 @@ def get_sample_start_indexes(tokenized: BatchEncoding) -> List[int]:

# note: tokenized["overflow_to_sample_mapping"] is a torch.Tensor

samples_start_idx: Dict[int, int] = {}
for i, sample_idx in enumerate(tokenized["overflow_to_sample_mapping"]):
if sample_idx not in samples_start_idx:
samples_start_idx[sample_idx] = i
samples_start_indexes: Dict[int, int] = {}
for i, sample in enumerate(tokenized["overflow_to_sample_mapping"]):
if int(sample) not in samples_start_indexes:
samples_start_indexes[int(sample)] = i

return list(samples_start_idx.values())
return list(samples_start_indexes.values())


class TruncateCountBehavior(Enum):
Expand Down
20 changes: 14 additions & 6 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.backends import mps
import numpy as np
import pytest
import torch

# First Party
from caikit.core import ModuleConfig
Expand Down Expand Up @@ -852,12 +853,17 @@ def test_sum_token_count_no_truncation(texts, expected_count, loaded_model):
[
# Only tokens requiring model attention is counted.
# [PAD] doesn't attract model attention, but [CLS] and [SEP] does
# [CLS] 5 normal [SEP]
#
# All encodings: [CLS] 12345 [SEP]
# No truncation
(["12345"], 10, 7),
# [CLS] 3 normal [SEP] + [CLS] 2 normal [SEP] [PAD]
(["12345"], 5, 5 + 4),
# [CLS] 3 normal [SEP] + [CLS] 2 normal [SEP] [PAD], [CLS] 3 normal [SEP] + [CLS] 1 normal [SEP] [PAD] [PAD]
(["12 345", "6 789"], 5, 9 + 8),
# All encodings: [CLS] 123 [SEP] + [CLS] 45 [SEP] [PAD]
# Only truncated: [CLS] 123 [SEP]
(["12345"], 5, 3 + 2),
#
# All encodings: [CLS] 123 [SEP] + [CLS] 45 [SEP] [PAD], [CLS] 678 [SEP] + [CLS] 9 [SEP] [PAD] [PAD]
# Only truncated: [CLS] 123 [SEP] , [CLS] 678 [SEP]
(["12 345", "6 789"], 5, (3 + 2) + (3 + 2)),
],
)
def test_sum_token_count_with_truncation(texts, truncate, expected_count, loaded_model):
Expand Down Expand Up @@ -913,5 +919,7 @@ def test_encoding_order(loaded_model: EmbeddingModule):
],
)
def test_get_sample_start_indexes(mapping, expected):
mock_tokenized = {"overflow_to_sample_mapping": mapping}
mock_tokenized = {
"overflow_to_sample_mapping": torch.Tensor(mapping).type(torch.int8)
}
assert get_sample_start_indexes(mock_tokenized) == expected # type: ignore

0 comments on commit 11d6135

Please sign in to comment.