Skip to content

Commit

Permalink
Fix stable diffusion output error on WebGPU (tinygrad#2032)
Browse files Browse the repository at this point in the history
* Fix stable diffusion on WebGPU

* Remove hack, numpy cast only on webgpu

* No-copy numpy cast
  • Loading branch information
wpmed92 committed Oct 10, 2023
1 parent e40f141 commit e27fedf
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions examples/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,12 +651,14 @@ def do_step(latent, timestep, index):

# make image correct size and scale
x = (x + 1.0) / 2.0
x = (x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255).cast(dtypes.uint8)
x = (x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255)
if Device.DEFAULT != "WEBGPU": x = x.cast(dtypes.uint8)
print(x.shape)

# save image
from PIL import Image
im = Image.fromarray(x.numpy())
import numpy as np
im = Image.fromarray(x.numpy().astype(np.uint8, copy=False))
print(f"saving {args.out}")
im.save(args.out)
# Open image.
Expand Down

0 comments on commit e27fedf

Please sign in to comment.