This is a fork of https://github.com/facebookresearch/llama that runs on Apple M2 (MPS - Metal Performance Shaders).
Please refer to the official installation, licences and usage instructions on the facebookresearch/llama page.
- Note: user needs to set PYTORCH_ENABLE_MPS_FALLBACK=1 env variable to run this code.
- This is a workaround for unsupported 'aten:polar.out' operator.
- If you see the following message, it is expected. It falls back to CPU for that specific operation and the warning is to inform the user about it: "UserWarning: The operator 'aten::polar.out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1aidzjezue/croot/pytorch_1687856425340/work/aten/src/ATen/mps/MPSFallback.mm:11.) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64"
The code should run past that as shown in this screenshot.