Skip to content

Commit

Permalink
construct torch array always on right device
Browse files Browse the repository at this point in the history
  • Loading branch information
JoaoLages committed Jan 3, 2022
1 parent 666fc82 commit 4566440
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/ecco/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,12 +722,13 @@ def sample_output_token(scores, do_sample, temperature, top_k, top_p):
return prediction_id


def _one_hot(token_ids, vocab_size):
return torch.zeros(len(token_ids), vocab_size).scatter_(1, token_ids.unsqueeze(1), 1.)
def _one_hot(token_ids: torch.Tensor, vocab_size: int) -> torch.Tensor:
return torch.zeros(len(token_ids), vocab_size, device=token_ids.device).scatter_(1, token_ids.unsqueeze(1), 1.)

def _one_hot_batched(token_ids, vocab_size):

def _one_hot_batched(token_ids: torch.Tensor, vocab_size: int) -> torch.Tensor:
batch_size, num_tokens = token_ids.shape
return torch.zeros(batch_size, num_tokens, vocab_size).scatter_(-1, token_ids.unsqueeze(-1), 1.)
return torch.zeros(batch_size, num_tokens, vocab_size, device=token_ids.device).scatter_(-1, token_ids.unsqueeze(-1), 1.)


def activations_dict_to_array(activations_dict):
Expand Down

0 comments on commit 4566440

Please sign in to comment.