Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recursive CCS #57

Closed
wants to merge 7 commits into from
Closed

Add recursive CCS #57

wants to merge 7 commits into from

Conversation

FabienRoger
Copy link
Collaborator

No description provided.

and data[0].dtype == data[1].dtype == self.dtype
), "Data must be a tuple of two tensors of the same shape and dtype"

def correct_dtypes(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary at all? I don't see why we need to cast to a single dtype

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When removing this cast I get

Traceback (most recent call last):
  File "/home/ubuntu/elk/elk/extensions/recursive_ccs/train.py", line 110, in <module>
    train(args)
  File "/home/ubuntu/elk/elk/extensions/recursive_ccs/train.py", line 74, in train
    probe, train_loss = rccs.fit_next_probe(
  File "/home/ubuntu/elk/elk/extensions/recursive_ccs/rccs.py", line 33, in fit_next_probe
    train_loss = ccs.fit(data, **train_params)
  File "/home/ubuntu/elk/elk/training/ccs.py", line 142, in fit
    loss = self.train_loop_lbfgs(x0, x1, num_epochs, weight_decay)
  File "/home/ubuntu/elk/elk/training/ccs.py", line 238, in train_loop_lbfgs
    optimizer.step(closure)
  File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/optim/optimizer.py", line 140, in wrapper
    out = func(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/optim/lbfgs.py", line 312, in step
    orig_loss = closure()
  File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/elk/elk/training/ccs.py", line 224, in closure
    logit0, logit1 = self(x0), self(x1)
  File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/elk/elk/training/ccs.py", line 103, in forward
    return self.probe(x)
  File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you just call .float() on the input before calling fit this should work

elk/extensions/recursive_ccs/train.py Outdated Show resolved Hide resolved
elk/training/ccs.py Outdated Show resolved Hide resolved
elk/extensions/recursive_ccs/parser.py Outdated Show resolved Hide resolved
@@ -0,0 +1,110 @@
import csv
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear why this entire file is necessary. It seems to be mostly copied over from the primary train.py. Could we add a flag or a subparser to the main command instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes training more complicated. I think this should not happen before the first release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants