-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add recursive CCS #57
Conversation
and data[0].dtype == data[1].dtype == self.dtype | ||
), "Data must be a tuple of two tensors of the same shape and dtype" | ||
|
||
def correct_dtypes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary at all? I don't see why we need to cast to a single dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When removing this cast I get
Traceback (most recent call last):
File "/home/ubuntu/elk/elk/extensions/recursive_ccs/train.py", line 110, in <module>
train(args)
File "/home/ubuntu/elk/elk/extensions/recursive_ccs/train.py", line 74, in train
probe, train_loss = rccs.fit_next_probe(
File "/home/ubuntu/elk/elk/extensions/recursive_ccs/rccs.py", line 33, in fit_next_probe
train_loss = ccs.fit(data, **train_params)
File "/home/ubuntu/elk/elk/training/ccs.py", line 142, in fit
loss = self.train_loop_lbfgs(x0, x1, num_epochs, weight_decay)
File "/home/ubuntu/elk/elk/training/ccs.py", line 238, in train_loop_lbfgs
optimizer.step(closure)
File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/optim/optimizer.py", line 140, in wrapper
out = func(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/optim/lbfgs.py", line 312, in step
orig_loss = closure()
File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/elk/elk/training/ccs.py", line 224, in closure
logit0, logit1 = self(x0), self(x1)
File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/elk/elk/training/ccs.py", line 103, in forward
return self.probe(x)
File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward
input = module(input)
File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/miniconda3/envs/elk/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you just call .float()
on the input before calling fit
this should work
@@ -0,0 +1,110 @@ | |||
import csv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear why this entire file is necessary. It seems to be mostly copied over from the primary train.py
. Could we add a flag or a subparser to the main command instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes training more complicated. I think this should not happen before the first release.
No description provided.