Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nd convolution and pooling with cuDNN #3983

Open
wants to merge 32 commits into
base: master
Choose a base branch
from

Conversation

christianpayer
Copy link

This branch implements n-dimensional convolution and pooling with cudnn, similar to #2824 and #3515. In contrast to #2824 and #3515 this PR does not create new layers, but is built on top of the existing convolution and pooling layers.

The convolution layer already has an interface for nd convolutions (#2049), but the pooling layer has not. I changed the interface of the pooling layer to support nd pooling similar to #2049. This new interface made changes in the caffe internal pooling and some test files necessary, but I did not change the caffe CPU and GPU pooling to support nd pooling. Nd pooling (2d, 3d and possibly more) is only working for average pooling using cudnn.

I also changed some calls to legacy shape accessors (num(), channels(), etc.) when I needed them. There are still many layers that do not support the new shape() accessors. So if you encounter errors from blob.hpp saying "Cannot use legacy accessors on Blobs with > 4 axes.", check the layer for the legacy accessors and make the necessary changes.

I tested the code with cudnn v4 for 2d and 3d convolutions and poolings. I am not sure, if older/other versions of cudnn would work.
All test cases pass with the new pooling interface. Currently, I did not create any new cases that test nd convolution and pooling, but I'm planning to implement some.

@christianpayer
Copy link
Author

If you want to use nd pooling without cudnn, see #2442. The new pooling interface that I implemented is (almost) the same as for #2442. So if #2442 is merged into master, I will adapt and rebase my code.

@jcpeterson
Copy link

Do you plan to add MAX pooling? Also, it appears the checks have failed.

@shelhamer shelhamer mentioned this pull request Apr 13, 2016
@christianpayer
Copy link
Author

@jcpeterson I did not include MAX pooling in this PR, as it would change the behaviour of caffe. See the following code snippet from layer_factory.cpp

    // CuDNN assumes layers are not being modified in place, thus
    // breaking our index tracking for updates in some cases in Caffe.
    // Until there is a workaround in Caffe (index management) or
    // cuDNN, use Caffe layer to max pooling, or don't use in place
    // layers after max pooling layers
    if (param.pooling_param().pool() == PoolingParameter_PoolMethod_MAX) {
        return shared_ptr<Layer<Dtype> >(new PoolingLayer<Dtype>(param));
    } else {
        return shared_ptr<Layer<Dtype> >(new CuDNNPoolingLayer<Dtype>(param));
    }

I don't know if the behaviour of caffe or cudnn changed in the meantime and how to solve this.
So if you are not affected by this, just remove this condition. I also used MAX pooling with cudnn a lot and it seemed to work very well.

The automatic build failed because of an issue with travis-ci (boost-1.56 was missing?). I don't know how to initiate another build without committing something. At least on my computer, every test passes.

@christianpayer
Copy link
Author

I just added tests for 3D cudnn convolution and pooling. The new convolution tests are based on the already existing 3D caffe convolution tests, the 3D pooling tests that are based on the existing 2D tests. I used a matlab script for generating the hard-coded input and output values.

// dimensions in all spatial dimensions, or once per spatial dimension.
repeated uint32 pad = 4; // The padding size; defaults to 0
repeated uint32 kernel_size = 2; // The kernel size
repeated uint32 stride = 3; // The stride; defaults to 1
Copy link
Contributor

@ajtulloch ajtulloch Apr 17, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a safe transformation (changing from option to repeated) in general? If I have a protobinary-serialized NetParameter and I reload it, does it get correctly deserialized? I understand that text formatted NetParameters are consistent though.

http:https://stackoverflow.com/questions/20220459/what-do-i-need-to-be-careful-of-when-changing-protobuf-serialized-types-in-order seems to indicate it's not safe for Protobuf. I'm more familiar with Thrift where this is definitely not a safe transformation that led to numerous site issues (to the point where we wrote a linter to verify diffs don't modify Thrift struct types without modifying the field ids as well).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is fine, and is the approach taken for N-D conv in #2049. My parse of the protobuf documentation on updating messages is that optional -> required is ok since "optional is compatible with repeated." However, it could be more clearly spelled out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nice, it's clearly correct then, thanks for the link @shelhamer.

@futurely
Copy link

Is the compatibility to cudnn-v5 tested by Travis?

@futurely
Copy link

I've compiled this branch with CUDA 8.0 and cuDNN v5.0.5. All tests passed.

@futurely
Copy link

I just tried to rebase against the master branch. There're too many merge conflicts caused by 6028022 (compatibility to cudnn-v5) since the master has already supported cuDNN v5 with #4159 while there's only a single minor conflict to rebase from the commit 46261f7 (fix CUDNN_BAD_PARAM when using InnerProduct layer). It would be much easier to get this branch merged by seprarating the cuDNN v5 support into an independent PR.

@futurely
Copy link

The git diff file produced by cherry-picking 6028022 after first rebasing 46261f7 to master is attached.
nd_cudnn_v5.diff.txt

@christianpayer
Copy link
Author

@futurely thanks for testing and your comments. Unfortunately I am on vacation until July 21st, so I'm not able to rebase or make any changes to the code.
I use this PR for a couple of months and it seems to work fine, but if you encounter any problems or have any suggestions, let me know!

@futurely
Copy link

Training a c3d model with the rebased version of this branch and the video data layer of @chuckcho chuckcho/video-caffe#22 didn't converge. The network prototxt is adapted from this one replacing NdConvolution and NdPooling with Convolution and Pooling. I would appreciate a lot if you could add an example to showcase how to effectively train and test your nd convolution and pooling?

@bhack bhack mentioned this pull request Jul 2, 2016
3 tasks
@rogertrullo
Copy link

Thanks @christianpayer this is very helpful; I would like to ask you something, when I use in the protoxt, the MAX method in the pooling using CUDNN engine, it runs normally. Is that ok?, Should I avoid using max pooling and stick with avg? Thanks!

@christianpayer
Copy link
Author

@rogertrullo you're welcome! By default, caffe does not use CUDNN for max pooling, even if the engine is set to CUDNN (see the code snippet of layer_factory.cpp I posted on Apr 14). I did not look into the comment of this snippet, so I don't know exactly, why CUDNN max pooling is never used.
We used CUDNN's max pooling extensively and did not observe wrong/strange behavior. If you want to use it as well, just rewrite the condition of the snippet, such that the CUDNN layer is used.

@shelhamer
Copy link
Member

@christianpayer see #3574 for why cuDNN max pooling is disabled by default.

@christianpayer
Copy link
Author

@shelhamer thanks for the link!

@antran89
Copy link
Contributor

@christianpayer Hi Christian, thank you for implementing Nd-pooling layer. I want to try your code into my Caffe repo for videos with many modifications from main BLVC. Can you recommend an easy way to integrate your changes? Thank you.

@christianpayer
Copy link
Author

@antran89 Hi! Sorry for the late answer. I don't know how different the code from your repository is to the code from the main repository, so it is hard to estimate how difficult the merging will be. For getting nd convolution and pooling to work, you would need to merge the corresponding files (cudnn, cudnn_conv_layer, pooling_layer, cudnn_pooling_layer), which could result in many conflicts, as there are lots of changes. The remaining files do not contain so many changes, so merging should work fine.
If you have any problems, feel free to ask.

@oyxhust
Copy link

oyxhust commented Feb 17, 2017

So caffe can use 3d conv and 3d pooling now? If I want to use 3d conv, how should I write my prototxt?

@John1231983
Copy link

John1231983 commented Feb 22, 2017

@christianpayer: Thanks for support ND Convolution. Does it support ND Deconvolution, ND ReLU also?
In summary, your PR will be support as bellow configuration. Please let me know if it is wrong

  1. For ND convolution:
layer {
  name: "conv1_1"
  type: "Convolution"
  bottom: "data"
  top: "conv1_1"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  convolution_param {
    num_output: 32
    pad: 1
    kernel_size: 3
    stride: 1
    weight_filler {
      type: "xavier"
    }
    bias_filler {
      type: "constant"
      value: -0.1
    }
    engine: CUDNN
  }
}

For ND Pooling

layer {
  name: "pool1"
  type: "Pooling"
  bottom: "conv1_2"
  top: "pool1"
  pooling_param {
    pool: AVE
    kernel_size: 2
    stride: 2
    engine: CUDNN
  }
}

@christianpayer
Copy link
Author

@oyxhust Caffe will use 3D or 2D convolutions depending on the input blob size. So if an input layer creates a 4D blob (image, channel, y, x), the following convolutions will use 2D filters, and if an input layer creates a 5D blob (image, channel, z, y, x), the following convolutions will be 3D. So if you want to use 3D convolutions, you need to use an input layer that creates 5D outputs (e.g. HDF5 data layer). Look into thread #2049 for more discussions.

@John1231983 ND convolution and ND deconvolution are already supported without my PR. This PR only adds ND convolutions and ND pooling with CUDNN. So if you use ND deconvolution (with or without my PR), the default caffe implementation is used.
I additionally changed some activation functions (and weight fillers) to support ND. You can look into the committed files of my PR to know which activation functions I changes and how to adapt other functions that are currently not supported.
Looking at your files, I don't see any problems and they should work with this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet