Skip to content

Commit

Permalink
IPEX fix dtype errors when GPU supports 64 bit
Browse files Browse the repository at this point in the history
  • Loading branch information
Disty0 committed Dec 11, 2023
1 parent 0d805e3 commit 0fe2764
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 30 deletions.
5 changes: 0 additions & 5 deletions modules/intel/ipex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,6 @@ def ipex_init(): # pylint: disable=too-many-statements

ipex_hijacks()
if not torch.xpu.has_fp64_dtype():
try:
from .attention import attention_init
attention_init()
except Exception: # pylint: disable=broad-exception-caught
pass
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
Expand Down
29 changes: 8 additions & 21 deletions modules/intel/ipex/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
# pylint: disable=protected-access, missing-function-docstring, line-too-long

original_torch_bmm = torch.bmm
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)

#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
def torch_bmm_32_bit(input, mat2, *, out=None):
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
block_multiply = input.element_size()
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
Expand All @@ -17,7 +14,7 @@ def torch_bmm(input, mat2, *, out=None):
split_slice_size = batch_size_attention
if block_size > 4:
do_split = True
#Find something divisible with the input_tokens
# Find something divisible with the input_tokens
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
Expand All @@ -30,7 +27,7 @@ def torch_bmm(input, mat2, *, out=None):
if split_slice_size * slice_block_size > 4:
slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the input_tokens
# Find something divisible with the input_tokens
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
Expand Down Expand Up @@ -64,8 +61,8 @@ def torch_bmm(input, mat2, *, out=None):
return hidden_states

original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
if len(query.shape) == 3:
batch_size_attention, query_tokens, shape_four = query.shape
shape_one = 1
Expand All @@ -74,19 +71,14 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
no_shape_one = False

if query.dtype != key.dtype:
key = key.to(dtype=query.dtype)
if query.dtype != value.dtype:
value = value.to(dtype=query.dtype)

block_multiply = query.element_size()
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
block_size = batch_size_attention * slice_block_size

split_slice_size = batch_size_attention
if block_size > 6:
do_split = True
#Find something divisible with the shape_one
# Find something divisible with the shape_one
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
Expand All @@ -99,7 +91,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
if split_slice_size * slice_block_size > 6:
slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the batch_size_attention
# Find something divisible with the batch_size_attention
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
Expand Down Expand Up @@ -155,8 +147,3 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
)
return hidden_states

def attention_init():
#ARC GPUs can't allocate more than 4GB to a single block:
torch.bmm = torch_bmm
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
38 changes: 34 additions & 4 deletions modules/intel/ipex/hijacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,31 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
else:
return original_linalg_solve(A, B, *args, **kwargs)

if torch.xpu.has_fp64_dtype():
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else:
# 64 bit attention workarounds for Alchemist:
try:
from .attention import torch_bmm_32_bit as original_torch_bmm
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
except Exception: # pylint: disable=broad-exception-caught
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention

# dtype errors:
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out)

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
if query.dtype != key.dtype:
key = key.to(dtype=query.dtype)
if query.dtype != value.dtype:
value = value.to(dtype=query.dtype)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)

@property
def is_cuda(self):
return self.device.type == 'xpu'
Expand Down Expand Up @@ -184,11 +209,16 @@ def ipex_hijacks():
lambda orig_func, *args, **kwargs: True)

# Functions that make compile mad with CondFunc:
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
torch.nn.DataParallel = DummyDataParallel
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers

torch.autocast = ipex_autocast
torch.cat = torch_cat
torch.linalg.solve = linalg_solve
torch.backends.cuda.sdp_kernel = return_null_context
torch.UntypedStorage.is_cuda = is_cuda

torch.nn.functional.interpolate = interpolate
torch.backends.cuda.sdp_kernel = return_null_context
torch.linalg.solve = linalg_solve

torch.bmm = torch_bmm
torch.cat = torch_cat
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

0 comments on commit 0fe2764

Please sign in to comment.