Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add context-based requests processing #1571

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add context-based requests processing
  • Loading branch information
artemorloff committed Mar 13, 2024
commit 8053761fdb17955285cb880717c63ed7ac1ddd1e
26 changes: 25 additions & 1 deletion lm_eval/api/instance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Literal, Optional, Tuple
from typing import Callable, Literal, Optional, Tuple


OutputType = Literal[
Expand Down Expand Up @@ -36,3 +36,27 @@ def args(self):
return (
self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
)


class ContextInstance(Instance):
def __init__(
self,
requests_updater: Optional[Callable] = None,
storage_updater: Optional[Callable] = None,
**kwargs,
):
super().__init__(**kwargs)
self._update_request = requests_updater
self._update_storage = storage_updater

@property
def update_request(self):
if getattr(self, "_update_request") is not None:
return self._update_request
raise NotImplementedError("Method for updating request is not defined.")

@property
def update_storage(self):
if getattr(self, "_update_storage") is not None:
return self._update_storage
raise NotImplementedError("Method for updating storage is not defined.")
85 changes: 63 additions & 22 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import torch
from tqdm import tqdm

import lm_eval.api.metrics
import lm_eval.api.registry
Expand Down Expand Up @@ -298,13 +299,26 @@ def evaluate(
Dictionary of results
"""

# names of requests meta-types
CONTEXT_BASED_TYPE_ID = "context-based"
DEFAULT_TYPE_ID = "regular"
# name of the attribute inside task that allows using ctx
CONTEXT_BASED_TYPE_ATTR = "CONTEXT_BASED"

eval_logger.setLevel(getattr(logging, f"{verbosity}"))

### prepare to split all requests into two meta-groups
# tracks all Instances/requests a model must generate output on.
requests = defaultdict(list)
requests = {
CONTEXT_BASED_TYPE_ID: defaultdict(list),
DEFAULT_TYPE_ID: defaultdict(list),
}
# stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal
padding_requests = defaultdict(int)
padding_requests = {
CONTEXT_BASED_TYPE_ID: defaultdict(int),
DEFAULT_TYPE_ID: defaultdict(int),
}

# get lists of group hierarchy and each type of request
task_hierarchy, eval_tasks = get_task_list(task_dict)
Expand All @@ -330,10 +344,16 @@ def evaluate(

if write_out:
print_writeout(task)
# aggregate Instances by LM method requested to get output.
# aggregate Instances by LM method requested to get output and req type also
task_type_id = (
CONTEXT_BASED_TYPE_ID
if getattr(task, CONTEXT_BASED_TYPE_ATTR, False)
else DEFAULT_TYPE_ID
)
for instance in task.instances:
reqtype = instance.request_type
requests[reqtype].append(instance)
# split requests into two groups: with and without context
requests[task_type_id][reqtype].append(instance)

if lm.world_size > 1:
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
Expand All @@ -349,30 +369,51 @@ def evaluate(
# compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
numpad = max(gathered_item) - gathered_item[lm.rank]
# todo: may not account for padding in cases like SquadV2 which has multiple req types
padding_requests[reqtype] += numpad
# pad each group separately
padding_requests[task_type_id][reqtype] += numpad

### Run LM on inputs, get all outputs ###
# execute each type of request
for reqtype, reqs in requests.items():
eval_logger.info(f"Running {reqtype} requests")
# create `K` copies of each request `req` based off `K = req.repeats`
cloned_reqs = []
for req in reqs:
cloned_reqs.extend([req] * req.repeats)

if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
for _ in range(padding_requests[reqtype]):
# execute each group of request: ctx-based and regular
for task_type, type_requests in requests.items():
# for reqtype in a group
for reqtype, reqs in type_requests.items():
eval_logger.info(f"Running {task_type} {reqtype} requests")
# create `K` copies of each request `req` based off `K = req.repeats`
cloned_reqs = []
for req in reqs:
cloned_reqs.extend([req] * req.repeats)

# run requests through model
resps = getattr(lm, reqtype)(cloned_reqs)
if (lm.world_size > 1) and (padding_requests[task_type_id][reqtype] > 0):
for _ in range(padding_requests[task_type_id][reqtype]):
cloned_reqs.extend([req] * req.repeats)

# put responses from model into a list of length K for each request.
for x, req in zip(resps, cloned_reqs):
req.resps.append(x)
# regular requests are left untouched
if task_type == DEFAULT_TYPE_ID:
# run all requests through model
resps = getattr(lm, reqtype)(cloned_reqs)

if lm.world_size > 1:
lm.accelerator.wait_for_everyone()
# put responses from model into a list of length K for each request.
for x, req in zip(resps, cloned_reqs):
req.resps.append(x)
# context tasks require separate reqs processing
else:
# needed to store lm outputs only ones
storage = {}
# iterate over all requests
# this tqdm does not overwrite internal tqdms of getattr(lm, reqtype)
for req in tqdm(cloned_reqs, desc=f"Running {reqtype} requests"):
# one request per iteration, each time update req.args
req = req.update_request(storage, req)
# only one resp for a single request
resp = getattr(lm, reqtype)([req])
# simultaneously add output to the Instance attr
req.resps.extend([resp])
# push changes into storage
# also discard storage after the current set ends
storage = req.update_storage(storage, req)

if lm.world_size > 1:
lm.accelerator.wait_for_everyone()

RANK = lm.rank
WORLD_SIZE = lm.world_size
Expand Down