Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is unit of batch-size and how should I set its value? #426

Open
feiyang-k opened this issue Sep 8, 2023 · 2 comments
Open

What is unit of batch-size and how should I set its value? #426

feiyang-k opened this issue Sep 8, 2023 · 2 comments
Labels
question Further information is requested

Comments

@feiyang-k
Copy link

Hi!

Following up on the solutions kindly provided, I'm experimenting with the batch-size method. It is very helpful! Now we are able to compute OT problems at least 10 times larger than before (still trying on larger sizes).

  • A quick question I got here is–What is the unit of batch-size?

It seems a bit counterintuitive. For example, if I'm computing the optimal transport between two distributions X_1 and X_2, I found that if I fix X_1 and increase the size of X_2 I would need to reduce the size of batch_size.

Right now, it seems I would need to do a manual line search for this parameter to find an "optimal feasible solution" for it :) Is there a more straightforward way to set its value?

Thankss!

@michalk8
Copy link
Collaborator

michalk8 commented Sep 8, 2023

A quick question I got here is–What is the unit of batch-size?

The unit is the number of elements of X_1/X_2 to evaluate in chunks (think more apt would be to rename batch_size -> chunk_size). For example, setting batch_size=8192, it would evaluate the LSE (or the kernel) matrix-vector product in a batched way, materializing either (batch_size, X_2.shape[0]) or (X_1.shape[0], batch_size) matrices.

There's no option to specify the batch size as a fraction of the input and we're currently looking into having a more robust implementation of a chunked-vmap (see google/jax#11319 or netket's implementation).

@michalk8 michalk8 added the question Further information is requested label Sep 8, 2023
@ntenenz
Copy link

ntenenz commented Sep 8, 2023

As dataset size grows (either in sample count or feature dimension), the dataset itself can quickly become limited by GPU memory as well. In this case, it may be worth also exploring the use of a jax callback to bring (a chunk of) data to the device.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants