Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of semi-supervised training #58

Merged
merged 16 commits into from
Feb 14, 2023
Merged

Conversation

norabelrose
Copy link
Member

@norabelrose norabelrose commented Feb 13, 2023

Draft PR, fixes #54. As a bonus, also partially fixes #36 (for DDP only)

@norabelrose norabelrose marked this pull request as draft February 13, 2023 18:02
@norabelrose norabelrose marked this pull request as ready for review February 14, 2023 06:35
FabienRoger
FabienRoger previously approved these changes Feb 14, 2023
Copy link
Collaborator

@FabienRoger FabienRoger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm
(haven't tested it since I don't have multiple GPUs)

elk/extraction/prompt_collator.py Outdated Show resolved Hide resolved
elk/extraction/extraction_main.py Outdated Show resolved Hide resolved
Comment on lines -66 to +71
self.loss = js_loss if loss == "js" else ccs_squared_loss
self.unsupervised_loss = js_loss if loss == "js" else ccs_squared_loss
self.supervised_weight = supervised_weight
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic should probably happen outside of the CCS class, which could simply take a loss fn as argument. But this cleanup can wait a future PR.

@@ -9,7 +9,7 @@
import torch.distributed as dist


@torch.autocast("cuda", enabled=torch.cuda.is_available())
@torch.autocast("cuda", enabled=torch.cuda.is_available()) # type: ignore
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not seeing typing issues here. Are you using mypy?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was a pylance error. I'll check

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pylance
Pytorch 1.13.1+cu117

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any issues here

elk/training/ccs.py Outdated Show resolved Hide resolved
elk/training/train.py Outdated Show resolved Hide resolved
@FabienRoger
Copy link
Collaborator

Another issue I just caught: shuffling now happens after selection of the first max_examples, which means that for datasets where points are sorted by label (such as imdb) you won't be able to use low number of max examples.

@FabienRoger FabienRoger dismissed their stale review February 14, 2023 16:52

I failed to see some problems

@norabelrose
Copy link
Member Author

Another issue I just caught: shuffling now happens after selection of the first max_examples, which means that for datasets where points are sorted by label (such as imdb) you won't be able to use low number of max examples.

Fixed

Comment on lines 56 to 60
self.dataset = self.dataset.select(range(max_examples))
if dist.is_initialized():
self.dataset = self.dataset.shard(dist.get_world_size(), dist.get_rank())

self.dataset = self.dataset.shuffle(seed=seed)
if max_examples:
self.dataset = self.dataset.select(range(max_examples))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't that select world_size x max_examples examples?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!

@@ -9,7 +9,7 @@
import torch.distributed as dist


@torch.autocast("cuda", enabled=torch.cuda.is_available())
@torch.autocast("cuda", enabled=torch.cuda.is_available()) # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any issues here

@@ -85,6 +90,9 @@ def reset_parameters(self):
for layer in self.probe:
if isinstance(layer, nn.Linear):
layer.reset_parameters()
elif self.init == "zero":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why did you add this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that you tend to get non-trivial probes with high accuracy even without any confidence loss, and I wanted to test the hypothesis that it's due to an initialization issue where it's hard for the optimizer to get to the trivial solution of outputting 0.5 all the time. Turns out to be true. See messages in the ELK channel a couple days ago

import torch.nn as nn


def maybe_all_gather(x: Tensor) -> Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably a dumb quesiton, but why the "maybe"s? Is it because you have maybe multiple processes or maybe just one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's because these methods do nothing when the script isn't run with torchrun

@norabelrose norabelrose merged commit 32e633b into main Feb 14, 2023
@norabelrose norabelrose deleted the semi-supervised branch February 14, 2023 21:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants