Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: update_grid_from_samples function #253

Closed
hi-jin opened this issue Jun 3, 2024 · 1 comment
Closed

Question: update_grid_from_samples function #253

hi-jin opened this issue Jun 3, 2024 · 1 comment

Comments

@hi-jin
Copy link

hi-jin commented Jun 3, 2024

Hello,

I'm trying to fully understand the update_grid_from_samples function in KANLayer.py, but I'm having trouble with a particular part of the code.

Specifically, I don't understand the purpose of sorting x along the batch axis. Here is the relevant code snippet:

x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
x_pos = torch.sort(x, dim=1)[0]
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k, device=self.device)

Here, x has a shape of (self.in_dim * self.out_dim, batch).

My concern is that sorting x along the batch axis (dim=1) might result in losing the 'batch' information. I would appreciate any clarification on why sorting along this axis is necessary and how it impacts the functionality of the update_grid_from_samples function.

Thank you for your assistance.

@hi-jin
Copy link
Author

hi-jin commented Jun 3, 2024

Oh, I just have come to understand the purpose of the code and resolved my own question. Therefore, I am closing this issue.

@hi-jin hi-jin closed this as completed Jun 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant