Skip to content

Commit

Permalink
Disable FP8 cache support when CUDA version < 11.8
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Oct 14, 2023
1 parent 5cce46a commit 4a5920a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
3 changes: 3 additions & 0 deletions exllamav2/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def __init__(self, model, batch_size = 1, max_seq_len = -1, copy_from = None):
self.is_8bit = self.model.config.kv_cache_mask is not None and self.model.config.kv_cache_mask == '8bit'

if self.is_8bit:
cuda_version = torch.version.cuda.split('.')
cuda_version = [int(x) for x in cuda_version[:2]]
assert cuda_version >= [11, 8], " ## 8-bit (FP8) cache requires CUDA version 11.8 or greater"
self.cached = Cache8Bit(model, self.batch_size, self.max_seq_len, num_key_value_heads, head_dim, num_hidden_layers, copy_from)
else:
self.cached = Cache16Bit(model, self.batch_size, self.max_seq_len, num_key_value_heads, head_dim, num_hidden_layers, copy_from)
Expand Down
12 changes: 12 additions & 0 deletions exllamav2/exllamav2_ext/cuda/cache.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#include "cache.cuh"

#if defined(CUDART_VERSION) && CUDART_VERSION >= 11080

#include <cuda_fp8.h>

// TODO: Kernel profiling

__global__ void nv_fp16_to_fp8(const half* pIn, unsigned char *pOut, int size) {
Expand Down Expand Up @@ -34,3 +38,11 @@ void array_fp8_to_fp16_cuda(const unsigned char* pIn, half* pOut, int size) {
int blocks = (size + threads - 1) / threads;
nv_fp8_to_fp16<<<blocks, threads>>>(pIn, pOut, size);
}

#else

void array_fp16_to_fp8_cuda(const half* pIn, unsigned char *pOut, int size) { }

void array_fp8_to_fp16_cuda(const unsigned char* pIn, half* pOut, int size) { }

#endif
4 changes: 3 additions & 1 deletion exllamav2/exllamav2_ext/cuda/cache.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef _cache_cuh
#define _cache_cuh
#include <cuda_fp8.h>

#include <cuda_runtime.h>
#include <cuda_fp16.h>

void array_fp16_to_fp8_cuda(const half* pIn, unsigned char *pOut, int size);
void array_fp8_to_fp16_cuda(const unsigned char* pIn, half* pOut, int size);
Expand Down

0 comments on commit 4a5920a

Please sign in to comment.