Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YOLOv5 Apple Metal Performance Shader (MPS) support #7878

Merged
merged 5 commits into from
May 24, 2022
Merged

Conversation

glenn-jocher
Copy link
Member

@glenn-jocher glenn-jocher commented May 18, 2022

Following https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/ posted in pytorch/pytorch#47702

Should work with Apple M1 devices with PyTorch nightly installed with command --device mps. Usage examples:

python train.py --device mps
python detect.py --device mps
python val.py --device mps

EDIT: Requires universal2 installer with Python>=3.9.1 from https://www.python.org/downloads/macos/ using command:

pip install --pre -r requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu

# Install torchvision nightly
pip uninstall -y torchvision
git clone https://github.com/pytorch/vision
cd vision
MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install

EDIT2: Pending new nightly torchvision arm64 distribution in pytorch/vision#6050
EDIT3: PR is merged, PyTorch seems to have a few TODOs on their side which are out of my control: 1) create torchvision nightly arm64, 2) resolve buffer is not large enough error reported by many users pytorch/pytorch#77748 (comment)

🛠️ PR Summary

Made with ❤️ by Ultralytics Actions

🌟 Summary

Enhanced device compatibility in model loading and environment configuration for Ultralytics YOLOv5.

📊 Key Changes

  • attempt_load function signature: Replaced map_location parameter with device for clarity.
  • Model loading: Streamlined by using the device argument directly.
  • Device selection: Introduced support for Apple Metal Performance Shaders (MPS) by extending the select_device function.

🎯 Purpose & Impact

  • Improved Device Handling: Simplifies the device specification when loading models, enhancing code readability and reducing potential confusion.
  • Broader Support: Adding MPS support opens up efficient GPU-accelerated inference for Mac users with Apple Silicon, improving performance on these devices.
  • Streamlined User Experience: Users can expect more intuitive interactions with the library when working across different hardware platforms, making advanced AI more accessible.

Following https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/

Should work with Apple M1 devices with PyTorch nightly installed with command `--device mps`. Usage examples:
```bash
python train.py --device mps
python detect.py --device mps
python val.py --device mps
```
@glenn-jocher glenn-jocher self-assigned this May 18, 2022
@glenn-jocher
Copy link
Member Author

Raised issue in pytorch/pytorch#77748

@glenn-jocher
Copy link
Member Author

glenn-jocher commented May 18, 2022

Python-3.9.13 torch-1.11.0 (Macbook Air M1) - CPU

(venv) (base) glennjocher@Glenns-MacBook-Air yolov5 % python detect.py
detect: weights=yolov5s.pt, source=data/images, data=data/coco128.yaml, imgsz=[640, 640], conf_thres=0.25, iou_thres=0.45, max_det=1000, device=, view_img=False, save_txt=False, save_conf=False, save_crop=False, nosave=False, classes=None, agnostic_nms=False, augment=False, visualize=False, update=False, project=runs/detect, name=exp, exist_ok=False, line_thickness=3, hide_labels=False, hide_conf=False, half=False, dnn=False
YOLOv5 🚀 v6.1-212-g7c13c46 Python-3.9.13 torch-1.11.0 CPU

Fusing layers... 
YOLOv5s summary: 213 layers, 7225885 parameters, 0 gradients
image 1/2 /Users/glennjocher/PycharmProjects/yolov5/data/images/bus.jpg: 640x480 4 persons, 1 bus, Done. (0.084s)
image 2/2 /Users/glennjocher/PycharmProjects/yolov5/data/images/zidane.jpg: 384x640 2 persons, 2 ties, Done. (0.068s)
Speed: 0.4ms pre-process, 76.1ms inference, 0.5ms NMS per image at shape (1, 3, 640, 640)

Python-3.9.13 torch-1.11.0 (Macbook Air M1) - MPS

python detect.py --device mps
TODO

@glenn-jocher glenn-jocher changed the title Apple Metal Performance Shader (MPS) device support YOLOv5 Apple Metal Performance Shader (MPS) device support May 19, 2022
@glenn-jocher glenn-jocher changed the title YOLOv5 Apple Metal Performance Shader (MPS) device support YOLOv5 Apple Metal Performance Shader (MPS) support May 19, 2022
@glenn-jocher glenn-jocher mentioned this pull request May 19, 2022
2 tasks
@glenn-jocher glenn-jocher merged commit c215878 into master May 24, 2022
@glenn-jocher glenn-jocher deleted the apple/mps branch May 24, 2022 11:34
@glenn-jocher glenn-jocher removed the TODO label May 24, 2022
tdhooghe pushed a commit to tdhooghe/yolov5 that referenced this pull request Jun 10, 2022
* Apple Metal Performance Shader (MPS) device support

Following https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/

Should work with Apple M1 devices with PyTorch nightly installed with command `--device mps`. Usage examples:
```bash
python train.py --device mps
python detect.py --device mps
python val.py --device mps
```

* Update device strategy to fix MPS issue
@RacerChen
Copy link

It seems that I still cannot use GPU of M1 chip to train YOLOv5 model. I followed the instructions before, and trained models with --device mps. But I got this
image
Is the issue from the pytorch side unsolved still?

@glenn-jocher
Copy link
Member Author

@RacerChen if you've installed pytorch nightly and you have a supported device then the correct usage example would be:

python train.py --device mps
python detect.py --device mps
etc...

@RacerChen
Copy link

Thanks for answering. My machine is MacBook Air M1 2020. I already installed the pytorch nightly. I trained model with --device mps, but still got this RuntimeError:

/Users/cjj/Desktop/YOLOv5/venv/bin/python /Users/cjj/Desktop/YOLOv5/yolov5/train.py --device mps --data coco128.yaml --weights '' --cfg yolov5s.yaml --img 640
train: weights='', cfg=yolov5s.yaml, data=coco128.yaml, hyp=data/hyps/hyp.scratch-low.yaml, epochs=300, batch_size=16, imgsz=640, rect=False, resume=False, nosave=False, noval=False, noautoanchor=False, noplots=False, evolve=None, bucket=, cache=None, image_weights=False, device=mps, multi_scale=False, single_cls=False, optimizer=SGD, sync_bn=False, workers=8, project=runs/train, name=exp, exist_ok=False, quad=False, cos_lr=False, label_smoothing=0.0, patience=100, freeze=[0], save_period=-1, local_rank=-1, entity=None, upload_dataset=False, bbox_interval=-1, artifact_alias=latest
github: up to date with https://github.com/ultralytics/yolov5 ✅
YOLOv5 🚀 v6.1-253-g75bbaa8 Python-3.10.5 torch-1.11.0 MPS

Traceback (most recent call last):
  File "/Users/cjj/Desktop/YOLOv5/yolov5/train.py", line 670, in <module>
    main(opt)
  File "/Users/cjj/Desktop/YOLOv5/yolov5/train.py", line 551, in main
    device = select_device(opt.device, batch_size=opt.batch_size)
  File "/Users/cjj/Desktop/YOLOv5/yolov5/utils/torch_utils.py", line 83, in select_device
    return torch.device('cuda:0' if cuda else 'mps' if mps else 'cpu')
RuntimeError: Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: mps

Process finished with exit code 1

By the way, if I remove the --device mps, the CPU training works well. But I wanna use the GPU.

@Djordi97
Copy link

@RacerChen had the same issue here. Try uninstalling torch, torchvision and torchaudio by running pip3 uninstall torch torchvision torchaudio as mentioned in https://discuss.pytorch.org/t/how-to-check-mps-availability/152015/3. Then install nightly version of torch again and it should work.

@RacerChen
Copy link

@Djordi97 Thanks a lot, I got it. And there is an easy way that cloning a new yolov5 project and configuring it again. : )
Also, check the MacOS version >12.3.

Now it works, but not totally. I am now facing the same problem of Error: buffer is not large enough. Must be 19200 bytes

@glenn-jocher
Copy link
Member Author

@RacerChen the buffer is not large enough. Must be 19200 bytes error is a torch error that I believe they are aware of.

@mabedd
Copy link

mabedd commented Jul 6, 2022

By using the mentioned command to start training
python3 train.py --device mps --img 640 --batch 16 --epochs 3 --data cells.yaml --weights yolov5s.pt
it generates the following error message:
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Note that installed PyTorch Nighty from their official website using the mentioned commands.

@glenn-jocher
Copy link
Member Author

@mohammed-ab99 PyTorch team is aware of ongoing MPS issues tracked in pytorch/pytorch#77886 but I can't tell from your message if this falls under that. Are you saying detect.py --device mps works correctly but not train.py --device mps?

@mabedd
Copy link

mabedd commented Jul 7, 2022

