-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Strange result with a simple 2-layers NN #138
Comments
Thanks for bringing this to our attention - this definitely looks like a bug, I think we are doing an implicit assumption of being in the overparameterized regime and just returning |
Hi, Yes, I have concluded that your equation
So, I come to the conclusion that in place of fx_train_inf, fx_test_inf = predict_fn(None, fx_train_0, fx_test_0, kntk_emp_test_train) I should use _, fx_test_inf = predict_fn(None, fx_train_0, fx_test_0, kntk_emp_test_train)
_, fx_train_inf = predict_fn(None, fx_train_0, fx_train_0, kntk_emp_train_train) and got this figure (d=15) Moreover, I was expecting to get the asymptotic regime (N\rightarrow \infty) using predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, X_train, Y_train, diag_reg=1e-9)
ntk_test_mean= predict_fn(x_test=X_test, get='ntk', compute_cov=False)
print("loss_test =",loss(ntk_test_mean, Y_test)) But, for -==> I wander if I use correctly the library as according to https://arxiv.org/pdf/2007.12826.pdf I should find as asymptote the infinite width kernel result. So, have I done a mistake ? Thanks |
Hello,
I will give a snippet
I get (DeviceArray(0., dtype=float64), DeviceArray(nan, dtype=float64)).
But, I would expect as Nd=90 (the number parameter of 1st Dense layer wo bias) is smaller than the number of samples (165) that the train MSE is not 0 ( I am not in the overparametrized regime) and the test MSE is not diverging as Nd =/= ns.
So I am puzzled and certainly I have missed something. What I wanted to do is to compute the MSE (time infinite) inference with the finite width Neural Tangent Kernel. By the way I am trying to reproduce more or less the results of Figure 1 & 2 of https://arxiv.org/pdf/2007.12826.pdf by Andrea Montanari and Yiqiao Zhong.
The text was updated successfully, but these errors were encountered: