Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory complexity issue with pmap #8585

Open
mohamad-amin opened this issue Nov 18, 2021 · 8 comments
Open

Memory complexity issue with pmap #8585

mohamad-amin opened this issue Nov 18, 2021 · 8 comments
Assignees
Labels
bug Something isn't working needs info More information is required to diagnose & prioritize the issue. NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)

Comments

@mohamad-amin
Copy link

mohamad-amin commented Nov 18, 2021

Hey!

I'm trying to compute the result of multiple kernel ridge regressions in a parallel mode. I've wrote the code and created jax expressions of my functions using jax.make_jaxpr. According to the jax expressions, the data and computation should fit into my GPU (I'm using 4 V100 GPU with 16GB of RAM on each, which amounts to 64GB of GPU RAM), and they should be very far from the actual limits of what I have, but surprisingly, it throws and OOM. (I'm using 64bit precision)

Basically, what I expect from the jax expressions is that the most expensive item here (memory-wise) should be the 4000 x 2000 x 10 x 10 along with the 20000x20000 matrix that are broadcasted on each GPU, which amounts to ~9GB of GPU RAM, but other than that, I can't see why this code can't fit in the GPU. (P.S: before entering the pmap, the gpu is in the state that is shown in the picture below)

Screen Shot 2021-11-18 at 12 21 16 AM

Error:

021-11-18 00:23:23.244442: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:802] failed to alloc 34156314624 bytes on host: CUDA_ERROR_INVALID_VALUE: invalid argument
2021-11-18 00:23:23.244517: W external/org_tensorflow/tensorflow/core/common_runtime/device/device_host_allocator.h:46] could not allocate pinned host memory of size: 34156314624

Killed

My compiled functions:

{ lambda a:bool[10,10] b:f64[10,10] c:f64[10,10] d:bool[1,10,10] e:f64[1,10,10] f:bool[1,10,10]
    g:f64[1,10,10]; h:f64[4,1,10] i:f64[4,1,1,2000,10,10] j:f64[4,1,1,1,10,10] k:f64[4,1,1,4000,10,10]
    l:f64[4,1,1,10] m:f64[10,10] n:f64[2000,10] o:f64[20000,20000] p:bool[] q:f64[2000,10]
    r:f64[4000,10] s:f64[4000,2000,10,10]. let
    t:f64[4,1] = xla_pmap[
      axis_name=<axis 0x2aad77a59550>
            slice_sizes=(1, 1, 1, 1)
            unique_indices=False
          ] ez fz
          gb:f64[1,10,4000] = broadcast_in_dim[
            broadcast_dimensions=(0, 1, 2)
            shape=(1, 10, 4000)
          ] ga
          gc:bool[1,10,4000] = ge gb 1.0
          gd:f64[4000] = broadcast_in_dim[broadcast_dimensions=() shape=(4000,)] 1.0
          ge:f64[1,10,4000] = xla_call[
            backend=None
            call_jaxpr={ lambda ; gf:bool[1,10,4000] gg:f64[4000] gh:f64[1,10,4000]. let
                gi:f64[10,4000] = broadcast_in_dim[
                  broadcast_dimensions=(1,)
                  shape=(10, 4000)
                ] gg
                gj:f64[1,10,4000] = broadcast_in_dim[
                  broadcast_dimensions=(1, 2)
                  shape=(1, 10, 4000)
                ] gi
                gk:f64[1,10,4000] = select gf gj gh
              in (gk,) }
            device=None
            donated_invars=(False, False, False)
            inline=False
            name=vmap(vmap(_where))
          ] gc gd gb
          gl:f64[1,10,4000] = sub 1.0 ge
          gm:f64[1,10] = reduce_sum[axes=(2,)] gl
          gn:f64[1] = dot_general[
            dimension_numbers=(((1,), (1,)), ((0,), (0,)))
            precision=None
            preferred_element_type=None
          ] bb gm
          go:f64[1] = mul gn -1.0
        in (go,) }
      devices=None
      donated_invars=(False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False)
      global_arg_shapes=(None, None, None, None, None, None, None, None, None, None, None, None)
      global_axis_size=None
      in_axes=(None, None, None, None, None, None, None, 0, 0, 0, 0, 0, None, None, None, None, None, None, None)
      name=compute_batch_uncertainty
      out_axes=(0,)
    ] a b c d e f g h i j k l m n o p q r s
  in (t,) }

in the compiled function above, there is a xla_call that calls this compiled function:

{ lambda a:bool[10,10] b:f64[10,10] c:f64[10,10] d:bool[10,10] e:f64[10,10]; f:f64[10,1,10]
    g:f64[4000,1,10,10] h:f64[2000,1,10,10] i:f64[1,10,10] j:f64[10] k:f64[2000,10]
    l:f64[20000,20000] m:bool[] n:f64[2000,10] o:f64[4000,10] p:f64[4000,2000,10,10]. let
    q:f64[2000,10,1,10] = transpose[permutation=(0, 2, 1, 3)] h
    r:f64[20000,10] = reshape[dimensions=None new_sizes=(20000, 10)] q
    s:f64[20000,10] = xla_call[
      backend=None
    call_jaxpr={ lambda ; t:f64[20000,20000] u:f64[20000,10]. let
        v:f64[20000,10] = triangular_solve[
        conjugate_a=False
        left_side=True
        lower=False
        transpose_a=True
        unit_diagonal=False
        ] t u
        w:f64[20000,10] = triangular_solve[
        conjugate_a=False
        left_side=True
        lower=False
        transpose_a=False
        unit_diagonal=False
        ] t v
    in (w,) }
      device=None
      donated_invars=(False, False)
      inline=False
      name=_cho_solve
    ] l r
    x:f64[2000,10,1,10] = reshape[dimensions=None new_sizes=(2000, 10, 1, 10)] s
    y:f64[2000,1,10,10] = transpose[permutation=(0, 2, 1, 3)] x
    z:f64[1,10,1,10] = dot_general[
      dimension_numbers=(((0, 2), (0, 2)), ((), ()))
      precision=None
      preferred_element_type=None
    ] h y
    ba:f64[1,1,10,10] = transpose[permutation=(0, 2, 1, 3)] z
    bb:f64[1,1,10,10] = broadcast_in_dim[
      broadcast_dimensions=(1, 2, 3)
      shape=(1, 1, 10, 10)
    ] i
    bc:f64[1,1,10,10] = sub bb ba
    bd:f64[1,10,1,10] = transpose[permutation=(0, 2, 1, 3)] bc
    be:f64[10,10] = reshape[dimensions=None new_sizes=(10, 10)] bd
    dt:i32[4000] = convert_element_type[new_dtype=int32 weak_type=False] dp
    du:i32[10,4000] = convert_element_type[new_dtype=int32 weak_type=False] ds
    dv:i32[4000,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4000, 1)] dt
    dw:i32[10,4000,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(10, 4000, 1)
    ] du
    dx:i32[10,4000,1] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(10, 4000, 1)
    ] dv
    dy:i32[10,4000,2] = concatenate[dimension=2] dx dw
    dz:i32[10,4000,1] = iota[dimension=0 dtype=int32 shape=(10, 4000, 1)]
    ea:i32[10,4000,3] = concatenate[dimension=2] dz dy
    eb:f64[10,4000] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1, 2), start_index_map=(0, 1, 2))
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1, 1)
      unique_indices=False
    ] df ea
    ec:f64[10,4000] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(10, 4000)
    ] eb
    ed:bool[10,4000] = ge ec 1.0
    ee:f64[4000] = broadcast_in_dim[broadcast_dimensions=() shape=(4000,)] 1.0
    ef:f64[10,4000] = xla_call[
      backend=None
      call_jaxpr={ lambda ; eg:bool[10,4000] eh:f64[4000] ei:f64[10,4000]. let
          ej:f64[10,4000] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(10, 4000)
          ] eh
          ek:f64[10,4000] = select eg ej ei
        in (ek,) }
      device=None
      donated_invars=(False, False, False)
      inline=False
      name=vmap(_where)
    ] ed ee ec
    el:f64[10,4000] = sub 1.0 ef
    em:f64[10] = reduce_sum[axes=(1,)] el
  in (em,) }
@mohamad-amin mohamad-amin added the bug Something isn't working label Nov 18, 2021
@mohamad-amin
Copy link
Author

