Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rlbayes committed Mar 17, 2023
1 parent e42d0f6 commit 118a985
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 18 deletions.
8 changes: 4 additions & 4 deletions evals/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,12 @@ def default(self, o: Any) -> str:
return _to_py_types(o)


def jsondumps(o: Any, **kwargs: Any) -> str:
return json.dumps(o, cls=EnhancedJSONEncoder, **kwargs)
def jsondumps(o: Any, ensure_ascii: bool = False, **kwargs: Any) -> str:
return json.dumps(o, cls=EnhancedJSONEncoder, ensure_ascii=ensure_ascii, **kwargs)


def jsondump(o: Any, fp: Any, **kwargs: Any) -> None:
json.dump(o, fp, cls=EnhancedJSONEncoder, **kwargs)
def jsondump(o: Any, fp: Any, ensure_ascii: bool = False, **kwargs: Any) -> None:
json.dump(o, fp, cls=EnhancedJSONEncoder, ensure_ascii=ensure_ascii, **kwargs)


def jsonloads(s: str, **kwargs: Any) -> Any:
Expand Down
24 changes: 10 additions & 14 deletions evals/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,16 +293,12 @@ def __init__(self, log_path: Optional[str], run_spec: RunSpec):
self.event_file_path = log_path
if log_path is not None:
with bf.BlobFile(log_path, "wb") as f:
f.write(
(
jsondumps({"spec": dataclasses.asdict(run_spec)}, ensure_ascii=False) + "\n"
).encode("utf-8")
)
f.write((jsondumps({"spec": dataclasses.asdict(run_spec)}) + "\n").encode("utf-8"))

def _flush_events_internal(self, events_to_write: Sequence[Event]):
start = time.time()
try:
lines = [jsondumps(event, ensure_ascii=False) + "\n" for event in events_to_write]
lines = [jsondumps(event) + "\n" for event in events_to_write]
except TypeError as e:
logger.error(f"Failed to serialize events: {events_to_write}")
raise e
Expand All @@ -318,8 +314,8 @@ def _flush_events_internal(self, events_to_write: Sequence[Event]):
self._flushes_done += 1

def record_final_report(self, final_report: Any):
with bf.BlobFile(self.event_file_path, "a") as f:
f.write(jsondumps({"final_report": final_report}) + "\n")
with bf.BlobFile(self.event_file_path, "ab") as f:
f.write((jsondumps({"final_report": final_report}) + "\n").encode("utf-8"))

logging.info(f"Final report: {final_report}. Logged to {self.event_file_path}")

Expand All @@ -345,8 +341,8 @@ def __init__(
self._conn = snowflake_connection

if log_path is not None:
with bf.BlobFile(log_path, "w") as f:
f.write(jsondumps({"spec": dataclasses.asdict(run_spec)}) + "\n")
with bf.BlobFile(log_path, "wb") as f:
f.write((jsondumps({"spec": dataclasses.asdict(run_spec)}) + "\n").encode("utf-8"))

query = """
INSERT ALL INTO runs (run_id, model_name, eval_name, base_eval, split, run_config, settings, created_by, created_at)
Expand Down Expand Up @@ -411,15 +407,15 @@ def _flush_events_internal(self, events_to_write: Sequence[Event]):
)
idx_l = idx_r

with bf.BlobFile(self.event_file_path, "a") as f:
f.writelines(lines)
with bf.BlobFile(self.event_file_path, "ab") as f:
f.write(b"".join([l.encode("utf-8") for l in lines]))
self._last_flush_time = time.time()
self._flushes_done += 1

def record_final_report(self, final_report: Any):
with self._writing_lock:
with bf.BlobFile(self.event_file_path, "a") as f:
f.write(jsondumps({"final_report": final_report}) + "\n")
with bf.BlobFile(self.event_file_path, "ab") as f:
f.write((jsondumps({"final_report": final_report}) + "\n").encode("utf-8"))
query = """
UPDATE runs
SET final_report = PARSE_JSON(%(final_report)s)
Expand Down
1 change: 1 addition & 0 deletions evals/registry/eval_sets/test-modelgraded.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ test-modelgraded:
- rap-people-vs-people
- rap-animals-vs-fruits
- rap-people-vs-fruits
- mg-humor-people_jp
4 changes: 4 additions & 0 deletions scripts/modelgraded_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def format(template: str, **kwargs: dict[str, str]) -> str:

data_dir = f"{REGISTRY_PATH}/data/test_modelgraded"
yaml_str = f"# This file is generated by {os.path.basename(__file__)}\n\n"
evals = []
for prompt_name, subject in unlabeled_target_sets:
prompt = unlabeled_prompts[prompt_name]["prompt"]
samples = [{"input": format(prompt, subject=s)} for s in subjects[subject]]
Expand All @@ -201,9 +202,12 @@ def format(template: str, **kwargs: dict[str, str]) -> str:
)
+ "\n\n"
)
evals += [f"mg-{prompt_name}-{subject}: {file_name}"]


yaml_file = f"{REGISTRY_PATH}/evals/test-modelgraded-generated.yaml"
with open(yaml_file, "w") as f:
f.write(yaml_str)
print(f"wrote {yaml_file}")
for e in evals:
print(e)

0 comments on commit 118a985

Please sign in to comment.