False INTERNAL ASSERT FAILED bug whilst training Neural Network #128778
Labels
module: linear algebra
Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🐛 Describe the bug
from kan import *
import pickle
import control as ct
import scipy.io as sio
vars = sio.loadmat('variables.mat')
predict_time = int(vars['predict_time'][0,0])
dp = int(vars['dp'][0,0]) - (predict_time)
dt = float(vars['dt'][0,0])
tf = int(vars['tf'][0,0])
m = int(vars['m'][0,0])
nx = int(vars['nx'][0,0])
data_x = sio.loadmat('data_x.mat')
train_x = torch.FloatTensor(data_x['data_x'])
data_y = sio.loadmat('data_y.mat')
train_y = torch.FloatTensor(data_y['data_y'])
data_y2 = sio.loadmat('data_y2.mat')
train_y2 = torch.FloatTensor(data_y2['data_y2'])
data_u = sio.loadmat('data_u.mat')
train_u = torch.FloatTensor(data_u['data_u'])
lifted_space = 3
hidden_size = 2
learning_rate = 1
Loss_prev = 1e38
P = torch.cat((torch.eye(nx),torch.zeros(nx,lifted_space)),1)
criterion = torch.nn.MSELoss()
log = 1
grids = [5,5]
steps = 20
recon_losses = []
pred_losses = []
model = KAN(width=[nx,hidden_size,lifted_space], grid=grids[0], k=3, grid_eps=0, noise_scale_base=0.25)
i=0
for grid in grids:
if i == 0:
a = 0
else:
model = KAN(width=[nx,hidden_size,lifted_space], grid=grid, k=3).initialize_from_another_model(model, train_x)
'''---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[3], line 74
72 recon_losses.append(recon_loss.detach().numpy())
73 pred_losses.append(pred_loss.detach().numpy())
---> 74 train()
75 i=i+1
Cell In[3], line 65
62 return loss
64 if step % 1 == 0 and step < 50:
---> 65 model.update_grid_from_samples(train_x)
67 optimizer.step(closure)
69 if step % log == 0:
File c:\Users\georg\Documents\University\FIT\Research\KAN Network\Quadcopter.venv\Lib\site-packages\kan\KAN.py:244, in KAN.update_grid_from_samples(self, x)
242 for l in range(self.depth):
243 self.forward(x)
--> 244 self.act_fun[l].update_grid_from_samples(self.acts[l])
File c:\Users\georg\Documents\University\FIT\Research\KAN Network\Quadcopter.venv\Lib\site-packages\kan\KANLayer.py:218, in KANLayer.update_grid_from_samples(self, x)
216 grid_uniform = torch.cat([grid_adaptive[:, [0]] - margin + (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) * a for a in np.linspace(0, 1, num=self.grid.shape[1])], dim=1)
217 self.grid.data = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
--> 218 self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k, device=self.device)
File c:\Users\georg\Documents\University\FIT\Research\KAN Network\Quadcopter.venv\Lib\site-packages\kan\spline.py:137, in curve2coef(x_eval, y_eval, grid, k, device)
135 # x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar
136 mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1)
--> 137 coef = torch.linalg.lstsq(mat.to('cpu'), y_eval.unsqueeze(dim=2).to('cpu')).solution[:, :, 0] # sometimes 'cuda' version may diverge
138 return coef.to(device)
RuntimeError: false INTERNAL ASSERT FAILED at "..\aten\src\ATen\native\BatchLinearAlgebra.cpp":1538, please report a bug to PyTorch. torch.linalg.lstsq: (Batch element 8): Argument 4 has illegal value. Most certainly there is a bug in the implementation calling the backend library.'''
Versions
StatusCode : 200
StatusDescription : OK
Content : # mypy: allow-untyped-defs
RawContent : HTTP/1.1 200 OK
Connection: keep-alive
Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; sandbox
Strict-Transport-Security: max-age=31536000
X-Content-Type-Options: nosniff
...
Forms : {}
Headers : {[Connection, keep-alive], [Content-Security-Policy, default-src 'none'; style-src 'unsafe-inline'; sandbox],
[Strict-Transport-Security, max-age=31536000], [X-Content-Type-Options, nosniff]...}
Images : {}
InputFields : {}
Links : {}
ParsedHtml : mshtml.HTMLDocumentClass
RawContentLength : 23357
cc @jianyuh @nikitaved @pearu @mruberry @walterddr @xwang233 @lezcano
The text was updated successfully, but these errors were encountered: