Skip to content
This repository has been archived by the owner on Aug 11, 2022. It is now read-only.

Commit

Permalink
fix device
Browse files Browse the repository at this point in the history
  • Loading branch information
vpj committed May 3, 2022
1 parent 821f247 commit 426cb87
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/neox/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def run_eval(self, name: str, eval_tasks: List[str]):
return results


def run_eval_harness(model: nn.Module, name: str, eval_tasks: List[str], batch_size: int = 8):
def run_eval_harness(model: nn.Module, name: str, eval_tasks: List[str], device: torch.device, batch_size: int = 8):
"""
## Run evaluation harness with a given model
"""
Expand All @@ -237,7 +237,7 @@ def run_eval_harness(model: nn.Module, name: str, eval_tasks: List[str], batch_s
tokenizer = Tokenizer.from_file(str(vocab_file))

# Create the adapter
adapter = EvalHarnessAdapter(model, tokenizer, 50_432, batch_size, torch.device('cpu'))
adapter = EvalHarnessAdapter(model, tokenizer, 50_432, batch_size, device)

# Run
return adapter.run_eval(name, eval_tasks)
2 changes: 1 addition & 1 deletion src/neox/evaluation/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ def forward(self, x: torch.Tensor):


if __name__ == '__main__':
print(run_eval_harness(DummyModel(50_432), 'dummy', ['lambada']))
print(run_eval_harness(DummyModel(50_432), 'dummy', ['lambada'], torch.device('cpu')))
2 changes: 1 addition & 1 deletion src/neox/evaluation/half_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
with monit.section('Sequential'):
model = nn.Sequential(*layers).half().to(torch.device('cuda:0'))

print(run_eval_harness(model, 'half_precision', []))
print(run_eval_harness(model, 'half_precision', [], torch.device('cuda:0')))
2 changes: 1 addition & 1 deletion src/neox/evaluation/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
devices=devices,
chunks=4)

print(run_eval_harness(model, 'pipeline_parallel', []))
print(run_eval_harness(model, 'pipeline_parallel', [], torch.device('cuda:0')))

0 comments on commit 426cb87

Please sign in to comment.