Skip to content

Commit

Permalink
check it again
Browse files Browse the repository at this point in the history
  • Loading branch information
thejaminator committed Apr 27, 2023
1 parent 72fc632 commit 1fca929
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
11 changes: 11 additions & 0 deletions elk/utils/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,17 @@ def map(
"""
return list(self.imap(closure, dataset))


def map_for_non_fsdp(
self,
closure: Callable[[ModelOutput], A],
dataset: Dataset,
) -> list[A]:
"""Run inference on the given inputs, running a closure on the outputs.
Note that the order of the outputs is not guaranteed to match
"""
return list(self.imap_for_non_fsdp(closure, dataset))

def one(
self,
dataset: Dataset,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_inference_server_fsdp_one():
)


def test_inference_server_fsdp_two():
def test_inference_server_fsdp_other_map_imp():
model_str = "sshleifer/tiny-gpt2"
single_model = transformers.AutoModelForCausalLM.from_pretrained(model_str)
server = InferenceServer(
Expand All @@ -82,7 +82,7 @@ def test_inference_server_fsdp_two():
[{"input_ids": input_ids_one}, {"input_ids": input_ids_two}]
)
input_dataset.set_format(type="torch")
outputs = server.map(dataset=input_dataset, closure=lambda x: x)
outputs = server.imap_for_non_fsdp(dataset=input_dataset, closure=lambda x: x)
assert len(outputs) == 2
first_output = outputs[0]
assert (
Expand Down

0 comments on commit 1fca929

Please sign in to comment.