From 7a536d760ca10880ccfe10f3143e7d741b8dbe5c Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 3 Sep 2019 14:33:51 -0700 Subject: [PATCH 01/22] 2d transpose naive --- src/operator/tensor/matrix_op-inl.h | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 58a535353e10..f551cea10078 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -257,6 +257,26 @@ struct TransposeParam : public dmlc::Parameter { } }; + +// using namespace mshadow; +template +MSHADOW_XINLINE void Transpose2D(DType *in, DType *out, index_t shape_0, index_t shape_1){ + +// ensure cache line hits and prevent cache miss for any configuration +index_t blocksize = 4; +index_t n = shape_0; +index_t p = shape_1; +for (index_t i = 0; i < n; i += blocksize) { + for (index_t j = 0; j < p; ++j) { + // transpose the block + for(index_t b = 0; b < blocksize && i + b < n; ++b) { + out[j*n + i + b] = in[(i + b)*p + j]; + } + } +} +} + + template void TransposeImpl(RunContext ctx, const TBlob& src, @@ -285,8 +305,9 @@ void TransposeImpl(RunContext ctx, case 2: { mshadow::Tensor in = src.FlatTo2D(s); mshadow::Tensor out = ret.FlatTo2D(s); + if (axes[0] == 1 && axes[1] == 0) { - out = in.T(); + Transpose2D(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]); } else { Copy(out, in, s); } From fe8700ce584007dad26069d8845efc08b4b245ad Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 4 Sep 2019 10:05:52 -0700 Subject: [PATCH 02/22] omp pragma --- src/operator/tensor/matrix_op-inl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index f551cea10078..4fd0d17a2a58 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -266,6 +266,7 @@ MSHADOW_XINLINE void Transpose2D(DType *in, DType *out, index_t shape_0, index_t index_t blocksize = 4; index_t n = shape_0; index_t p = shape_1; +#pragma omp parallel for for (index_t i = 0; i < n; i += blocksize) { for (index_t j = 0; j < p; ++j) { // transpose the block From 5644dfd4845009df1f9c06a4cd7bead1929310c3 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 4 Sep 2019 10:47:52 -0700 Subject: [PATCH 03/22] omp pragma unroll --- src/operator/tensor/matrix_op-inl.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 4fd0d17a2a58..9b96848f352a 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -266,10 +266,12 @@ MSHADOW_XINLINE void Transpose2D(DType *in, DType *out, index_t shape_0, index_t index_t blocksize = 4; index_t n = shape_0; index_t p = shape_1; -#pragma omp parallel for + for (index_t i = 0; i < n; i += blocksize) { + #pragma omp parallel for for (index_t j = 0; j < p; ++j) { // transpose the block + #pragma unroll 4 for(index_t b = 0; b < blocksize && i + b < n; ++b) { out[j*n + i + b] = in[(i + b)*p + j]; } From 0d01d71cc8602fae59d79faa6aefe9cdee2d505b Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 4 Sep 2019 10:53:28 -0700 Subject: [PATCH 04/22] blocksize --- src/operator/tensor/matrix_op-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 9b96848f352a..5959238320ea 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -263,7 +263,7 @@ template MSHADOW_XINLINE void Transpose2D(DType *in, DType *out, index_t shape_0, index_t shape_1){ // ensure cache line hits and prevent cache miss for any configuration -index_t blocksize = 4; +index_t blocksize = 32; index_t n = shape_0; index_t p = shape_1; From c7864c2903a00c1881648a7918b1c36be88f7ff3 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 4 Sep 2019 11:32:04 -0700 Subject: [PATCH 05/22] make it 2d tile --- src/operator/tensor/matrix_op-inl.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 5959238320ea..d0d67c7a1dfc 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -269,12 +269,14 @@ index_t p = shape_1; for (index_t i = 0; i < n; i += blocksize) { #pragma omp parallel for - for (index_t j = 0; j < p; ++j) { + for (index_t j = 0; j < p; j += blocksize) { // transpose the block #pragma unroll 4 - for(index_t b = 0; b < blocksize && i + b < n; ++b) { - out[j*n + i + b] = in[(i + b)*p + j]; - } + for (index_t a = 0; a < blocksize && j + a < n; ++a) { + for (index_t b = 0; b < blocksize && i + b < n; ++b) { + out[(j + a) * n + i + b] = in[(i + b) * p + (j + a)]; + } + } } } } From b44d4ea79f3cb241509fcc483977bf188f26667d Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 4 Sep 2019 17:04:49 -0700 Subject: [PATCH 06/22] loop peeling --- src/operator/tensor/matrix_op-inl.h | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index d0d67c7a1dfc..2a0fd03cbe45 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -272,9 +272,14 @@ for (index_t i = 0; i < n; i += blocksize) { for (index_t j = 0; j < p; j += blocksize) { // transpose the block #pragma unroll 4 - for (index_t a = 0; a < blocksize && j + a < n; ++a) { - for (index_t b = 0; b < blocksize && i + b < n; ++b) { - out[(j + a) * n + i + b] = in[(i + b) * p + (j + a)]; + for (index_t a = j; a < j + blocksize; ++a) { + for (index_t b = i; b < i + blocksize; ++b) { + out[a * n + i] = in[i * p + j]; + } + } + for (index_t a = a; a < n; ++a) { + for (index_t b = b; b < n; ++b) { + out[a * n + i] = in[i * p + j]; } } } From bdfde0a23a83aac5720086a274bc204b8df77c69 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 5 Sep 2019 10:32:42 -0700 Subject: [PATCH 07/22] better loop peeling --- src/operator/tensor/matrix_op-inl.h | 39 +++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 2a0fd03cbe45..fe5bf8d8fd39 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -267,23 +267,42 @@ index_t blocksize = 32; index_t n = shape_0; index_t p = shape_1; -for (index_t i = 0; i < n; i += blocksize) { +index_t N = (n/blocksize)*blocksize; +bool n_strip = (n - N > 0); +index_t P = (p/blocksize)*blocksize; +bool p_strip = (p - P > 0); + +for (index_t i = 0; i < N; i += blocksize) { #pragma omp parallel for - for (index_t j = 0; j < p; j += blocksize) { + for (index_t j = 0; j < P; j += blocksize) { // transpose the block #pragma unroll 4 - for (index_t a = j; a < j + blocksize; ++a) { - for (index_t b = i; b < i + blocksize; ++b) { - out[a * n + i] = in[i * p + j]; - } + index_t a_limit = j + blocksize; + index_t b_limit = i + blocksize; + for (index_t a = j; a < a_limit; ++a) { + for (index_t b = i; b < b_limit; ++b) { + out[a * n + b] = in[b * p + a]; + } } - for (index_t a = a; a < n; ++a) { - for (index_t b = b; b < n; ++b) { - out[a * n + i] = in[i * p + j]; - } + } +} + +for (index_t i = N; i < n; i += blocksize) { + #pragma omp parallel for + for (index_t j = P; j < p; j += blocksize) { + // transpose the block + index_t a_limit = j + blocksize; + index_t b_limit = i + blocksize; + #pragma unroll 4 + for (index_t a = j; (a < a_limit && a < p); ++a) { + for (index_t b = i; (b < b_limit && b < n); ++b) { + out[a * n + b] = in[b * p + a]; + } } } } + + } From 31f43cd193df0ff91979d03f2d49ca80e6310e97 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 5 Sep 2019 10:44:34 -0700 Subject: [PATCH 08/22] redundancy --- src/operator/tensor/matrix_op-inl.h | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index fe5bf8d8fd39..e818d1d7cc14 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -287,15 +287,29 @@ for (index_t i = 0; i < N; i += blocksize) { } } -for (index_t i = N; i < n; i += blocksize) { +for (index_t i = 0; i < N; i += blocksize) { #pragma omp parallel for for (index_t j = P; j < p; j += blocksize) { // transpose the block - index_t a_limit = j + blocksize; index_t b_limit = i + blocksize; #pragma unroll 4 - for (index_t a = j; (a < a_limit && a < p); ++a) { - for (index_t b = i; (b < b_limit && b < n); ++b) { + for (index_t a = j; a < p; ++a) { + for (index_t b = i; b < b_limit; ++b) { + out[a * n + b] = in[b * p + a]; + } + } + } +} + +for (index_t i = N; i < n; i += blocksize) { + #pragma omp parallel for + for (index_t j = 0; j < P; j += blocksize) { + // transpose the block + + index_t a_limit = j + blocksize; + #pragma unroll 4 + for (index_t a = j; a < a_limit; ++a) { + for (index_t b = i; b < n; ++b) { out[a * n + b] = in[b * p + a]; } } @@ -303,6 +317,12 @@ for (index_t i = N; i < n; i += blocksize) { } +#pragma unroll 4 +for (index_t a = N; a < n; ++a) { + for (index_t b = P; b < p; ++b) { + out[a * n + b] = in[b * p + a]; + } +} } From 3f5e7ceb584f20e8b7be8c7614a7998f566e5ea0 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 5 Sep 2019 10:50:49 -0700 Subject: [PATCH 09/22] removed bool --- src/operator/tensor/matrix_op-inl.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index e818d1d7cc14..521caecec99a 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -268,9 +268,7 @@ index_t n = shape_0; index_t p = shape_1; index_t N = (n/blocksize)*blocksize; -bool n_strip = (n - N > 0); index_t P = (p/blocksize)*blocksize; -bool p_strip = (p - P > 0); for (index_t i = 0; i < N; i += blocksize) { #pragma omp parallel for From 875868d64d1db997690127004ce438ae4e34f859 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 5 Sep 2019 11:03:56 -0700 Subject: [PATCH 10/22] removing excess for loops, memory save --- src/operator/tensor/matrix_op-inl.h | 60 ++++++++++++++--------------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 521caecec99a..52bb7f9790c9 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -269,55 +269,51 @@ index_t p = shape_1; index_t N = (n/blocksize)*blocksize; index_t P = (p/blocksize)*blocksize; +index_t a_limit = 0, b_limit = 0, i = 0, j = 0, a = 0, b = 0; -for (index_t i = 0; i < N; i += blocksize) { +for (i = 0; i < N; i += blocksize) { #pragma omp parallel for - for (index_t j = 0; j < P; j += blocksize) { + for (j = 0; j < P; j += blocksize) { // transpose the block #pragma unroll 4 - index_t a_limit = j + blocksize; - index_t b_limit = i + blocksize; - for (index_t a = j; a < a_limit; ++a) { - for (index_t b = i; b < b_limit; ++b) { + a_limit = j + blocksize; + b_limit = i + blocksize; + for (a = j; a < a_limit; ++a) { + for (b = i; b < b_limit; ++b) { out[a * n + b] = in[b * p + a]; } } } } -for (index_t i = 0; i < N; i += blocksize) { - #pragma omp parallel for - for (index_t j = P; j < p; j += blocksize) { - // transpose the block - index_t b_limit = i + blocksize; - #pragma unroll 4 - for (index_t a = j; a < p; ++a) { - for (index_t b = i; b < b_limit; ++b) { - out[a * n + b] = in[b * p + a]; - } - } +#pragma omp parallel for +for (i = 0; i < N; i += blocksize) { + // transpose the block + b_limit = i + blocksize; + #pragma unroll 4 + for (a = P; a < p; ++a) { + for (b = i; b < b_limit; ++b) { + out[a * n + b] = in[b * p + a]; } + } } -for (index_t i = N; i < n; i += blocksize) { - #pragma omp parallel for - for (index_t j = 0; j < P; j += blocksize) { - // transpose the block - - index_t a_limit = j + blocksize; - #pragma unroll 4 - for (index_t a = j; a < a_limit; ++a) { - for (index_t b = i; b < n; ++b) { - out[a * n + b] = in[b * p + a]; - } - } + +#pragma omp parallel for +for (j = 0; j < P; j += blocksize) { + // transpose the block + a_limit = j + blocksize; + #pragma unroll 4 + for (a = j; a < a_limit; ++a) { + for (b = N; b < n; ++b) { + out[a * n + b] = in[b * p + a]; } + } } - #pragma unroll 4 -for (index_t a = N; a < n; ++a) { - for (index_t b = P; b < p; ++b) { +for (a = N; a < n; ++a) { + for (b = P; b < p; ++b) { out[a * n + b] = in[b * p + a]; } } From 29f80dd61462634fbe054f6ea766f819ee5edef5 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 5 Sep 2019 11:35:52 -0700 Subject: [PATCH 11/22] fix internal forloop --- src/operator/tensor/matrix_op-inl.h | 50 ++++------------------------- 1 file changed, 6 insertions(+), 44 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 52bb7f9790c9..a78406f0a617 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -267,56 +267,18 @@ index_t blocksize = 32; index_t n = shape_0; index_t p = shape_1; -index_t N = (n/blocksize)*blocksize; -index_t P = (p/blocksize)*blocksize; -index_t a_limit = 0, b_limit = 0, i = 0, j = 0, a = 0, b = 0; - -for (i = 0; i < N; i += blocksize) { +for (index_t i = 0; i < n; i += blocksize) { #pragma omp parallel for - for (j = 0; j < P; j += blocksize) { + for (index_t j = 0; j < p; j += blocksize) { // transpose the block #pragma unroll 4 - a_limit = j + blocksize; - b_limit = i + blocksize; - for (a = j; a < a_limit; ++a) { - for (b = i; b < b_limit; ++b) { - out[a * n + b] = in[b * p + a]; - } + for (index_t a = 0; a < blocksize && j + a < n; ++a) { + for (index_t b = 0; b < blocksize && i + b < p; ++b) { + out[(j + a) * n + i + b] = in[(i + b) * p + (j + a)]; + } } } } - -#pragma omp parallel for -for (i = 0; i < N; i += blocksize) { - // transpose the block - b_limit = i + blocksize; - #pragma unroll 4 - for (a = P; a < p; ++a) { - for (b = i; b < b_limit; ++b) { - out[a * n + b] = in[b * p + a]; - } - } -} - - -#pragma omp parallel for -for (j = 0; j < P; j += blocksize) { - // transpose the block - a_limit = j + blocksize; - #pragma unroll 4 - for (a = j; a < a_limit; ++a) { - for (b = N; b < n; ++b) { - out[a * n + b] = in[b * p + a]; - } - } -} - -#pragma unroll 4 -for (a = N; a < n; ++a) { - for (b = P; b < p; ++b) { - out[a * n + b] = in[b * p + a]; - } -} } From 789e81e6fb509ab469ae6ecfd6e5cdf323efa873 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 5 Sep 2019 12:28:07 -0700 Subject: [PATCH 12/22] remove commented code, lint fix --- src/operator/tensor/matrix_op-inl.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index a78406f0a617..6c00303abf4f 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -258,10 +258,8 @@ struct TransposeParam : public dmlc::Parameter { }; -// using namespace mshadow; template -MSHADOW_XINLINE void Transpose2D(DType *in, DType *out, index_t shape_0, index_t shape_1){ - +MSHADOW_XINLINE void Transpose2D(DType *in, DType *out, index_t shape_0, index_t shape_1) { // ensure cache line hits and prevent cache miss for any configuration index_t blocksize = 32; index_t n = shape_0; From 56c6835fef6d019552b64851955cede9c1ba8b64 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 5 Sep 2019 17:26:22 -0700 Subject: [PATCH 13/22] Trigger notification From 4e10c8cb70804392e1f1e6711c86c768d15f1a83 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Sun, 8 Sep 2019 16:34:24 -0700 Subject: [PATCH 14/22] explain params, indent fix, explain blocksize --- src/operator/tensor/matrix_op-inl.h | 48 +++++++++++++++++++---------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 6c00303abf4f..8353d41b2b59 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -258,25 +258,39 @@ struct TransposeParam : public dmlc::Parameter { }; +/*! + * \brief This function performs transpose operation on a 2D matrix by utilizing the L1 cache + * \param in input tensor + * \param out output tensor + * \param shape_0 shape of dim 0 of input + * \param shape_1 shape of dim 1 of input + */ template MSHADOW_XINLINE void Transpose2D(DType *in, DType *out, index_t shape_0, index_t shape_1) { -// ensure cache line hits and prevent cache miss for any configuration -index_t blocksize = 32; -index_t n = shape_0; -index_t p = shape_1; - -for (index_t i = 0; i < n; i += blocksize) { - #pragma omp parallel for - for (index_t j = 0; j < p; j += blocksize) { - // transpose the block - #pragma unroll 4 - for (index_t a = 0; a < blocksize && j + a < n; ++a) { - for (index_t b = 0; b < blocksize && i + b < p; ++b) { - out[(j + a) * n + i + b] = in[(i + b) * p + (j + a)]; - } - } - } -} + // ensure cache line hits and prevent cache miss for any configuration + // L1 cache size to be utilized = 32kb = 2^15 + // Largest size of a single unit of any dtype <= 8 byte = 2^3 + // Number of elements - (2^15/2^3) = 2^12 + // Block-size - 2^6 v 2^6 (64 v 64) + + // But we could leverage unrolling of for loops (for parallelization) + // Block-size - 2^5 v 2^5 (32 v 32) with 4 pragma for loop unrolled + index_t blocksize = 32; + index_t n = shape_0; + index_t p = shape_1; + + for (index_t i = 0; i < n; i += blocksize) { + #pragma omp parallel for + for (index_t j = 0; j < p; j += blocksize) { + // transpose the block + #pragma unroll 4 + for (index_t a = 0; a < blocksize && j + a < n; ++a) { + for (index_t b = 0; b < blocksize && i + b < p; ++b) { + out[(j + a) * n + i + b] = in[(i + b) * p + (j + a)]; + } + } + } + } } From 52b41fb0072f69392d7098278b21c06b7d9ff33d Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Mon, 9 Sep 2019 09:25:40 -0700 Subject: [PATCH 15/22] fix p,n and reduce for loop computation j+a,i+b --- src/operator/tensor/matrix_op-inl.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index c4fadee190e5..f68d0622e908 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -284,9 +284,9 @@ MSHADOW_XINLINE void Transpose2D(DType *in, DType *out, index_t shape_0, index_t for (index_t j = 0; j < p; j += blocksize) { // transpose the block #pragma unroll 4 - for (index_t a = 0; a < blocksize && j + a < n; ++a) { - for (index_t b = 0; b < blocksize && i + b < p; ++b) { - out[(j + a) * n + i + b] = in[(i + b) * p + (j + a)]; + for (index_t a = j; a < blocksize && a < p; ++a) { + for (index_t b = i; b < blocksize && b < n; ++b) { + out[a * n + b] = in[b * p + a]; } } } From e5ac4eea11e71456daacb58874bd07c88f86c33e Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 18 Sep 2019 11:14:40 -0700 Subject: [PATCH 16/22] kernel --- src/operator/tensor/matrix_op-inl.h | 64 +++++++++++++++-------------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index f68d0622e908..518a369e6f75 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -260,38 +260,39 @@ struct TransposeParam : public dmlc::Parameter { /*! * \brief This function performs transpose operation on a 2D matrix by utilizing the L1 cache - * \param in input tensor - * \param out output tensor - * \param shape_0 shape of dim 0 of input - * \param shape_1 shape of dim 1 of input + * \param in input tensor + * \param out output tensor + * \param row shape of dim 0 of input + * \param col shape of dim 1 of input */ -template -MSHADOW_XINLINE void Transpose2D(DType *in, DType *out, index_t shape_0, index_t shape_1) { - // ensure cache line hits and prevent cache miss for any configuration - // L1 cache size to be utilized = 32kb = 2^15 - // Largest size of a single unit of any dtype <= 8 byte = 2^3 - // Number of elements - (2^15/2^3) = 2^12 - // Block-size - 2^6 v 2^6 (64 v 64) - - // But we could leverage unrolling of for loops (for parallelization) - // Block-size - 2^5 v 2^5 (32 v 32) with 4 pragma for loop unrolled - index_t blocksize = 32; - index_t n = shape_0; - index_t p = shape_1; - - for (index_t i = 0; i < n; i += blocksize) { - #pragma omp parallel for - for (index_t j = 0; j < p; j += blocksize) { - // transpose the block - #pragma unroll 4 - for (index_t a = j; a < blocksize && a < p; ++a) { - for (index_t b = i; b < blocksize && b < n; ++b) { - out[a * n + b] = in[b * p + a]; - } - } - } +struct Transpose2D{ + template + MSHADOW_XINLINE static void Map(const DType *in, DType *out, index_t row, index_t col) { + // ensure cache line hits and prevent cache miss for any configuration + // L1 cache size to be utilized = 32kb = 2^15 + // Largest size of a single unit of any dtype <= 8 byte = 2^3 + // Number of elements - (2^15/2^3) = 2^12 + // Block-size - 2^6 v 2^6 (64 v 64) + + // But we could leverage unrolling of for loops (for parallelization) + // Block-size - 2^5 v 2^5 (32 v 32) with 4 pragma for loop unrolled + // blocksize * blocksize * num_threads = cache_size / dtype_size + index_t blocksize = 32; + + for (index_t i = 0; i < row; i += blocksize) { + #pragma omp parallel for + for (index_t j = 0; j < col; j += blocksize) { + // transpose the block + #pragma unroll 4 + for (index_t a = j; a < blocksize && a < col; ++a) { + for (index_t b = i; b < blocksize && b < row; ++b) { + out[a * row + b] = in[b * col + a]; + } + } + } + } } -} +}; template @@ -324,7 +325,8 @@ void TransposeImpl(RunContext ctx, mshadow::Tensor out = ret.FlatTo2D(s); if (axes[0] == 1 && axes[1] == 0) { - Transpose2D(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]); + const index_t size = in.Size(); + Kernel::Launch(s, size, in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]); } else { Copy(out, in, s); } From 0f28c7d7135c702b4798224332c700115ed8b4f6 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 18 Sep 2019 14:45:57 -0700 Subject: [PATCH 17/22] gpu thread 1 --- src/operator/tensor/matrix_op-inl.h | 76 +++++++++++++++++------------ 1 file changed, 46 insertions(+), 30 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 518a369e6f75..1a8e74659ea5 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -258,43 +258,59 @@ struct TransposeParam : public dmlc::Parameter { }; -/*! - * \brief This function performs transpose operation on a 2D matrix by utilizing the L1 cache - * \param in input tensor - * \param out output tensor - * \param row shape of dim 0 of input - * \param col shape of dim 1 of input - */ -struct Transpose2D{ - template - MSHADOW_XINLINE static void Map(const DType *in, DType *out, index_t row, index_t col) { - // ensure cache line hits and prevent cache miss for any configuration - // L1 cache size to be utilized = 32kb = 2^15 - // Largest size of a single unit of any dtype <= 8 byte = 2^3 - // Number of elements - (2^15/2^3) = 2^12 - // Block-size - 2^6 v 2^6 (64 v 64) - - // But we could leverage unrolling of for loops (for parallelization) - // Block-size - 2^5 v 2^5 (32 v 32) with 4 pragma for loop unrolled - // blocksize * blocksize * num_threads = cache_size / dtype_size +struct transpose_2d_kernel{ + template + MSHADOW_XINLINE static void Map(index_t t, const DType *in, DType *out, index_t row, index_t col) { index_t blocksize = 32; - for (index_t i = 0; i < row; i += blocksize) { - #pragma omp parallel for - for (index_t j = 0; j < col; j += blocksize) { - // transpose the block - #pragma unroll 4 - for (index_t a = j; a < blocksize && a < col; ++a) { - for (index_t b = i; b < blocksize && b < row; ++b) { + for (index_t i = 0; i < row; i += blocksize) { + #pragma omp parallel for + for (index_t j = 0; j < col; j += blocksize) { + // transpose the block + #pragma unroll 4 + for (index_t a = j; a < blocksize && a < col; ++a) { + for (index_t b = i; b < blocksize && b < row; ++b) { out[a * row + b] = in[b * col + a]; } } - } - } + } } + } }; +/*! + * \brief This function performs transpose operation on a 2D matrix by utilizing the L1 cache + * \param in input tensor + * \param out output tensor + * \param row shape of dim 0 of input + * \param col shape of dim 1 of input + */ +// template +// MSHADOW_XINLINE void Transpose2D(Stream* s, const DType *in, DType *out, index_t row, index_t col) { +// // ensure cache line hits and prevent cache miss for any configuration +// // L1 cache size to be utilized = 32kb = 2^15 +// // Largest size of a single unit of any dtype <= 8 byte = 2^3 +// // Number of elements - (2^15/2^3) = 2^12 +// // Block-size - 2^6 v 2^6 (64 v 64) + +// // But we could leverage unrolling of for loops (for parallelization) +// // Block-size - 2^5 v 2^5 (32 v 32) with 4 pragma for loop unrolled +// // blocksize * blocksize * num_threads = cache_size / dtype_size +// index_t blocksize = 32; + +// for (index_t i = 0; i < row; i += blocksize) { +// #pragma omp parallel for +// for (index_t j = 0; j < col; j += blocksize) { +// // transpose the block +// #pragma unroll 4 +// for (index_t a = j; a < blocksize && a < col; ++a) { +// mxnet_op::Kernel::Launch(s, size, in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]); +// } +// } +// } + + template void TransposeImpl(RunContext ctx, const TBlob& src, @@ -325,8 +341,8 @@ void TransposeImpl(RunContext ctx, mshadow::Tensor out = ret.FlatTo2D(s); if (axes[0] == 1 && axes[1] == 0) { - const index_t size = in.Size(); - Kernel::Launch(s, size, in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]); + // const index_t size = in.Size(); + mxnet_op::Kernel::Launch(s, 1, in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]); } else { Copy(out, in, s); } From b91c5c6f2a950327a2944f834bb068d3b9debe20 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 24 Sep 2019 11:38:42 -0700 Subject: [PATCH 18/22] remove gpu implementation --- src/operator/tensor/matrix_op-inl.h | 77 ++++++++++++----------------- 1 file changed, 31 insertions(+), 46 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 1a8e74659ea5..2788e8949fd7 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -258,27 +258,6 @@ struct TransposeParam : public dmlc::Parameter { }; -struct transpose_2d_kernel{ - template - MSHADOW_XINLINE static void Map(index_t t, const DType *in, DType *out, index_t row, index_t col) { - index_t blocksize = 32; - - for (index_t i = 0; i < row; i += blocksize) { - #pragma omp parallel for - for (index_t j = 0; j < col; j += blocksize) { - // transpose the block - #pragma unroll 4 - for (index_t a = j; a < blocksize && a < col; ++a) { - for (index_t b = i; b < blocksize && b < row; ++b) { - out[a * row + b] = in[b * col + a]; - } - } - } - } - } -}; - - /*! * \brief This function performs transpose operation on a 2D matrix by utilizing the L1 cache * \param in input tensor @@ -286,29 +265,32 @@ struct transpose_2d_kernel{ * \param row shape of dim 0 of input * \param col shape of dim 1 of input */ -// template -// MSHADOW_XINLINE void Transpose2D(Stream* s, const DType *in, DType *out, index_t row, index_t col) { -// // ensure cache line hits and prevent cache miss for any configuration -// // L1 cache size to be utilized = 32kb = 2^15 -// // Largest size of a single unit of any dtype <= 8 byte = 2^3 -// // Number of elements - (2^15/2^3) = 2^12 -// // Block-size - 2^6 v 2^6 (64 v 64) - -// // But we could leverage unrolling of for loops (for parallelization) -// // Block-size - 2^5 v 2^5 (32 v 32) with 4 pragma for loop unrolled -// // blocksize * blocksize * num_threads = cache_size / dtype_size -// index_t blocksize = 32; - -// for (index_t i = 0; i < row; i += blocksize) { -// #pragma omp parallel for -// for (index_t j = 0; j < col; j += blocksize) { -// // transpose the block -// #pragma unroll 4 -// for (index_t a = j; a < blocksize && a < col; ++a) { -// mxnet_op::Kernel::Launch(s, size, in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]); -// } -// } -// } +template +MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index_t col) { + // ensure cache line hits and prevent cache miss for any configuration + // L1 cache size to be utilized = 32kb = 2^15 + // Largest size of a single unit of any dtype <= 8 byte = 2^3 + // Number of elements - (2^15/2^3) = 2^12 + // Block-size - 2^6 v 2^6 (64 v 64) + + // But we could leverage unrolling of for loops (for parallelization) + // Block-size - 2^5 v 2^5 (32 v 32) with 4 pragma for loop unrolled + // blocksize * blocksize * num_threads = cache_size / dtype_size + index_t blocksize = 32; + + for (index_t i = 0; i < row; i += blocksize) { + #pragma omp parallel for + for (index_t j = 0; j < col; j += blocksize) { + // transpose the block + #pragma unroll 4 + for (index_t a = j; a < blocksize && a < col; ++a) { + for (index_t b = i; b < blocksize && b < row; ++b) { + out[a * row + b] = in[b * col + a]; + } + } + } + } +} template @@ -341,8 +323,11 @@ void TransposeImpl(RunContext ctx, mshadow::Tensor out = ret.FlatTo2D(s); if (axes[0] == 1 && axes[1] == 0) { - // const index_t size = in.Size(); - mxnet_op::Kernel::Launch(s, 1, in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]); + if (ctx.get_ctx().dev_mask() == cpu::kDevMask) { + Transpose2D(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]); + } else { + out = in.T(); + } } else { Copy(out, in, s); } From 22650b5850c457830ba23fcb63acd5520241533e Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Mon, 7 Oct 2019 15:32:35 -0700 Subject: [PATCH 19/22] fix internal for loop --- src/operator/tensor/matrix_op-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 2788e8949fd7..81e8bb0cdfeb 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -283,8 +283,8 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index for (index_t j = 0; j < col; j += blocksize) { // transpose the block #pragma unroll 4 - for (index_t a = j; a < blocksize && a < col; ++a) { - for (index_t b = i; b < blocksize && b < row; ++b) { + for (index_t a = j; (a < blocksize + j) && (a < col); ++a) { + for (index_t b = i; (b < blocksize + i) && (b < row); ++b) { out[a * row + b] = in[b * col + a]; } } From 9868f5a3378086095e1e5d5b728495ea7b936801 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Mon, 7 Oct 2019 18:24:38 -0700 Subject: [PATCH 20/22] unittest to catch the previous error --- tests/python/unittest/test_operator.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index bcf618f2784a..1d46c4a7a21f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2840,6 +2840,13 @@ def test_transpose(): assert_allclose(np.transpose(x.asnumpy()), y.asnumpy()) +@with_seed() +def test_larger_transpose(): + x = mx.nd.random.normal(shape=(50,51)) + y = mx.nd.transpose(x) + assert_allclose(np.transpose(x.asnumpy()), y.asnumpy()) + + @with_seed() def test_expand_dims(): for ndim in range(1, 6): From 95507345fa815fe80268b0fa3ee99a798484c0a9 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 8 Oct 2019 11:45:59 -0700 Subject: [PATCH 21/22] optimizations --- src/operator/tensor/matrix_op-inl.h | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 81e8bb0cdfeb..dc28699384e6 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -274,15 +274,17 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index // Block-size - 2^6 v 2^6 (64 v 64) // But we could leverage unrolling of for loops (for parallelization) - // Block-size - 2^5 v 2^5 (32 v 32) with 4 pragma for loop unrolled + // Block-size - 2^5 v 2^5 (32 v 32) with potential 4 pragma for loop unrolled // blocksize * blocksize * num_threads = cache_size / dtype_size + // Instead of explicit unroll, let compiler figure out optimal unroll factor index_t blocksize = 32; + // collapse 2 parallelizes 2 for loops + // inner 2 for loops aren't parallelized to prevent cache miss + #pragma omp parallel for collapse(2) for (index_t i = 0; i < row; i += blocksize) { - #pragma omp parallel for for (index_t j = 0; j < col; j += blocksize) { // transpose the block - #pragma unroll 4 for (index_t a = j; (a < blocksize + j) && (a < col); ++a) { for (index_t b = i; (b < blocksize + i) && (b < row); ++b) { out[a * row + b] = in[b * col + a]; From e3d3f5b1972b369f46c71ae10fca60c2d705fe45 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 8 Oct 2019 13:09:12 -0700 Subject: [PATCH 22/22] microsoft cpp doesn't support omp collapse --- src/operator/tensor/matrix_op-inl.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index dc28699384e6..dc78d359a543 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -281,7 +281,14 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index // collapse 2 parallelizes 2 for loops // inner 2 for loops aren't parallelized to prevent cache miss - #pragma omp parallel for collapse(2) + + // Microsoft Visual C++ compiler does not support omp collapse + #ifdef _MSC_VER + #pragma omp parallel for + #else + #pragma omp parallel for collapse(2) + #endif // _MSC_VER + for (index_t i = 0; i < row; i += blocksize) { for (index_t j = 0; j < col; j += blocksize) { // transpose the block