Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triangular solve fails on batches of matrices of size > (*, 524280) #79191

Closed
CloudyDory opened this issue Jun 9, 2022 · 11 comments
Closed

Triangular solve fails on batches of matrices of size > (*, 524280) #79191

CloudyDory opened this issue Jun 9, 2022 · 11 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@CloudyDory
Copy link

CloudyDory commented Jun 9, 2022

馃悰 Describe the bug

An error "CUBLAS_STATUS_EXECUTION_FAILED when calling 'cublasStrsmBatched'" will be triggered when calculating the log probabilities of MultivariateNormal distribution on GPU and the number of data samples is larger than 524280.

Code to reproduce the problem:

import torch

device = torch.device("cuda")
# device = torch.device("cpu")
dtype = torch.float32

mean = torch.tensor([0.0, 0.0], dtype=dtype, device=device)
sd = torch.diag_embed(torch.tensor([1.0, 1.0], dtype=dtype, device=device))
distribution = torch.distributions.MultivariateNormal(mean, sd)

data1 = torch.randn([524280,2], dtype=dtype, device=device)
logprob = distribution.log_prob(data1)

data2 = torch.randn([524281,2], dtype=dtype, device=device)
logprob = distribution.log_prob(data2)

Result:

Calculating the log_prob of data1 runs without problem. Calculating the log_prob of data2 produce the following error:

Traceback (most recent call last):

  File ~/project/test_distributions.py:23 in <module>
    logprob = distribution.log_prob(data2)

  File ~/anaconda3/lib/python3.9/site-packages/torch/distributions/multivariate_normal.py:208 in log_prob
    M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)

  File ~/anaconda3/lib/python3.9/site-packages/torch/distributions/multivariate_normal.py:57 in _batch_mahalanobis
    M_swap = torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)  # shape = b x c

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasStrsmBatched( handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount)`

The code also runs without problem if switching the device to CPU.

Versions

The error can be reproduced on the following two systems.

System 1:

Collecting environment information...
PyTorch version: 1.11.0
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.20.0
Libc version: glibc-2.31

Python version: 3.9.12 (main, Apr  5 2022, 06:56:58)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.13.0-41-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.6.55
GPU models and configuration: 
GPU 0: NVIDIA A100-PCIE-40GB
GPU 1: NVIDIA A100-PCIE-40GB
GPU 2: NVIDIA A100-PCIE-40GB
GPU 3: NVIDIA A100-PCIE-40GB

Nvidia driver version: 510.39.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.5
[pip3] numpydoc==1.2
[pip3] torch==1.11.0
[pip3] torchaudio==0.11.0
[pip3] torchtext==0.12.0
[pip3] torchvision==0.12.0
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.3.1               h2bc3f7f_2  
[conda] libblas                   3.9.0            12_linux64_mkl    conda-forge
[conda] liblapack                 3.9.0            12_linux64_mkl    conda-forge
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0            py39h7f8727e_0  
[conda] mkl_fft                   1.3.1            py39hd3c417c_0  
[conda] mkl_random                1.2.2            py39h51133e4_0  
[conda] numpy                     1.21.5           py39he7a7128_2  
[conda] numpy-base                1.21.5           py39hf524024_2  
[conda] numpydoc                  1.2                pyhd3eb1b0_0  
[conda] pytorch                   1.11.0          py3.9_cuda11.3_cudnn8.2.0_0    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                0.11.0               py39_cu113    pytorch
[conda] torchtext                 0.12.0                     py39    pytorch
[conda] torchvision               0.12.0               py39_cu113    pytorch

System 2:

Collecting environment information...
PyTorch version: 1.11.0
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Enterprise
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.19.6
Libc version: N/A

Python version: 3.9.12 (main, Apr  4 2022, 05:22:27) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19044-SP0
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2080 SUPER
Nvidia driver version: 510.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.5
[pip3] numpydoc==1.2
[pip3] torch==1.11.0
[pip3] torchaudio==0.11.0
[pip3] torchvision==0.12.0
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               11.3.1               h59b6b97_2
[conda] mkl                       2021.4.0           haa95532_640
[conda] mkl-service               2.4.0            py39h2bbff1b_0
[conda] mkl_fft                   1.3.1            py39h277e83a_0
[conda] mkl_random                1.2.2            py39hf11a4ad_0
[conda] numpy                     1.21.5           py39h7a0a035_2
[conda] numpy-base                1.21.5           py39hca35cd5_2
[conda] numpydoc                  1.2                pyhd3eb1b0_0
[conda] pytorch                   1.11.0          py3.9_cuda11.3_cudnn8_0    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                0.11.0               py39_cu113    pytorch
[conda] torchvision               0.12.0               py39_cu113    pytorch

cc @ezyang @gchanan @zou3519 @fritzo @neerajprad @alicanb @nikitaved @ngimel @jianyuh @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano

@CloudyDory CloudyDory changed the title CUBLAS_STATUS_EXECUTION_FAILED when calling 'cublasStrsmBatched' in torch.distributions.MultivariateNormal module when calculating log_prob of large matrix Pytorch fails to calculate the log_prob of a multivariate normal distribution when the number of data samples is larger than 524280 Jun 9, 2022
@CloudyDory CloudyDory changed the title Pytorch fails to calculate the log_prob of a multivariate normal distribution when the number of data samples is larger than 524280 PyTorch fails to calculate the log_prob of a multivariate normal distribution when the number of data samples is larger than 524280 Jun 9, 2022
@zou3519 zou3519 added module: distributions Related to torch.distributions triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 9, 2022
@zou3519
Copy link
Contributor

zou3519 commented Jun 9, 2022

I can reproduce this

@albanD albanD added module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul and removed triage review high priority labels Jun 13, 2022
@ngimel
Copy link
Collaborator

ngimel commented Jun 13, 2022

For the passing case, the profile shows

4.62631s  179.33us          (1 65535 1)        (64 4 1)        39  2.3750KB        0B         -           -           -           -  Tesla V100-SXM2         1         7  void batch_trsm_left_kernel<float, int=64, int=4, int=3, bool=0, bool=0, bool=0>(cublasTrsmBatchParams2<float>, float, float const *, int) [5332]

so with a bigger batch (?) the gridDim.y becomes 65536 which is more than a max.

@ngimel
Copy link
Collaborator

ngimel commented Jun 13, 2022

Smaller repro

import torch
device="cuda"
A = torch.randn(1, 2, 2, device=device).tril_()
B = torch.randn(1, 2, 524281, device=device)
X = torch.linalg.solve_triangular(A, B, upper=False)

@lezcano
Copy link
Collaborator

lezcano commented Jun 14, 2022

I believe we already encountered this in other operations and we resolved that there's not much we can do about it? @IvanYashchuk @xwang233

@ngimel
Copy link
Collaborator

ngimel commented Jun 14, 2022

It's the number of rhs's you have for triangular solve, so you can split the computation to handle rhs's in batches?

@ngimel ngimel removed the module: distributions Related to torch.distributions label Jun 14, 2022
@IvanYashchuk
Copy link
Collaborator

Failure of trsmBatched was briefly mentioned in #75494 (comment). And the same problem was reported in #67013 but got ignored.

@CloudyDory
Copy link
Author

So what is the cause of the problem? Are there any intuitive explanations? Is it fixable?

@timwaite
Copy link

timwaite commented Aug 31, 2023

Inspired by this discussion I tried an approach based on repeating the mean and covariance matrix as a batch.

My approach seems to work reliably for data at least 128 times bigger than @CloudyDory's example, i.e. $n=2^{19+7}$ samples. I haven't tested further as I ran out of VRAM.

## as in the OP
import torch
device = torch.device("cuda")
dtype = torch.float32
mean = torch.tensor([0.0, 0.0], dtype=dtype, device=device)
sd = torch.diag_embed(torch.tensor([1.0, 1.0], dtype=dtype, device=device))

## solution
n=2**(19+7)
data_n = torch.randn([n,2], dtype=dtype, device=device)
distribution_n = torch.distributions.MultivariateNormal(mean.repeat(n,1), sd.repeat(n,1,1))
distribution_n.log_prob(data_n) 

@lezcano @IvanYashchuk @xwang233 do you think this repeating idea could be incorporated into the MultivariateNormal class so that OP's code works?

@lezcano lezcano changed the title PyTorch fails to calculate the log_prob of a multivariate normal distribution when the number of data samples is larger than 524280 Triangular solve fails on batches of matrices of size > (*, 524280) Aug 31, 2023
@lezcano
Copy link
Collaborator

lezcano commented Aug 31, 2023

@timwaite alas, that's just a hack, and it'll probably be very slow. We should implmenet #79191 (comment) or #97211 (comment) as a workaround before this is fixed in cusolver (cc @xwang233).

@timwaite
Copy link

timwaite commented Sep 5, 2023

@lezcano thanks for the reply.

The following slightly better workaround worked for my use case. (Computing the log probability of a low dimensional MVN distribution with many samples). I have posted it in case it is useful to others.

The code needs some more input checking etc, but aside from those issues I don't think there is any drawback to doing the MVN calculation this way when on a CUDA device. As far as I can tell:

  1. the results are the same to within numerical tolerance (see checking below)
  2. there is no performance penalty to this approach (see checking below)
  3. this code would not encounter the same CUBLAS failure mode

I wonder if a simple fix along these lines would be good enough then for the MVN (regardless of what happens with triangular_solve)?

Of course it is possible I made a mistake, so it would be good for others to check.

The main difference, in my use case, is:

  • the current Pytorch implementation ultimately launches a single batch triangular solve with a wide RHS
  • the below just uses a batch of small matrices

Reading the CUDA documentation the latter approach seems to be more how cublasStrsmBatched is intended to be used, so I wonder if this is actually the correct approach.

This function works for any sizes but is intended to be used for matrices of small sizes where the launch overhead is a significant factor. For bigger sizes, it might be advantageous to call batchCount times the regular [cublas<t>trsm](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-trsm) within a set of CUDA streams.

I don't think this code would experience the same CUBLAS failure mode ever, as it seems to be due to wide right hand sides, whereas this method has RHS of width 1. I tested high dimensional distributions until my VRAM ran out..

Code for log probability density:

import torch
import math

def mvnlogprob(dist: torch.distributions.multivariate_normal.MultivariateNormal, 
               inputs: torch.tensor):
    p = dist.loc.size(0)
    diff = inputs - dist.loc
    
    batch_shape = diff.shape[:-1]
    
    scale_shape = dist.scale_tril.size()
    
    _scale_tril = dist.scale_tril.expand(batch_shape+scale_shape)
    z = torch.linalg.solve_triangular(_scale_tril,
                                      diff.unsqueeze(-1), 
                                      upper=False).squeeze()
    
    out=  -0.5*p*torch.tensor(2*math.pi).log() - dist.scale_tril.logdet() -0.5*(z**2).sum(dim=-1)
    return out.squeeze()

Comparison results:

n =  1000
Relative error:  1.1910433528328213e-07
Relative speed of new method:  0.9748199294571058

n =  10000
Relative error:  1.1919563291939994e-07
Relative speed of new method:  1.0344244369857059

n =  100000
Relative error:  1.1920259623821039e-07
Relative speed of new method:  1.14061501003811

Code to compare performance:


import torch
import time 
from mvnlogprobfix import mvnlogprob
    
def compare_logprob_calcs(sample_size, dist):
    sample = dist.sample([sample_size])
    # compute result using old and new method 
    test = mvnlogprob(dist, sample)
    ref= dist.log_prob(sample)
    
    # compare answers
    print("\n\nn = ", sample_size)
    relerr = (test-ref)/ref
    
    print("Relative error: ", relerr.abs().max().item())

    # timings 
    # gpu warm up just in case 
    A = torch.randn([1000,1000])
    B = torch.randn([1000,1000])
    torch.matmul(A,B)
    
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(5000):
        test = mvnlogprob(dist, sample)
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    
    
    torch.cuda.synchronize()
    t2 = time.perf_counter()
    for _ in range(5000):
        ref = dist.log_prob(sample)
    torch.cuda.synchronize()
    t3 = time.perf_counter()
    
    print("Relative speed of new method: ", (t3-t2)/(t1-t0))
    return 0


device='cuda'
prior_mean = torch.tensor([1.0,2.0], device=device)
prior_Sigma = torch.tensor([[0.5**2,-0.01], [-0.01,1.0**2]], device=device)
prior = torch.distributions.multivariate_normal.MultivariateNormal(prior_mean, prior_Sigma)

compare_logprob_calcs(1000, prior)
compare_logprob_calcs(10000, prior)
compare_logprob_calcs(100000, prior)

sample = prior.sample([2**19])
mvnlogprob(prior,sample) # runs fine
prior.log_prob(sample) # fails with CUBLAS error 

PS apologies if there are any glaring errors, I am quite new to Python and Pytorch!

@lezcano
Copy link
Collaborator

lezcano commented Jan 16, 2024

This looks like it's been fixed for CUDA >= 12.1. I'll implement a fix for previous versions tho.

lezcano added a commit that referenced this issue Jan 17, 2024
lezcano added a commit that referenced this issue Jan 17, 2024
Fix #79191

ghstack-source-id: e6167046210e6b18780d62f695c35134f82d8a93
Pull Request resolved: #117636
lezcano added a commit that referenced this issue Jan 17, 2024
Fix #79191

cc jianyuh nikitaved pearu mruberry walterddr xwang233 Lezcano

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jan 17, 2024
Fix #79191

ghstack-source-id: 85c803dc9f8656d2397120b79f9b86191c6c0995
Pull Request resolved: #117636
lezcano added a commit that referenced this issue Jan 18, 2024
Fix #79191

cc jianyuh nikitaved pearu mruberry walterddr xwang233 Lezcano

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jan 18, 2024
Fix #79191

ghstack-source-id: fe3e453835266f7c6d158240a4db2bb28be18db1
Pull Request resolved: #117636
lezcano added a commit that referenced this issue Jan 19, 2024
Fix #79191

cc jianyuh nikitaved pearu mruberry walterddr xwang233 Lezcano

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jan 19, 2024
Fix #79191

ghstack-source-id: 261084a44da9ccfbcfa8c471931edeb401978a66
Pull Request resolved: #117636
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants