Fix transposed convolution in CPU w/o MKLDNN.#14031
Fix transposed convolution in CPU w/o MKLDNN.#14031apeforest wants to merge 38 commits intoapache:masterfrom
Conversation
|
@zhreshold @thomelane Please help to review. Thanks! |
|
Can you verify the result in unittest? |
|
@zhreshold unit test added. |
|
@mxnet-label-bot add [pr-awaiting-review] |
| Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
| #if defined(__CUDACC__) | ||
| CHECK_EQ(s->blas_handle_ownership_, Stream<xpu>::OwnHandle) | ||
| << "Must init CuBLAS handle in stream"; |
There was a problem hiding this comment.
"cuBLAS" is the official abbreviation :)
| # # check_layer_forward(layer, (1, 10, 10, 10, 4)) | ||
|
|
||
| @with_seed() | ||
| def test_deconv_dilation(): |
There was a problem hiding this comment.
Since deconv is a really important OP, I suggest to visit the original deconv test cases and add dilation > 1 cases alongside the old tests. This ensures better coverage than this single test case.
Feel free to keep this unittest which LGTM as well.
Vikas-kum
left a comment
There was a problem hiding this comment.
Good catch! Looks good.
| DeconvolutionParam param_; | ||
| mshadow::Shape<2> shape_colunit_; | ||
| mshadow::Shape<3> shape_dstunit_; | ||
| index_t nstep_; |
There was a problem hiding this comment.
Can you please tell me why was this removed?
There was a problem hiding this comment.
The col2im method does not support such step.
| Tensor<xpu, 1, DType> workspace = | ||
| ctx.requested[deconv::kTempSpace].get_space_typed<xpu, 1, DType>( | ||
| Shape1(this->InitTemp(out.shape_, data.shape_)), s); | ||
| for (index_t i = 0; i < nbatch; i += nstep_) { |
There was a problem hiding this comment.
Do you know what was "nstep_" doing earlier? It would help understand the problem with the earlier code.
There was a problem hiding this comment.
The col2im method does not support such step.
| ctx.requested[deconv::kTempSpace].get_space_typed<xpu, 1, DType>( | ||
| Shape1(this->InitTemp(grad.shape_, data.shape_)), s); | ||
| for (index_t i = 0; i < nbatch; i += nstep_) { | ||
| const index_t step = std::min(nstep_, nbatch - i); |
There was a problem hiding this comment.
Again can you tell what was the purpose of "step" in the previous code?
There was a problem hiding this comment.
I think it's used to convert multiple batch of image data into columns in the prevous library. However, it is not supported in the col2im method.
There was a problem hiding this comment.
The code changes are consistent with the optimized way to perform Deconv operation on just CPU but I have some questions that will help me understand what was happening earlier and why was it that way. Rest your code is correct and precise. Good Work !
|
@apeforest can you please rebase and resolve the merge conflicts? |
|
@apeforest Could you please have a look at the CI failures? |
|
@apeforest Gentle ping... |
|
@apeforest Can you take a look at failing CI build? |
|
@apeforest Could you please provide an updates on this PR about your progress and thoughts so that the other community members get help from this. Thanks! |
|
@karan6181 The new function im2col has different signature and calling sequence from old col_unpack(). The changes fail in a few unit tests and I ended up re-implementing the operator itself. Given that current MKLDNN is default in CPU and it has no issue with Conv2DTranspose operator, I would like to treat this issue as lower priority and get it complete in a few weeks. |
|
@mxnet-label-bot Update[pr-work-in-progress] @apeforest Can you look into the CI failures ? |
|
@apeforest Hi! Any update in this PR? The PR is important: ) |
Description
transposed convolution operator in CPU w/o MKLDNN is not working properly when dilation is set. This is because the mshadow library function
unpack_patch2colandpack_col2patchgenerate incorrect results with dilation parameter. This PR replaced these two functions with MXNet native functionim2colandcol2imThis PR fixs issue #11203
Passed the local test in the issue:
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.