Skip to content

Commit

Permalink
Disable non blocking on mps.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Dec 10, 2023
1 parent 614b7e7 commit 340177e
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,15 +553,19 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu():
device_supports_cast = True

non_blocking = True
if is_device_mps(device):
non_blocking = False #pytorch bug? mps doesn't support non blocking

if device_supports_cast:
if copy:
if tensor.device == device:
return tensor.to(dtype, copy=copy, non_blocking=True)
return tensor.to(device, copy=copy, non_blocking=True).to(dtype, non_blocking=True)
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, non_blocking=True).to(dtype, non_blocking=True)
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, dtype, copy=copy, non_blocking=True)
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)

def xformers_enabled():
global directml_enabled
Expand Down

0 comments on commit 340177e

Please sign in to comment.