@glenn-jocher I am trying to train on my custom data and falling with this error. It is generated after running train.py --device.

This is the complete traceback:
Traceback (most recent call last): File "train.py", line 666, in <module> main(opt) File "train.py", line 561, in main train(opt.hyp, opt, device, callbacks) File "train.py", line 285, in train model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Also there is a warning that is being generated before throwing the error:
/Users/mabed/Dev/Repos/yolov5/utils/general.py:812: UserWarning: The operator 'aten::nonzero' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)

@glenn-jocher
Copy link
Member Author

@mohammed-ab99 the first might be resolved by reducing any FP64 variables to FP32. Do you know which variable is FP64?

The second issue is already open in #8508

@mabedd
Copy link

mabedd commented Jul 7, 2022

@glenn-jocher Actually I am using the code as is without any modifications, but according to the traceback it is in this line:
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights

This is the lables_to_class_weights function
`def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels
if labels[0] is None: # no labels loaded
return torch.Tensor()

labels = np.concatenate(labels, 0)  # labels.shape = (866643, 5) for COCO
classes = labels[:, 0].astype(np.int)  # labels = [class xywh]
weights = np.bincount(classes, minlength=nc)  # occurrences per class

# Prepend gridpoint count (for uCE training)
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum()  # gridpoints per image
# weights = np.hstack([gpi * len(labels)  - weights.sum() * 9, weights * 9]) ** 0.5  # prepend gridpoints to start

weights[weights == 0] = 1  # replace empty bins with 1
weights = 1 / weights  # number of targets per class
weights /= weights.sum()  # normalize
return torch.from_numpy(weights)`

But I am not sure if this is the variable or it is another one.

glenn-jocher added a commit that referenced this pull request Jul 7, 2022
@glenn-jocher
Copy link
Member Author

@mohammed-ab99 good news 😃! Your original issue may now be fixed ✅ in PR #8511. To receive this update:

  • Gitgit pull from within your yolov5/ directory or git clone https://github.com/ultralytics/yolov5 again
  • PyTorch Hub – Force-reload model = torch.hub.load('ultralytics/yolov5', 'yolov5s', force_reload=True)
  • Notebooks – View updated notebooks Open In Colab Open In Kaggle
  • Dockersudo docker pull ultralytics/yolov5:latest to update your image Docker Pulls

Thank you for spotting this issue and informing us of the problem. This likely won't resolve all issues for you, so if you run into another error on training with MPS please let us know.

@mabedd
Copy link

mabedd commented Jul 7, 2022

@glenn-jocher thanks the error disappeared now. However, I think that now I have fallen to PyTorch support problem for MPS as this error appeared:
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
by running the command:
PYTORCH_ENABLE_MPS_FALLBACK=1 python3 train.py --device mps --img 640 --batch 16 --epochs 3 --data cells.yaml --weights yolov5s.pt

I will be following up with the other issue.
Thanks

@glenn-jocher
Copy link
Member Author

@mohammed-ab99 what line is causing that error?

@mabedd
Copy link

mabedd commented Jul 7, 2022

@glenn-jocher This is the traceback:

Traceback (most recent call last):
  File "train.py", line 667, in <module>
    main(opt)
  File "train.py", line 562, in main
    train(opt.hyp, opt, device, callbacks)
  File "train.py", line 353, in train
    loss, loss_items = compute_loss(pred, targets.to(device))  # loss scaled by batch_size
  File "/Users/mabed/Dev/Repos/yolov5/utils/loss.py", line 125, in __call__
    tcls, tbox, indices, anchors = self.build_targets(p, targets)  # targets
  File "/Users/mabed/Dev/Repos/yolov5/utils/loss.py", line 208, in build_targets
    t = t[j]  # filter

This is the loop inside loss.py

    for i in range(self.nl):
        anchors, shape = self.anchors[i], p[i].shape
        gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]]  # xyxy gain

        # Match targets to anchors
        t = targets * gain  # shape(3,n,7)
        if nt:
            # Matches
            r = t[..., 4:6] / anchors[:, None]  # wh ratio
            j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t']  # compare
            # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
            t = t[j]  # filter

            # Offsets
            gxy = t[:, 2:4]  # grid xy
            gxi = gain[[2, 3]] - gxy  # inverse
            j, k = ((gxy % 1 < g) & (gxy > 1)).T
            l, m = ((gxi % 1 < g) & (gxi > 1)).T
            j = torch.stack((torch.ones_like(j), j, k, l, m))
            t = t.repeat((5, 1, 1))[j]
            offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
        else:
            t = targets[0]
            offsets = 0

        # Define
        bc, gxy, gwh, a = t.chunk(4, 1)  # (image, class), grid xy, grid wh, anchors
        a, (b, c) = a.long().view(-1), bc.long().T  # anchors, image, class
        gij = (gxy - offsets).long()
        gi, gj = gij.T  # grid indices

        # Append
        indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))  # image, anchor, grid
        tbox.append(torch.cat((gxy - gij, gwh), 1))  # box
        anch.append(anchors[a])  # anchors
        tcls.append(c)  # class

    return tcls, tbox, indices, anch

`

@glenn-jocher
Copy link
Member Author

glenn-jocher commented Jul 8, 2022

@mohammed-ab99 got it. Seems like aten::nonzero is required for the indexing op on loss.py L208, as well as in NMS. I would stop using PYTORCH_ENABLE_MPS_FALLBACK=1 and start debugging loss.py L208 to see if you can restructure this op in a different way that bypasses the aten:nonzero requirement. I don't have availability right now to do this but I'll add a TODO to track this closer.

t = t[j] # filter

@glenn-jocher
Copy link
Member Author

@mohammed-ab99 for example you could try using https://pytorch.org/docs/stable/generated/torch.index_select.html

@glenn-jocher
Copy link
Member Author

@mohammed-ab99 I noticed that j is also a boolean tensor. Perhaps you need to use torch.nonzero to get True indices on the boolean vector and then that might work. https://pytorch.org/docs/stable/generated/torch.nonzero.html

@mabedd
Copy link

mabedd commented Jul 8, 2022

Hopefully this can be fixed later on as well as the MPS officially. Either ways, thanks for your notes I ll check and let you know.

@glenn-jocher
Copy link
Member Author

@mohammed-ab99 well yes, ideally the torch team should fix this but without a clear schedule we should try to debug alternative implementations on our end, making sure to profile any changes for speed differences.

@glenn-jocher
Copy link
Member Author

@mohammed-ab99 I should be able to test on our M1 Macbook this weekend.

@mabedd
Copy link

mabedd commented Jul 9, 2022

That sounds good !!
Please keep us updated regarding any solutions.

Thanks dear.

Shivvrat pushed a commit to Shivvrat/epic-yolov5 that referenced this pull request Jul 12, 2022
@interactivetech
Copy link

interactivetech commented Jul 15, 2022

I am running into the same issue that @mohammed-ab99 is having #7878 (comment)

I attempted to replace t[j] to t = torch.index_select(t,dim=0, index=j.nonzero().int().reshape(-1)), but I am running into an issue when I test on CPU with the following error:

  File "train.py", line 667, in <module>
    main(opt)
  File "train.py", line 562, in main
    train(opt.hyp, opt, device, callbacks)
  File "train.py", line 353, in train
    loss, loss_items = compute_loss(pred, targets.to(device))  # loss scaled by batch_size
  File "/Users/mendeza/Documents/projects/yolov5/utils/loss.py", line 125, in __call__
    tcls, tbox, indices, anchors = self.build_targets(p, targets)  # targets
  File "/Users/mendeza/Documents/projects/yolov5/utils/loss.py", line 219, in build_targets
    t = torch.index_select(t,dim=0, index=j_ind.reshape(-1))
IndexError: index out of range in self

I am able to show that index_select with nonzero is the same as indexing on nonzero for 2D case, but 3D I am having a hard time how to reshape:

e = torch.eye(2)
shap = e.shape
e2 = e.index_select(0,index=e.nonzero().int().flatten()).reshape(shap[0],-1,shap[1])
print(torch.equal(e[e.nonzero()],e2))

@glenn-jocher glenn-jocher removed the TODO label Jul 30, 2022
ctjanuhowski pushed a commit to ctjanuhowski/yolov5 that referenced this pull request Sep 8, 2022
* Apple Metal Performance Shader (MPS) device support

Following https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/

Should work with Apple M1 devices with PyTorch nightly installed with command `--device mps`. Usage examples:
```bash
python train.py --device mps
python detect.py --device mps
python val.py --device mps
```

* Update device strategy to fix MPS issue
ctjanuhowski pushed a commit to ctjanuhowski/yolov5 that referenced this pull request Sep 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants