Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train probe per prompt #271

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add assert
  • Loading branch information
derpyplops committed Jul 20, 2023
commit f533418f9907c8816233d912c3ae47a50a06ee96
1 change: 1 addition & 0 deletions elk/training/multi_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

class MultiReporter:
def __init__(self, reporter_results: list[ReporterTrainResult]):
assert len(reporter_results) > 0, "Must have at least one reporter"
self.reporter_results: list[ReporterTrainResult] = reporter_results
self.reporters = [r.reporter for r in reporter_results]
train_losses = (
Expand All @@ -26,7 +27,7 @@
else None
)
self.train_loss = (
sum(train_losses) / len(train_losses) if train_losses is not None else None

Check failure on line 30 in elk/training/multi_reporter.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Argument of type "list[float | None]" cannot be assigned to parameter "__iterable" of type "Iterable[_SupportsSumNoDefaultT@sum]" in function "sum"   "list[float | None]" is incompatible with "Iterable[_SupportsSumNoDefaultT@sum]"     TypeVar "_T_co@Iterable" is covariant       Type "float | None" cannot be assigned to type "_SupportsSumWithNoDefaultGiven"         Type "float | None" cannot be assigned to type "_SupportsSumWithNoDefaultGiven"           "__add__" is not present           "__radd__" is not present (reportGeneralTypeIssues)
)

def __call__(self, h):
Expand All @@ -46,4 +47,4 @@
reporter = t.load(path, map_location=device)
reporters.append(reporter)
# TODO for now I don't care about the train losses
return MultiReporter([ReporterTrainResult(r, None) for r in reporters])

Check failure on line 50 in elk/training/multi_reporter.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Argument missing for parameter "prompt_index" (reportGeneralTypeIssues)
Loading