I guess the question is, is XLA's triangular_solve operator really working in-place? If not, this could be expected, but shouldn't it work inplace?

@mohamad-amin
Copy link
Author

mohamad-amin commented Nov 19, 2021

I guess it's not triangular_solve. I checked the source code. The problem is that dot_general is not operating as it's expected to, and is creating a huge matrix for computing a multiplication of shape (a, b) x (b, c) that is as big as (a, b). In my case, (a, b) is way bigger than (a, c). I guess this is not an issue anymore (unless dot_general is not expected to have this behaviour?).

Is there any suggestion on how I can avoid memory problems while computing (a, b) x (b, c) dot? Let's say (a, b) takes more than half of my memory, then computing (a, b) x (b, c) will be impossible using jax.jit. But in fact, it is very possible by implementing the for loop. This gives rise to three questions:

  • Is this behaviour expected in jax?
  • If so, is it okay to be this way? I think it shouldn't be okay as jax is designed for scientific computations and in such computations we might encounter a lot of situations where we are computing such a huge matrix multiplication where one side is huge and takes more than half of memory but the final result is actually very small and the sequential code will be fast
  • Is there any workaround for me to avoid this memory problem? (I can't really use the sequential version here, jax arrays are immutable, it's a bit of a hassle...)

Well again I feel like this is an issue, but not the issue that I mentioned in the first post. This is a problem with huge matrix multiplication in Jax now.

@mohamad-amin
Copy link
Author

mohamad-amin commented Nov 19, 2021

Well the work-around is this: (At least this is what comes to my mind:)

Z1 = lax.map(lambda X_i: np.einsum('j,jk->k', X_i, Y), X)
Z2 = X @ Y
np.alltrue(Z1 == Z2)
# True

This works fine but shouldn't this be automated if X @ Y doesn't fit in GPU or whatever memory?

Edit: Just realized that even this wouldn't work! Jax will autocompile this to dot_general again!

@mohamad-amin
Copy link
Author

Moreover, I noticed that calling lax_linalg.cholesky(A, False) on a 10 by 10 matrix A causes jax to soak up 100MB (!!) of memory. I'm really curious about why lax needs 100 MB of memory to compute cholesky factorization of a 10 by 10 matrix!

@mohamad-amin
Copy link
Author

Suggestion:

  • maybe jax should also have something like https://pytorch.org/docs/stable/generated/torch.bmm.html and automatically use it in vmaps (or pmaps and xmaps) whenever the vectorization causes OOM errors. This could maybe configured through a parameter that is passed in to vmap (optimize=False or True maybe?)

@hawkinsp
Copy link
Member

It's impossible for us to debug your problem without a complete, self-contained Python code that reproduces your problem. I don't know what is happening in place and what is happening out-of-place without debugging it, and I can't do that without a way to run the code.

I note that JAX does have batched matrix multiplication operator, and vmap and einsum will use it if applicable. So that's probably not the problem.

@hawkinsp hawkinsp added the needs info More information is required to diagnose & prioritize the issue. label Nov 22, 2021
@RylanSchaeffer
Copy link

@mohamad-amin , did you ever find a solution to this? We're hitting a similar OOM issue from pmap when trying to use a neural_tangents linearized vision transformer, even though we shouldn't have any issues.

@hawkinsp , I have a colab I can share: https://colab.research.google.com/drive/184moQLq3tjo-wEpc8gD7fXCFguAVDBOm#scrollTo=k4CjYqp5qLvj

@mohamad-amin
Copy link
Author

mohamad-amin commented Mar 7, 2022

@RylanSchaeffer Not yet, I solved my problem in another way though. I was also using pmap while computing the result of many kernel ridge regressions (in parallel) and the kernel ridge regression's code was mainly taken from the neural_tangents library's predict functions. How do you make sure that you shouldn't have any issues? Did you check the generated jax expressions?

@sudhakarsingh27 sudhakarsingh27 added P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) NVIDIA GPU Issues specific to NVIDIA GPUs labels Aug 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs info More information is required to diagnose & prioritize the issue. NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)
Projects
None yet
Development

No branches or pull requests

4 participants