-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial implementation of semi-supervised training #58
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
(haven't tested it since I don't have multiple GPUs)
self.loss = js_loss if loss == "js" else ccs_squared_loss | ||
self.unsupervised_loss = js_loss if loss == "js" else ccs_squared_loss | ||
self.supervised_weight = supervised_weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic should probably happen outside of the CCS class, which could simply take a loss fn as argument. But this cleanup can wait a future PR.
@@ -9,7 +9,7 @@ | |||
import torch.distributed as dist | |||
|
|||
|
|||
@torch.autocast("cuda", enabled=torch.cuda.is_available()) | |||
@torch.autocast("cuda", enabled=torch.cuda.is_available()) # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not seeing typing issues here. Are you using mypy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was a pylance error. I'll check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pylance
Pytorch 1.13.1+cu117
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any issues here
This reverts commit 47a6c76.
Another issue I just caught: shuffling now happens after selection of the first max_examples, which means that for datasets where points are sorted by label (such as imdb) you won't be able to use low number of max examples. |
Fixed |
elk/extraction/prompt_collator.py
Outdated
self.dataset = self.dataset.select(range(max_examples)) | ||
if dist.is_initialized(): | ||
self.dataset = self.dataset.shard(dist.get_world_size(), dist.get_rank()) | ||
|
||
self.dataset = self.dataset.shuffle(seed=seed) | ||
if max_examples: | ||
self.dataset = self.dataset.select(range(max_examples)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't that select world_size x max_examples examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep!
@@ -9,7 +9,7 @@ | |||
import torch.distributed as dist | |||
|
|||
|
|||
@torch.autocast("cuda", enabled=torch.cuda.is_available()) | |||
@torch.autocast("cuda", enabled=torch.cuda.is_available()) # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any issues here
@@ -85,6 +90,9 @@ def reset_parameters(self): | |||
for layer in self.probe: | |||
if isinstance(layer, nn.Linear): | |||
layer.reset_parameters() | |||
elif self.init == "zero": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, why did you add this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that you tend to get non-trivial probes with high accuracy even without any confidence loss, and I wanted to test the hypothesis that it's due to an initialization issue where it's hard for the optimizer to get to the trivial solution of outputting 0.5 all the time. Turns out to be true. See messages in the ELK channel a couple days ago
import torch.nn as nn | ||
|
||
|
||
def maybe_all_gather(x: Tensor) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably a dumb quesiton, but why the "maybe"s? Is it because you have maybe multiple processes or maybe just one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it's because these methods do nothing when the script isn't run with torchrun
Draft PR, fixes #54. As a bonus, also partially fixes #36 (for DDP only)