| // This file is part of Eigen, a lightweight C++ template library |
| // for linear algebra. |
| // |
| // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> |
| // |
| // This Source Code Form is subject to the terms of the Mozilla |
| // Public License v. 2.0. If a copy of the MPL was not distributed |
| // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| |
| #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H |
| #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H |
| |
| #if defined(EIGEN_USE_GPU) && defined(__CUDACC__) |
| |
| namespace Eigen { |
| |
| template<typename Scalar, typename Index, typename LhsMapper, |
| typename RhsMapper, typename OutputMapper, bool needs_edge_check> |
| __device__ EIGEN_STRONG_INLINE void |
| EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, |
| const OutputMapper output, volatile Scalar* lhs_shmem, volatile Scalar* rhs_shmem, |
| const Index m_size, const Index n_size, const Index k_size) { |
| |
| const Index m_block_idx = blockIdx.x; |
| const Index n_block_idx = blockIdx.y; |
| |
| const Index base_m = 64 * m_block_idx; |
| const Index base_n = 64 * n_block_idx; |
| |
| // declare and initialize 64 registers for output 8x8 block |
| |
| // prefetch registers |
| Scalar lhs_pf0; |
| Scalar lhs_pf1; |
| Scalar lhs_pf2; |
| Scalar lhs_pf3; |
| Scalar lhs_pf4; |
| Scalar lhs_pf5; |
| Scalar lhs_pf6; |
| Scalar lhs_pf7; |
| |
| Scalar rhs_pf0; |
| Scalar rhs_pf1; |
| Scalar rhs_pf2; |
| Scalar rhs_pf3; |
| Scalar rhs_pf4; |
| Scalar rhs_pf5; |
| Scalar rhs_pf6; |
| Scalar rhs_pf7; |
| |
| // shared memory is formatted |
| // (contract idx in block, nocontract idx in block, block idx) |
| // where block idx is column major. This transposition limits the number of |
| // bank conflicts when reading the LHS. The core idea is that since the contracting |
| // index is shared by both sides, then the contracting index should be in threadIdx.x. |
| |
| // On the LHS, we pad each row inside of each block with an extra element. This makes |
| // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts |
| // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks. |
| |
| // On the RHS we just add 8 padding elements to the end of each block. This gives no bank |
| // conflicts on writes and also none on reads. |
| |
| // storage indices |
| const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z; |
| const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x; |
| |
| const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0; |
| const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1; |
| const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2; |
| const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3; |
| const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4; |
| const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5; |
| const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6; |
| const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7; |
| |
| const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0; |
| const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1; |
| const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2; |
| const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3; |
| const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4; |
| const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5; |
| const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6; |
| const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7; |
| |
| // in the loading code, the following variables are important: |
| // threadIdx.x: the vertical position in an 8x8 block |
| // threadIdx.y: the vertical index of the 8x8 block in the grid |
| // threadIdx.z: the horizontal position in an 8x8 block |
| // k: the horizontal index of the 8x8 block in the grid |
| // |
| // The k parameter is implicit (it was the loop counter for a loop that went |
| // from 0 to <8, but now that loop is unrolled in the below code. |
| |
| const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y; |
| const Index lhs_vert = base_m + load_idx_vert; |
| |
| #define prefetchIntoRegisters(base_k) \ |
| { \ |
| lhs_pf0 = Scalar(0); \ |
| lhs_pf1 = Scalar(0); \ |
| lhs_pf2 = Scalar(0); \ |
| lhs_pf3 = Scalar(0); \ |
| lhs_pf4 = Scalar(0); \ |
| lhs_pf5 = Scalar(0); \ |
| lhs_pf6 = Scalar(0); \ |
| lhs_pf7 = Scalar(0); \ |
| \ |
| rhs_pf0 = Scalar(0); \ |
| rhs_pf1 = Scalar(0); \ |
| rhs_pf2 = Scalar(0); \ |
| rhs_pf3 = Scalar(0); \ |
| rhs_pf4 = Scalar(0); \ |
| rhs_pf5 = Scalar(0); \ |
| rhs_pf6 = Scalar(0); \ |
| rhs_pf7 = Scalar(0); \ |
| \ |
| if (!needs_edge_check || lhs_vert < m_size) { \ |
| const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \ |
| const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \ |
| const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \ |
| const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \ |
| const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \ |
| const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \ |
| const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \ |
| const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \ |
| \ |
| if (!needs_edge_check || lhs_horiz_7 < k_size) { \ |
| lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ |
| lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ |
| lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ |
| lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ |
| lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ |
| lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ |
| lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ |
| lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \ |
| } else if (lhs_horiz_6 < k_size) { \ |
| lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ |
| lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ |
| lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ |
| lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ |
| lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ |
| lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ |
| lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ |
| } else if (lhs_horiz_5 < k_size) { \ |
| lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ |
| lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ |
| lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ |
| lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ |
| lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ |
| lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ |
| } else if (lhs_horiz_4 < k_size) { \ |
| lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ |
| lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ |
| lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ |
| lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ |
| lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ |
| } else if (lhs_horiz_3 < k_size) { \ |
| lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ |
| lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ |
| lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ |
| lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ |
| } else if (lhs_horiz_2 < k_size) { \ |
| lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ |
| lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ |
| lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ |
| } else if (lhs_horiz_1 < k_size) { \ |
| lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ |
| lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ |
| } else if (lhs_horiz_0 < k_size) { \ |
| lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ |
| } \ |
| } \ |
| \ |
| const Index rhs_vert = base_k + load_idx_vert; \ |
| if (!needs_edge_check || rhs_vert < k_size) { \ |
| const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \ |
| const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \ |
| const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \ |
| const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \ |
| const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \ |
| const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \ |
| const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \ |
| const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \ |
| \ |
| if (rhs_horiz_7 < n_size) { \ |
| rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ |
| rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ |
| rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ |
| rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ |
| rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ |
| rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ |
| rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ |
| rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \ |
| } else if (rhs_horiz_6 < n_size) { \ |
| rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ |
| rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ |
| rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ |
| rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ |
| rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ |
| rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ |
| rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ |
| } else if (rhs_horiz_5 < n_size) { \ |
| rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ |
| rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ |
| rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ |
| rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ |
| rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ |
| rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ |
| } else if (rhs_horiz_4 < n_size) { \ |
| rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ |
| rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ |
| rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ |
| rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ |
| rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ |
| } else if (rhs_horiz_3 < n_size) { \ |
| rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ |
| rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ |
| rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ |
| rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ |
| } else if (rhs_horiz_2 < n_size) { \ |
| rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ |
| rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ |
| rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ |
| } else if (rhs_horiz_1 < n_size) { \ |
| rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ |
| rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ |
| } else if (rhs_horiz_0 < n_size) { \ |
| rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ |
| } \ |
| } \ |
| } \ |
| |
| #define writeRegToShmem(_) \ |
| lhs_shmem[lhs_store_idx_0] = lhs_pf0; \ |
| rhs_shmem[rhs_store_idx_0] = rhs_pf0; \ |
| \ |
| lhs_shmem[lhs_store_idx_1] = lhs_pf1; \ |
| rhs_shmem[rhs_store_idx_1] = rhs_pf1; \ |
| \ |
| lhs_shmem[lhs_store_idx_2] = lhs_pf2; \ |
| rhs_shmem[rhs_store_idx_2] = rhs_pf2; \ |
| \ |
| lhs_shmem[lhs_store_idx_3] = lhs_pf3; \ |
| rhs_shmem[rhs_store_idx_3] = rhs_pf3; \ |
| \ |
| lhs_shmem[lhs_store_idx_4] = lhs_pf4; \ |
| rhs_shmem[rhs_store_idx_4] = rhs_pf4; \ |
| \ |
| lhs_shmem[lhs_store_idx_5] = lhs_pf5; \ |
| rhs_shmem[rhs_store_idx_5] = rhs_pf5; \ |
| \ |
| lhs_shmem[lhs_store_idx_6] = lhs_pf6; \ |
| rhs_shmem[rhs_store_idx_6] = rhs_pf6; \ |
| \ |
| lhs_shmem[lhs_store_idx_7] = lhs_pf7; \ |
| rhs_shmem[rhs_store_idx_7] = rhs_pf7; \ |
| |
| // declare and initialize result array |
| #define res(i, j) _res_##i##j |
| #define initResultRow(i) \ |
| Scalar res(i, 0) = Scalar(0); \ |
| Scalar res(i, 1) = Scalar(0); \ |
| Scalar res(i, 2) = Scalar(0); \ |
| Scalar res(i, 3) = Scalar(0); \ |
| Scalar res(i, 4) = Scalar(0); \ |
| Scalar res(i, 5) = Scalar(0); \ |
| Scalar res(i, 6) = Scalar(0); \ |
| Scalar res(i, 7) = Scalar(0); \ |
| |
| initResultRow(0); |
| initResultRow(1); |
| initResultRow(2); |
| initResultRow(3); |
| initResultRow(4); |
| initResultRow(5); |
| initResultRow(6); |
| initResultRow(7); |
| #undef initResultRow |
| |
| for (Index base_k = 0; base_k < k_size; base_k += 64) { |
| // wait for previous iteration to finish with shmem. Despite common sense, |
| // the code is a bit faster with this here then at bottom of loop |
| __syncthreads(); |
| |
| prefetchIntoRegisters(base_k); |
| writeRegToShmem(); |
| |
| #undef prefetchIntoRegisters |
| #undef writeRegToShmem |
| |
| // wait for shared mem packing to be done before starting computation |
| __syncthreads(); |
| |
| // compute 8x8 matrix product by outer product. This involves packing one column |
| // of LHS and one row of RHS into registers (takes 16 registers). |
| |
| #define lcol(i) _lcol##i |
| Scalar lcol(0); |
| Scalar lcol(1); |
| Scalar lcol(2); |
| Scalar lcol(3); |
| Scalar lcol(4); |
| Scalar lcol(5); |
| Scalar lcol(6); |
| Scalar lcol(7); |
| |
| #define rrow(j) _rrow##j |
| Scalar rrow(0); |
| Scalar rrow(1); |
| Scalar rrow(2); |
| Scalar rrow(3); |
| Scalar rrow(4); |
| Scalar rrow(5); |
| Scalar rrow(6); |
| Scalar rrow(7); |
| |
| // Now x corresponds to k, y to m, and z to n |
| const volatile Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y]; |
| const volatile Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z]; |
| |
| #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))] |
| #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))] |
| |
| #define loadData(i, j) \ |
| lcol(0) = lhs_element(0, j); \ |
| rrow(0) = rhs_element(i, 0); \ |
| lcol(1) = lhs_element(1, j); \ |
| rrow(1) = rhs_element(i, 1); \ |
| lcol(2) = lhs_element(2, j); \ |
| rrow(2) = rhs_element(i, 2); \ |
| lcol(3) = lhs_element(3, j); \ |
| rrow(3) = rhs_element(i, 3); \ |
| lcol(4) = lhs_element(4, j); \ |
| rrow(4) = rhs_element(i, 4); \ |
| lcol(5) = lhs_element(5, j); \ |
| rrow(5) = rhs_element(i, 5); \ |
| lcol(6) = lhs_element(6, j); \ |
| rrow(6) = rhs_element(i, 6); \ |
| lcol(7) = lhs_element(7, j); \ |
| rrow(7) = rhs_element(i, 7); \ |
| |
| #define computeCol(j) \ |
| res(0, j) += lcol(0) * rrow(j); \ |
| res(1, j) += lcol(1) * rrow(j); \ |
| res(2, j) += lcol(2) * rrow(j); \ |
| res(3, j) += lcol(3) * rrow(j); \ |
| res(4, j) += lcol(4) * rrow(j); \ |
| res(5, j) += lcol(5) * rrow(j); \ |
| res(6, j) += lcol(6) * rrow(j); \ |
| res(7, j) += lcol(7) * rrow(j); \ |
| |
| #define computePass(i) \ |
| loadData(i, i); \ |
| \ |
| computeCol(0); \ |
| computeCol(1); \ |
| computeCol(2); \ |
| computeCol(3); \ |
| computeCol(4); \ |
| computeCol(5); \ |
| computeCol(6); \ |
| computeCol(7); \ |
| |
| computePass(0); |
| computePass(1); |
| computePass(2); |
| computePass(3); |
| computePass(4); |
| computePass(5); |
| computePass(6); |
| computePass(7); |
| |
| #undef lcol |
| #undef rrow |
| #undef lhs_element |
| #undef rhs_element |
| #undef loadData |
| #undef computeCol |
| #undef computePass |
| } // end loop over k |
| |
| // we've now iterated over all of the large (ie width 64) k blocks and |
| // accumulated results in registers. At this point thread (x, y, z) contains |
| // the sum across all big k blocks of the product of little k block of index (x, y) |
| // with block of index (y, z). To compute the final output, we need to reduce |
| // the 8 threads over y by summation. |
| #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask) |
| |
| #define reduceRow(i, mask) \ |
| shuffleInc(i, 0, mask); \ |
| shuffleInc(i, 1, mask); \ |
| shuffleInc(i, 2, mask); \ |
| shuffleInc(i, 3, mask); \ |
| shuffleInc(i, 4, mask); \ |
| shuffleInc(i, 5, mask); \ |
| shuffleInc(i, 6, mask); \ |
| shuffleInc(i, 7, mask); \ |
| |
| #define reduceMatrix(mask) \ |
| reduceRow(0, mask); \ |
| reduceRow(1, mask); \ |
| reduceRow(2, mask); \ |
| reduceRow(3, mask); \ |
| reduceRow(4, mask); \ |
| reduceRow(5, mask); \ |
| reduceRow(6, mask); \ |
| reduceRow(7, mask); \ |
| |
| // actually perform the reduction, now each thread of index (_, y, z) |
| // contains the correct values in its registers that belong in the output |
| // block |
| reduceMatrix(1); |
| reduceMatrix(2); |
| reduceMatrix(4); |
| |
| #undef shuffleInc |
| #undef reduceRow |
| #undef reduceMatrix |
| |
| // now we need to copy the 64 values into main memory. We can't split work |
| // among threads because all variables are in registers. There's 2 ways |
| // to do this: |
| // (1) have 1 thread do 64 writes from registers into global memory |
| // (2) have 1 thread do 64 writes into shared memory, and then 8 threads |
| // each do 8 writes into global memory. We can just overwrite the shared |
| // memory from the problem we just solved. |
| // (2) is slightly faster than (1) due to less branching and more ILP |
| |
| // TODO: won't yield much gain, but could just use currently unused shared mem |
| // and then we won't have to sync |
| // wait for shared mem to be out of use |
| __syncthreads(); |
| |
| #define writeResultShmem(i, j) \ |
| lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \ |
| |
| #define writeRow(i) \ |
| writeResultShmem(i, 0); \ |
| writeResultShmem(i, 1); \ |
| writeResultShmem(i, 2); \ |
| writeResultShmem(i, 3); \ |
| writeResultShmem(i, 4); \ |
| writeResultShmem(i, 5); \ |
| writeResultShmem(i, 6); \ |
| writeResultShmem(i, 7); \ |
| |
| if (threadIdx.x == 0) { |
| writeRow(0); |
| writeRow(1); |
| writeRow(2); |
| writeRow(3); |
| writeRow(4); |
| writeRow(5); |
| writeRow(6); |
| writeRow(7); |
| } |
| #undef writeResultShmem |
| #undef writeRow |
| |
| const int max_i_write = (min)((int)((m_size - base_m - threadIdx.y + 7) / 8), 8); |
| const int max_j_write = (min)((int)((n_size - base_n - threadIdx.z + 7) / 8), 8); |
| |
| if (threadIdx.x < max_i_write) { |
| if (max_j_write == 8) { |
| Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0]; |
| Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1]; |
| Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2]; |
| Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3]; |
| Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4]; |
| Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5]; |
| Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6]; |
| Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7]; |
| |
| output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0; |
| output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1; |
| output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2; |
| output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3; |
| output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4; |
| output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5; |
| output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6; |
| output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7; |
| } else { |
| #pragma unroll 7 |
| for (int j = 0; j < max_j_write; j++) { |
| Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j]; |
| output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val; |
| } |
| } |
| } |
| #undef res |
| } |
| |
| |
| template<typename Scalar, typename Index, typename LhsMapper, |
| typename RhsMapper, typename OutputMapper> |
| __global__ void |
| __launch_bounds__(512) |
| EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, |
| const OutputMapper output, |
| const Index m_size, const Index n_size, const Index k_size) { |
| __shared__ volatile Scalar lhs_shmem[72 * 64]; |
| __shared__ volatile Scalar rhs_shmem[72 * 64]; |
| |
| const Index m_block_idx = blockIdx.x; |
| const Index n_block_idx = blockIdx.y; |
| |
| const Index base_m = 64 * m_block_idx; |
| const Index base_n = 64 * n_block_idx; |
| |
| if (base_m + 63 < m_size && base_n + 63 < n_size) { |
| EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); |
| } else { |
| EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); |
| } |
| } |
| |
| |
| |
| template<typename Index, typename LhsMapper, |
| typename RhsMapper, typename OutputMapper, bool needs_edge_check> |
| __device__ EIGEN_STRONG_INLINE void |
| EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, |
| const OutputMapper output, float4* lhs_shmem4, float2* rhs_shmem2, |
| const Index m_size, const Index n_size, const Index k_size) { |
| typedef float Scalar; |
| |
| const Index m_block_idx = blockIdx.x; |
| const Index n_block_idx = blockIdx.y; |
| |
| const Index base_m = 64 * m_block_idx; |
| const Index base_n = 64 * n_block_idx; |
| |
| const Index lane = threadIdx.x + 8 * (threadIdx.y % 4); |
| |
| // prefetch registers |
| float4 lhs_pf0; |
| float4 lhs_pf1; |
| |
| float4 rhs_pf0; |
| float4 rhs_pf1; |
| |
| // shared memory is formatted |
| // (contract idx in block, nocontract idx in block, block idx) |
| // where block idx is column major. This transposition limits the number of |
| // bank conflicts when reading the LHS. The core idea is that since the contracting |
| // index is shared by both sides, then the contracting index should be in threadIdx.x. |
| |
| // all of these indices assume float4 loading |
| // this thread loads the float4 starting at this index, and then also loads |
| // another float4 starting 32 columns to to the right |
| const Index horiz_block_idx = threadIdx.z / 2; |
| const Index vert_block_idx = threadIdx.x / 2 + 4 * (threadIdx.y % 2); |
| const Index horiz_idx_in_block = threadIdx.y / 2 + 4 * (threadIdx.z % 2); |
| const Index vert_idx_in_block = threadIdx.x % 2; |
| |
| // there's padding in both the LHS and RHS shared memory layouts. This padding |
| // allows for 0 bank conflicts on all shmem stores and loads. |
| // LHS padding: 1 float4 on each 8x8 block of floats |
| // RHS padding: 1 float2 on each block, and 12 additional float2s between vertical blocks |
| // 3 and 4 |
| |
| // storage indices |
| // lhs index with respect to float4s |
| const Index lhs_store_idx_base = |
| 136 * horiz_block_idx + |
| 17 * vert_block_idx + |
| 8 * vert_idx_in_block + |
| horiz_idx_in_block; |
| |
| // rhs index with respect to floats |
| const Index rhs_store_idx_base = |
| 552 * horiz_block_idx + |
| 66 * vert_block_idx + |
| 32 * (horiz_idx_in_block / 4) + (horiz_idx_in_block % 4) + |
| 16 * vert_idx_in_block + |
| ((vert_block_idx < 4) ? 0 : 24); |
| |
| const Index lhs_store_idx_0 = lhs_store_idx_base + 544 * 0; |
| const Index lhs_store_idx_1 = lhs_store_idx_base + 544 * 1; |
| |
| const Index rhs_store_idx_0 = (rhs_store_idx_base / 2) + ((lane < 16) ? 0 : 4); |
| const Index rhs_store_idx_1 = rhs_store_idx_0 + 2; |
| const Index rhs_store_idx_2 = rhs_store_idx_0 + 1104; |
| const Index rhs_store_idx_3 = rhs_store_idx_1 + 1104; |
| |
| // The below diagrams show which shmem index (with respect to floats) each element |
| // in an 8x8 input block gets packed into: |
| // LHS: |
| // 0 4 8 12 16 20 24 28 |
| // 1 5 9 13 17 21 25 29 |
| // 2 6 10 14 18 22 26 30 |
| // 3 7 11 15 19 23 27 31 |
| // 32 36 40 44 48 52 56 60 |
| // ... (pack as 2 rows of float4 indexed row major, each float4 is vertical) |
| // |
| // RHS: |
| // 0 1 2 3 32 33 34 35 |
| // 4 5 6 7 36 37 38 39 |
| // ... (pack as 2 cols of float4 indexed col major, each float4 is horizontal) |
| |
| // Each thread in a warp loads 2 float4s. This happens in 2 instructions. On each of these |
| // instruction, the warp loads 2 columns (2 cols * 64 elements / col = 128 elements = 32 threads |
| // * 4 elements/thread). For the LHS, we're able to store the loaded float4 directly into |
| // shmem (using a 128 bit store instruction). For the RHS, we need to transpose the data. |
| // This is done with warp shuffles. Furthermore, we only use 64 bit stores for the RHS, because |
| // 64 bits is only 2 columns (which is all we load in a warp), and the padding for the RHS |
| // doesn't meet 64 bit alignment requirements (namely, the 4 consecutive floats that we want |
| // to load on the RHS are 8 byte aligned, not 16 byte aligned, which is required for float4). |
| |
| const Index load_idx_vert = 4 * (threadIdx.x + 8 * (threadIdx.y % 2)); |
| const Index load_idx_horiz = (threadIdx.y / 2) + 4 * threadIdx.z; |
| |
| const Index lhs_vert = base_m + load_idx_vert; |
| const Index rhs_horiz_0 = base_n + load_idx_horiz; |
| const Index rhs_horiz_1 = base_n + load_idx_horiz + 32; |
| |
| #define prefetchIntoRegisters(base_k) \ |
| { \ |
| lhs_pf0 = internal::pset1<float4>(0); \ |
| lhs_pf1 = internal::pset1<float4>(0); \ |
| \ |
| rhs_pf0 = internal::pset1<float4>(0); \ |
| rhs_pf1 = internal::pset1<float4>(0); \ |
| \ |
| const Index lhs_horiz_0 = base_k + load_idx_horiz; \ |
| const Index lhs_horiz_1 = base_k + load_idx_horiz + 32; \ |
| if (!needs_edge_check || lhs_vert + 3 < m_size) { \ |
| if (lhs_horiz_1 < k_size) { \ |
| lhs_pf0 = lhs.loadPacket(lhs_vert, lhs_horiz_0); \ |
| lhs_pf1 = lhs.loadPacket(lhs_vert, lhs_horiz_1); \ |
| } else if (lhs_horiz_0 < k_size) { \ |
| lhs_pf0 = lhs.loadPacket(lhs_vert, lhs_horiz_0); \ |
| } \ |
| } else if (lhs_vert + 2 < m_size) { \ |
| if (lhs_horiz_1 < k_size) { \ |
| lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ |
| lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \ |
| lhs_pf0.z = lhs(lhs_vert + 2, lhs_horiz_0); \ |
| \ |
| lhs_pf1.x = lhs(lhs_vert + 0, lhs_horiz_1); \ |
| lhs_pf1.y = lhs(lhs_vert + 1, lhs_horiz_1); \ |
| lhs_pf1.z = lhs(lhs_vert + 2, lhs_horiz_1); \ |
| } else if (lhs_horiz_0 < k_size) { \ |
| lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ |
| lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \ |
| lhs_pf0.z = lhs(lhs_vert + 2, lhs_horiz_0); \ |
| } \ |
| } else if (lhs_vert + 1 < m_size) { \ |
| if (lhs_horiz_1 < k_size) { \ |
| lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ |
| lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \ |
| \ |
| lhs_pf1.x = lhs(lhs_vert + 0, lhs_horiz_1); \ |
| lhs_pf1.y = lhs(lhs_vert + 1, lhs_horiz_1); \ |
| } else if (lhs_horiz_0 < k_size) { \ |
| lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ |
| lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \ |
| } \ |
| } else if (lhs_vert < m_size) { \ |
| if (lhs_horiz_1 < k_size) { \ |
| lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ |
| lhs_pf1.x = lhs(lhs_vert + 0, lhs_horiz_1); \ |
| } else if (lhs_horiz_0 < k_size) { \ |
| lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ |
| } \ |
| } \ |
| \ |
| const Index rhs_vert = base_k + load_idx_vert; \ |
| if (rhs_vert + 3 < k_size) { \ |
| if (!needs_edge_check || rhs_horiz_1 < n_size) { \ |
| rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz_0); \ |
| rhs_pf1 = rhs.loadPacket(rhs_vert, rhs_horiz_1); \ |
| } else if (rhs_horiz_0 < n_size) { \ |
| rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz_0); \ |
| } \ |
| } else if (rhs_vert + 2 < k_size) { \ |
| if (!needs_edge_check || rhs_horiz_1 < n_size) { \ |
| rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ |
| rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \ |
| rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz_0); \ |
| \ |
| rhs_pf1.x = rhs(rhs_vert + 0, rhs_horiz_1); \ |
| rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz_1); \ |
| rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz_1); \ |
| } else if (rhs_horiz_0 < n_size) { \ |
| rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ |
| rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \ |
| rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz_0); \ |
| } \ |
| } else if (rhs_vert + 1 < k_size) { \ |
| if (!needs_edge_check || rhs_horiz_1 < n_size) { \ |
| rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ |
| rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \ |
| \ |
| rhs_pf1.x = rhs(rhs_vert + 0, rhs_horiz_1); \ |
| rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz_1); \ |
| } else if (rhs_horiz_0 < n_size) { \ |
| rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ |
| rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \ |
| } \ |
| } else if (rhs_vert < k_size) { \ |
| if (!needs_edge_check || rhs_horiz_1 < n_size) { \ |
| rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ |
| rhs_pf1.x = rhs(rhs_vert + 0, rhs_horiz_1); \ |
| } else if (rhs_horiz_0 < n_size) { \ |
| rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ |
| } \ |
| } \ |
| \ |
| float swap_val0 = (lane < 16) ? rhs_pf0.z : rhs_pf0.x; \ |
| float swap_val1 = (lane < 16) ? rhs_pf0.w : rhs_pf0.y; \ |
| float swap_val2 = (lane < 16) ? rhs_pf1.z : rhs_pf1.x; \ |
| float swap_val3 = (lane < 16) ? rhs_pf1.w : rhs_pf1.y; \ |
| \ |
| swap_val0 = __shfl_xor(swap_val0, 16); \ |
| swap_val1 = __shfl_xor(swap_val1, 16); \ |
| swap_val2 = __shfl_xor(swap_val2, 16); \ |
| swap_val3 = __shfl_xor(swap_val3, 16); \ |
| \ |
| if (lane < 16) { \ |
| rhs_pf0.z = swap_val0; \ |
| rhs_pf0.w = swap_val1; \ |
| rhs_pf1.z = swap_val2; \ |
| rhs_pf1.w = swap_val3; \ |
| } else { \ |
| rhs_pf0.x = swap_val0; \ |
| rhs_pf0.y = swap_val1; \ |
| rhs_pf1.x = swap_val2; \ |
| rhs_pf1.y = swap_val3; \ |
| } \ |
| } \ |
| |
| |
| #define writeRegToShmem(_) \ |
| lhs_shmem4[lhs_store_idx_0] = lhs_pf0; \ |
| \ |
| rhs_shmem2[rhs_store_idx_0] = make_float2(rhs_pf0.x, rhs_pf0.z); \ |
| rhs_shmem2[rhs_store_idx_1] = make_float2(rhs_pf0.y, rhs_pf0.w); \ |
| \ |
| lhs_shmem4[lhs_store_idx_1] = lhs_pf1; \ |
| \ |
| rhs_shmem2[rhs_store_idx_2] = make_float2(rhs_pf1.x, rhs_pf1.z); \ |
| rhs_shmem2[rhs_store_idx_3] = make_float2(rhs_pf1.y, rhs_pf1.w); \ |
| |
| // declare and initialize result array |
| #define res(i, j) _res_##i##j |
| #define initResultRow(i) \ |
| Scalar res(i, 0) = Scalar(0); \ |
| Scalar res(i, 1) = Scalar(0); \ |
| Scalar res(i, 2) = Scalar(0); \ |
| Scalar res(i, 3) = Scalar(0); \ |
| Scalar res(i, 4) = Scalar(0); \ |
| Scalar res(i, 5) = Scalar(0); \ |
| Scalar res(i, 6) = Scalar(0); \ |
| Scalar res(i, 7) = Scalar(0); \ |
| |
| initResultRow(0); |
| initResultRow(1); |
| initResultRow(2); |
| initResultRow(3); |
| initResultRow(4); |
| initResultRow(5); |
| initResultRow(6); |
| initResultRow(7); |
| #undef initResultRow |
| |
| for (Index base_k = 0; base_k < k_size; base_k += 64) { |
| // wait for previous iteration to finish with shmem. Despite common sense, |
| // the code is a bit faster with this here then at bottom of loop |
| __syncthreads(); |
| |
| prefetchIntoRegisters(base_k); |
| writeRegToShmem(); |
| |
| #undef prefetchIntoRegisters |
| #undef writeRegoToShmem |
| |
| // wait for shared mem packing to be done before starting computation |
| __syncthreads(); |
| |
| // compute 8x8 matrix product by outer product. This involves packing one column |
| // of LHS and one row of RHS into registers (takes 16 registers). |
| |
| float4 _lcol0; |
| float4 _lcol1; |
| float2 _rrow0; |
| float2 _rrow1; |
| float2 _rrow2; |
| float2 _rrow3; |
| |
| #define lcol0 _lcol0.x |
| #define lcol1 _lcol0.y |
| #define lcol2 _lcol0.z |
| #define lcol3 _lcol0.w |
| #define lcol4 _lcol1.x |
| #define lcol5 _lcol1.y |
| #define lcol6 _lcol1.z |
| #define lcol7 _lcol1.w |
| #define rrow0 _rrow0.x |
| #define rrow1 _rrow0.y |
| #define rrow2 _rrow1.x |
| #define rrow3 _rrow1.y |
| #define rrow4 _rrow2.x |
| #define rrow5 _rrow2.y |
| #define rrow6 _rrow3.x |
| #define rrow7 _rrow3.y |
| |
| // Now x corresponds to k, y to m, and z to n |
| const float4* lhs_block = &lhs_shmem4[threadIdx.x + 8 * (threadIdx.y % 2) + 17 * (threadIdx.y / 2)]; |
| const float2* rhs_block = &rhs_shmem2[2 * threadIdx.x + 16 * (threadIdx.z % 2) + 276 * (threadIdx.z / 2)]; |
| |
| #define lhs_element(i, k) lhs_block[68 * i + 136 * k] |
| #define rhs_element(k, j) rhs_block[33 * k + 1104 * j + ((k < 4) ? 0 : 12)] |
| |
| #define loadData(i) \ |
| _lcol0 = lhs_element(0, i); \ |
| _rrow0 = rhs_element(i, 0); \ |
| _rrow1 = *(&(rhs_element(i, 0)) + 1); \ |
| _lcol1 = lhs_element(1, i); \ |
| _rrow2 = rhs_element(i, 1); \ |
| _rrow3 = *(&(rhs_element(i, 1)) + 1); \ |
| |
| #define computeCol(j) \ |
| res(0, j) += lcol0 * rrow##j; \ |
| res(1, j) += lcol1 * rrow##j; \ |
| res(2, j) += lcol2 * rrow##j; \ |
| res(3, j) += lcol3 * rrow##j; \ |
| res(4, j) += lcol4 * rrow##j; \ |
| res(5, j) += lcol5 * rrow##j; \ |
| res(6, j) += lcol6 * rrow##j; \ |
| res(7, j) += lcol7 * rrow##j; \ |
| |
| #define computePass(i) \ |
| loadData(i); \ |
| \ |
| computeCol(0); \ |
| computeCol(1); \ |
| computeCol(2); \ |
| computeCol(3); \ |
| computeCol(4); \ |
| computeCol(5); \ |
| computeCol(6); \ |
| computeCol(7); \ |
| |
| computePass(0); |
| computePass(1); |
| computePass(2); |
| computePass(3); |
| computePass(4); |
| computePass(5); |
| computePass(6); |
| computePass(7); |
| |
| #undef lcol0 |
| #undef lcol1 |
| #undef lcol2 |
| #undef lcol3 |
| #undef lcol4 |
| #undef lcol5 |
| #undef lcol6 |
| #undef lcol7 |
| #undef rrow0 |
| #undef rrow1 |
| #undef rrow2 |
| #undef rrow3 |
| #undef rrow4 |
| #undef rrow5 |
| #undef rrow6 |
| #undef rrow7 |
| |
| #undef computePass |
| #undef computeCol |
| #undef loadData |
| #undef lhs_element |
| #undef rhs_element |
| |
| } // end loop over k |
| |
| // we've now iterated over all of the large (ie width 64) k blocks and |
| // accumulated results in registers. At this point thread (x, y, z) contains |
| // the sum across all big k blocks of the product of little k block of index (x, y) |
| // with block of index (y, z). To compute the final output, we need to reduce |
| // the 8 threads over y by summation. |
| #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask) |
| |
| #define reduceRow(i, mask) \ |
| shuffleInc(i, 0, mask); \ |
| shuffleInc(i, 1, mask); \ |
| shuffleInc(i, 2, mask); \ |
| shuffleInc(i, 3, mask); \ |
| shuffleInc(i, 4, mask); \ |
| shuffleInc(i, 5, mask); \ |
| shuffleInc(i, 6, mask); \ |
| shuffleInc(i, 7, mask); \ |
| |
| #define reduceMatrix(mask) \ |
| reduceRow(0, mask); \ |
| reduceRow(1, mask); \ |
| reduceRow(2, mask); \ |
| reduceRow(3, mask); \ |
| reduceRow(4, mask); \ |
| reduceRow(5, mask); \ |
| reduceRow(6, mask); \ |
| reduceRow(7, mask); \ |
| |
| // actually perform the reduction, now each thread of index (_, y, z) |
| // contains the correct values in its registers that belong in the output |
| // block |
| reduceMatrix(1); |
| reduceMatrix(2); |
| reduceMatrix(4); |
| |
| #undef shuffleInc |
| #undef reduceRow |
| #undef reduceMatrix |
| |
| // now we need to copy the 64 values into main memory. We can't split work |
| // among threads because all variables are in registers. There's 2 ways |
| // to do this: |
| // (1) have 1 thread do 64 writes from registers into global memory |
| // (2) have 1 thread do 64 writes into shared memory, and then 8 threads |
| // each do 8 writes into global memory. We can just overwrite the shared |
| // memory from the problem we just solved. |
| // (3) Copies the values into new registers using conditional logic. |
| |
| #define makeAssignments(i) \ |
| val0 = res(i, 0); \ |
| val1 = res(i, 1); \ |
| val2 = res(i, 2); \ |
| val3 = res(i, 3); \ |
| val4 = res(i, 4); \ |
| val5 = res(i, 5); \ |
| val6 = res(i, 6); \ |
| val7 = res(i, 7); \ |
| |
| Scalar val0; |
| Scalar val1; |
| Scalar val2; |
| Scalar val3; |
| Scalar val4; |
| Scalar val5; |
| Scalar val6; |
| Scalar val7; |
| |
| switch (threadIdx.x) { |
| case 0: |
| makeAssignments(0); |
| break; |
| case 1: |
| makeAssignments(1); |
| break; |
| case 2: |
| makeAssignments(2); |
| break; |
| case 3: |
| makeAssignments(3); |
| break; |
| case 4: |
| makeAssignments(4); |
| break; |
| case 5: |
| makeAssignments(5); |
| break; |
| case 6: |
| makeAssignments(6); |
| break; |
| case 7: |
| makeAssignments(7); |
| break; |
| } |
| |
| #undef res |
| |
| const Index vert_base = base_m + 4 * threadIdx.y + (threadIdx.x % 4) + 32 * (threadIdx.x / 4); |
| const Index horiz_base = base_n + 4 * threadIdx.z; |
| |
| if (!needs_edge_check || vert_base < m_size) { |
| if (!needs_edge_check || horiz_base + 35 < n_size) { |
| output(vert_base, horiz_base + 0) = val0; |
| output(vert_base, horiz_base + 1) = val1; |
| output(vert_base, horiz_base + 2) = val2; |
| output(vert_base, horiz_base + 3) = val3; |
| output(vert_base, horiz_base + 32) = val4; |
| output(vert_base, horiz_base + 33) = val5; |
| output(vert_base, horiz_base + 34) = val6; |
| output(vert_base, horiz_base + 35) = val7; |
| } else if (horiz_base + 34 < n_size) { |
| output(vert_base, horiz_base + 0) = val0; |
| output(vert_base, horiz_base + 1) = val1; |
| output(vert_base, horiz_base + 2) = val2; |
| output(vert_base, horiz_base + 3) = val3; |
| output(vert_base, horiz_base + 32) = val4; |
| output(vert_base, horiz_base + 33) = val5; |
| output(vert_base, horiz_base + 34) = val6; |
| } else if (horiz_base + 33 < n_size) { |
| output(vert_base, horiz_base + 0) = val0; |
| output(vert_base, horiz_base + 1) = val1; |
| output(vert_base, horiz_base + 2) = val2; |
| output(vert_base, horiz_base + 3) = val3; |
| output(vert_base, horiz_base + 32) = val4; |
| output(vert_base, horiz_base + 33) = val5; |
| } else if (horiz_base + 32 < n_size) { |
| output(vert_base, horiz_base + 0) = val0; |
| output(vert_base, horiz_base + 1) = val1; |
| output(vert_base, horiz_base + 2) = val2; |
| output(vert_base, horiz_base + 3) = val3; |
| output(vert_base, horiz_base + 32) = val4; |
| } else if (horiz_base + 3 < n_size) { |
| output(vert_base, horiz_base + 0) = val0; |
| output(vert_base, horiz_base + 1) = val1; |
| output(vert_base, horiz_base + 2) = val2; |
| output(vert_base, horiz_base + 3) = val3; |
| } else if (horiz_base + 2 < n_size) { |
| output(vert_base, horiz_base + 0) = val0; |
| output(vert_base, horiz_base + 1) = val1; |
| output(vert_base, horiz_base + 2) = val2; |
| } else if (horiz_base + 1 < n_size) { |
| output(vert_base, horiz_base + 0) = val0; |
| output(vert_base, horiz_base + 1) = val1; |
| } else if (horiz_base < n_size) { |
| output(vert_base, horiz_base + 0) = val0; |
| } |
| } |
| } |
| |
| |
| template<typename Index, typename LhsMapper, |
| typename RhsMapper, typename OutputMapper> |
| __global__ void |
| __launch_bounds__(512) |
| EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, |
| const OutputMapper output, |
| const Index m_size, const Index n_size, const Index k_size) { |
| __shared__ float4 lhs_shmem[(68 * 64) / 4]; |
| __shared__ float2 rhs_shmem[((66 * 8 + 24) * 8) / 2]; |
| |
| const Index m_block_idx = blockIdx.x; |
| const Index n_block_idx = blockIdx.y; |
| |
| const Index base_m = 64 * m_block_idx; |
| const Index base_n = 64 * n_block_idx; |
| |
| if (base_m + 63 < m_size && base_n + 63 < n_size) { |
| EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); |
| } else { |
| EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); |
| } |
| } |
| |
| |
| template<typename Indices, typename LeftArgType, typename RightArgType> |
| struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> : |
| public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> > { |
| |
| typedef GpuDevice Device; |
| |
| typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; |
| typedef TensorContractionEvaluatorBase<Self> Base; |
| |
| typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; |
| typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; |
| typedef typename XprType::Packet Packet; |
| typedef typename XprType::Index Index; |
| typedef typename XprType::CoeffReturnType CoeffReturnType; |
| typedef typename XprType::PacketReturnType PacketReturnType; |
| |
| typedef array<Index, TensorEvaluator<LeftArgType, Device>::Dimensions::count> left_dim_mapper_t; |
| typedef array<Index, TensorEvaluator<RightArgType, Device>::Dimensions::count> right_dim_mapper_t; |
| |
| typedef array<Index, internal::array_size<Indices>::value> contract_t; |
| typedef array<Index, TensorEvaluator<LeftArgType, Device>::Dimensions::count - internal::array_size<Indices>::value> left_nocontract_t; |
| typedef array<Index, TensorEvaluator<RightArgType, Device>::Dimensions::count - internal::array_size<Indices>::value> right_nocontract_t; |
| |
| static const int NumDims = max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count + TensorEvaluator<RightArgType, Device>::Dimensions::count - 2 * internal::array_size<Indices>::value>::size; |
| |
| typedef DSizes<Index, NumDims> Dimensions; |
| |
| // typedefs needed in evalTo |
| typedef typename internal::remove_const<typename LeftArgType::Scalar>::type LhsScalar; |
| typedef typename internal::remove_const<typename RightArgType::Scalar>::type RhsScalar; |
| |
| typedef TensorEvaluator<LeftArgType, Device> LeftEvaluator; |
| typedef TensorEvaluator<RightArgType, Device> RightEvaluator; |
| |
| typedef typename LeftEvaluator::Dimensions LeftDimensions; |
| typedef typename RightEvaluator::Dimensions RightDimensions; |
| |
| EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : |
| Base(op, device) {} |
| |
| // We need to redefine this method to make nvcc happy |
| EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { |
| this->m_leftImpl.evalSubExprsIfNeeded(NULL); |
| this->m_rightImpl.evalSubExprsIfNeeded(NULL); |
| if (data) { |
| evalTo(data); |
| return false; |
| } else { |
| this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar))); |
| evalTo(this->m_result); |
| return true; |
| } |
| } |
| |
| void evalTo(Scalar* buffer) const { |
| if (this->m_lhs_inner_dim_contiguous) { |
| if (this->m_rhs_inner_dim_contiguous) { |
| if (this->m_rhs_inner_dim_reordered) { |
| evalTyped<true, true, true, Unaligned>(buffer); |
| } |
| else { |
| evalTyped<true, true, false, Unaligned>(buffer); |
| } |
| } |
| else { |
| if (this->m_rhs_inner_dim_reordered) { |
| evalTyped<true, false, true, Unaligned>(buffer); |
| } |
| else { |
| evalTyped<true, false, false, Unaligned>(buffer); |
| } |
| } |
| } |
| else { |
| if (this->m_rhs_inner_dim_contiguous) { |
| if (this->m_rhs_inner_dim_reordered) { |
| evalTyped<false, true, true, Unaligned>(buffer); |
| } |
| else { |
| evalTyped<false, true, false, Unaligned>(buffer); |
| } |
| } |
| else { |
| if (this->m_rhs_inner_dim_reordered) { |
| evalTyped<false, false, true, Unaligned>(buffer); |
| } |
| else { |
| evalTyped<false, false, false, Unaligned>(buffer); |
| } |
| } |
| } |
| } |
| |
| template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> |
| void evalTyped(Scalar* buffer) const { |
| // columns in left side, rows in right side |
| const Index k = this->m_k_size; |
| |
| // rows in left side |
| const Index m = this->m_i_size; |
| |
| // columns in right side |
| const Index n = this->m_j_size; |
| |
| // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) |
| this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); |
| |
| typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, |
| LeftEvaluator, left_nocontract_t, |
| contract_t, 4, |
| lhs_inner_dim_contiguous, |
| false, Unaligned> LhsMapper; |
| |
| typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, |
| RightEvaluator, right_nocontract_t, |
| contract_t, 4, |
| rhs_inner_dim_contiguous, |
| rhs_inner_dim_reordered, Unaligned> RhsMapper; |
| |
| typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; |
| |
| |
| // initialize data mappers |
| LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides, |
| this->m_left_contracting_strides, this->m_k_strides); |
| |
| RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides, |
| this->m_right_contracting_strides, this->m_k_strides); |
| |
| OutputMapper output(buffer, m); |
| |
| const Index m_blocks = (m + 63) / 64; |
| const Index n_blocks = (n + 63) / 64; |
| const dim3 num_blocks(m_blocks, n_blocks, 1); |
| const dim3 block_size(8, 8, 8); |
| |
| cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte); |
| if (internal::is_same<LhsScalar, float>::value && |
| internal::is_same<RhsScalar, float>::value) { |
| EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper> |
| <<<num_blocks, block_size, 0, this->m_device.stream()>>>(lhs, rhs, output, m, n, k); |
| } else { |
| EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper> |
| <<<num_blocks, block_size, 0, this->m_device.stream()>>>(lhs, rhs, output, m, n, k); |
| } |
| |
| assert(cudaGetLastError() == cudaSuccess); |
| } |
| }; |
| |
| } // end namespace Eigen |
| |
| #endif // EIGEN_USE_GPU and __CUDACC__ |
| |
| #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H |