-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hands on in in-house dataset - Bad performance on M1 Max #94
Comments
Hi, if your input has 5D, then kan_model = KAN(width=[1,1,1], grid=2, k=3, seed=0)
kan_model.train(...) should incur an error immediately, because your KAN takes only one inputs. valid widths are e.g., |
Mmm, indeed there was no error but that was the problem. Everything is rolling ;) |
Hi I have some general advice on hyperparameter tuning here. |
Providing data with a dimension that differs from the width of the first KAN layers should maybe raise an Error? |
Testing on a in-house dataset which we have in our research group as a benchmark. It has lots of inputs to one output but to keep it simple now we have 5 inputs and 1 output and like 100K observations.
Model conversion from pandas Dataframe
my_ds= {"train_input":torch.from_numpy(np.array(train_data_x)[:, :5]),
"test_input":torch.from_numpy(np.array(test_data_x)[:, :5]),
"train_label":torch.from_numpy(np.array(train_data_y)),
"test_label":torch.from_numpy(np.array(test_data_y))}
Model creation
kan_model = KAN(width=[1,1,1], grid=2, k=3, seed=0)
Model fit
kan_model.train(my_ds, opt="LBFGS", steps=2, lamb=0.01, lamb_entropy=10.)
Perhaps am I missing something, some parameter I don't know.
Pd: also tested for 10k observations and got the same behaviour.
Idk if it is important due to optimization and that stuff but I'm using an Apple M1 Max w/ 64GB.
The versions used are:
torch- > 2.3.0
numpy -> 1.24.4
So I left the training on a Jupyter notebook for more than an hour and it haven't pass from the 0%, tbh I think is something related to the chip architecture. But also could be to the data... I do not know if 10k obervations are a lot for this stage of the code.
The text was updated successfully, but these errors were encountered: