Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strange result with a simple 2-layers NN #138

Open
jecampagne opened this issue Jan 19, 2022 · 3 comments
Open

Strange result with a simple 2-layers NN #138

jecampagne opened this issue Jan 19, 2022 · 3 comments
Labels
bug Something isn't working

Comments

@jecampagne
Copy link

jecampagne commented Jan 19, 2022

Hello,

I will give a snippet

import jax
import jax.numpy as jnp
from jax import jit
from jax import grad
from jax.example_libraries import optimizers

from jax.config import config
config.update("jax_enable_x64", True) # DOUBLE PRECISION pour les operations matricielles

import numpy as np

import neural_tangents as nt
from neural_tangents import stax
###########

key = jax.random.PRNGKey(0)  #initial seed

# Some dimensions
d=15
N=6
ns=165
n_test=1_000
batch_size=5

# A vector beta once for all
beta = jax.random.normal(key, shape=(1,d))
norm = jnp.linalg.norm(beta, axis=1)
beta =  beta / norm

# Utils to generate a dataset
def gen_x(key=None, r=1.0, d=20,ns=50):
    x = jax.random.normal(key, shape=(ns,d))
    norm = jnp.linalg.norm(x, axis=1)
    x_normed = r * x / norm.reshape(x.shape[0],1)
    return x_normed


def gen_y(key, X, beta, sigma_eps=0.5):
    " Target generation"
    Xbeta = X @ beta.T  # <beta, Xi>
    y = jnp.sin(Xbeta)
    noise = jax.random.normal(key,shape=(X.shape[0],1)) * sigma_eps
    return y + noise

# The MSE loss
loss = lambda fx, y_hat: 0.5*jnp.mean((fx - y_hat) ** 2)
grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

    
#Test Dataset
key, x_key, y_key = jax.random.split(key, 3)
X_test = gen_x(x_key, r=np.sqrt(d), d=d, ns=n_test)
Y_test = gen_y(y_key,X_test, beta, sigma_eps=0.) # no error
    
# Train Dataste
key, x_key, y_key = jax.random.split(key, 3)
X_train = gen_x(x_key, r=np.sqrt(d), d=d, ns=ns)
Y_train = gen_y(y_key,X_train, beta, sigma_eps=0.5)
        
               
#NN 2-layers for regression 1 ouput
init_fn, apply_fn, kernel_fn = stax.serial(
                stax.Dense(N, W_std=1., parameterization='standard'), 
                stax.Relu(),
                stax.Dense(1, W_std=1., parameterization='standard')
)

#Finite Width NTK with batch size as the number of samples can be large
emp_ntk_kernel_fn = nt.batch(nt.empirical_ntk_fn(apply_fn),device_count=-1, batch_size=batch_size)
            
#Initialize the parameters and the NTK_train_train /NTK_test_train kernel matrix
_, params = init_fn(key, (-1, d))
kntk_emp_train_train = emp_ntk_kernel_fn(X_train, None, params)
kntk_emp_test_train  = emp_ntk_kernel_fn(X_test, X_train, params)
            
predict_fn = nt.predict.gradient_descent_mse(kntk_emp_train_train, Y_train,  diag_reg=1.e-9)

#First (t=0) inference of the Network for Train & Test sets
fx_train_0 = apply_fn(params, X_train)
fx_test_0  = apply_fn(params, X_test)
            
# MSE @ t=Infinity inference (= ridgeless regression min-norm)
fx_train_inf, fx_test_inf = predict_fn(None, fx_train_0, fx_test_0,  kntk_emp_test_train)

# The MSE loss on Train & Test datasets
loss(fx_train_inf, Y_train), loss(fx_test_inf, Y_test)

I get (DeviceArray(0., dtype=float64), DeviceArray(nan, dtype=float64)).

But, I would expect as Nd=90 (the number parameter of 1st Dense layer wo bias) is smaller than the number of samples (165) that the train MSE is not 0 ( I am not in the overparametrized regime) and the test MSE is not diverging as Nd =/= ns.

So I am puzzled and certainly I have missed something. What I wanted to do is to compute the MSE (time infinite) inference with the finite width Neural Tangent Kernel. By the way I am trying to reproduce more or less the results of Figure 1 & 2 of https://arxiv.org/pdf/2007.12826.pdf by Andrea Montanari and Yiqiao Zhong.

@romanngg romanngg added the bug Something isn't working label Jan 21, 2022
@romanngg
Copy link
Contributor

Thanks for bringing this to our attention - this definitely looks like a bug, I think we are doing an implicit assumption of being in the overparameterized regime and just returning Y_train / doing overparametrized linear regression on X_test. Will look into this (although unfortunately not before ICML).

@jecampagne
Copy link
Author

jecampagne commented Jan 22, 2022

Hi,
Thanks a lot for taking time for my posts. Your lib is awesome so I would like to use it properly.

Yes, I have concluded that your equation

f_t^{lin}(\mathcal{X}) = \left(\mathbf{1} - e^{-\eta \Theta_0 t} \right)\mathcal{Y} + e^{-\eta \Theta_0 t} f_0(\mathcal{X}) 

image
holds in the case of \Theta_0 is inversible so in the over-parametrized regime.

So, I come to the conclusion that in place of

fx_train_inf, fx_test_inf = predict_fn(None, fx_train_0, fx_test_0,  kntk_emp_test_train)

I should use

    _, fx_test_inf  = predict_fn(None, fx_train_0, fx_test_0,   kntk_emp_test_train)
    _, fx_train_inf = predict_fn(None, fx_train_0, fx_train_0,  kntk_emp_train_train)

and got this figure (d=15)
image
which looks reasonable, but I may have made a mistake, so please if you have a comment, let me know.

Moreover, I was expecting to get the asymptotic regime (N\rightarrow \infty) using

    predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, X_train, Y_train, diag_reg=1e-9)    
    ntk_test_mean= predict_fn(x_test=X_test, get='ntk', compute_cov=False)
    print("loss_test =",loss(ntk_test_mean, Y_test))

But, for ns=100 and d=15 (ln n/ln d = 1.7) corresponding to the "red curve" in the above figure, I get
loss_test = 0.28 (obtained with different seeds and value of N, ie the second layer number of neurons). So, is quite less than the 0.4 value obtained with the value of N=667 that corresponds to ln Nd/ln n=2ie the value at the right of the figure).
For ns=225 (ln n/ln d = 2.0) (the green curve) I find also loss_test = 0.27 which is bigger than the 0.2value for finite size kernel. Finally, with ns=869 (ln n/ln d = 2.5) (the blue curve) I get loss_test = 0.11 which this case in agreement with the finite size kernel.

-==> I wander if I use correctly the library as according to https://arxiv.org/pdf/2007.12826.pdf I should find as asymptote the infinite width kernel result. So, have I done a mistake ? Thanks

@jecampagne
Copy link
Author

Ha ! I got it... this is due to the parametrization=standard vs ntkof the Dense layers. here it is when I use ntk
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants