Skip to content

Commit

Permalink
Added version 0.37.0.
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Feb 2, 2023
1 parent de53588 commit 0f5c394
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,15 @@ Improvements:
- StableEmbedding layer now has device and dtype parameters to make it 1:1 replaceable with regular Embedding layers (@lostmsu)
- runtime performance of block-wise quantization slightly improved
- added error message for the case multiple libcudart.so are installed and bitsandbytes picks the wrong one


### 0.37.0

#### Int8 Matmul + backward support for all GPUs

Features:
- Int8 MatmulLt now supports backward through inversion of the ColTuring/ColAmpere format. Slow, but memory efficient. Big thanks to @borzunov
- Int8 now supported on all GPUs. On devices with compute capability < 7.5, the Int weights are cast to 16/32-bit for the matrix multiplication. Contributed by @borzunov

Improvements:
- Improved logging for the CUDA detection mechanism.
13 changes: 7 additions & 6 deletions bitsandbytes/cuda_setup/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ def generate_instructions(self):
self.add_log_entry('python setup.py install')

def initialize(self):
self.has_printed = False
self.lib = None
self.initialized = False
if not getattr(self, 'initialized', False):
self.has_printed = False
self.lib = None
self.initialized = False

def run_cuda_setup(self):
self.initialized = True
Expand All @@ -103,7 +104,7 @@ def run_cuda_setup(self):
legacy_binary_name = "libbitsandbytes_cpu.so"
self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
binary_path = package_dir / legacy_binary_name
if not binary_path.exists():
if not binary_path.exists() or torch.cuda.is_available():
self.add_log_entry('')
self.add_log_entry('='*48 + 'ERROR' + '='*37)
self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:')
Expand All @@ -112,6 +113,7 @@ def run_cuda_setup(self):
self.add_log_entry('3. You have multiple conflicting CUDA libraries')
self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!')
self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
self.add_log_entry('CUDA SETUP: The CUDA version for the compile might depend on your conda install. Inspect CUDA version via `conda list | grep cuda`.')
self.add_log_entry('='*80)
self.add_log_entry('')
self.generate_instructions()
Expand Down Expand Up @@ -148,7 +150,7 @@ def is_cublasLt_compatible(cc):
if cc is not None:
cc_major, cc_minor = cc.split('.')
if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5):
cuda_setup.add_log_entry("WARNING: Compute capability < 7.5 detected! Proceeding to load CPU-only library...", is_warning=True)
cuda_setup.add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!", is_warning=True)
else:
has_cublaslt = True
return has_cublaslt
Expand Down Expand Up @@ -362,7 +364,6 @@ def evaluate_cuda_setup():
print('')
print('='*35 + 'BUG REPORT' + '='*35)
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print('='*80)
if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def read(fname):

setup(
name=f"bitsandbytes",
version=f"0.36.0-2",
version=f"0.37.0",
author="Tim Dettmers",
author_email="[email protected]",
description="8-bit optimizers and matrix multiplication routines.",
Expand Down

0 comments on commit 0f5c394

Please sign in to comment.