blob: 20b6c609a61140e95611d9b2a672361624478a00 [file]
// Benchmarks for Eigen Tensor contraction (generalized GEMM).
// Tests single-threaded (DefaultDevice) and multi-threaded (ThreadPoolDevice) variants.
// SPDX-FileCopyrightText: The Eigen Authors
// SPDX-License-Identifier: MPL-2.0
#define EIGEN_USE_THREADS
#include <benchmark/benchmark.h>
#include <unsupported/Eigen/Tensor>
#include <unsupported/Eigen/ThreadPool>
using namespace Eigen;
#ifndef SCALAR
#define SCALAR float
#endif
typedef SCALAR Scalar;
// --- DefaultDevice contraction (rank-2, equivalent to matrix multiply) ---
static void BM_Contraction(benchmark::State& state) {
const int M = state.range(0);
const int N = state.range(1);
const int K = state.range(2);
Tensor<Scalar, 2> A(M, K);
Tensor<Scalar, 2> B(K, N);
Tensor<Scalar, 2> C(M, N);
A.setRandom();
B.setRandom();
using ContractDims = Tensor<Scalar, 2>::DimensionPair;
Eigen::array<ContractDims, 1> contract_dims = {ContractDims(1, 0)};
for (auto _ : state) {
C = A.contract(B, contract_dims);
benchmark::DoNotOptimize(C.data());
benchmark::ClobberMemory();
}
state.counters["GFLOPS"] =
benchmark::Counter(2.0 * M * N * K, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::kIs1000);
}
// --- ThreadPoolDevice contraction ---
static void BM_Contraction_ThreadPool(benchmark::State& state) {
const int M = state.range(0);
const int N = state.range(1);
const int K = state.range(2);
const int threads = state.range(3);
Tensor<Scalar, 2> A(M, K);
Tensor<Scalar, 2> B(K, N);
Tensor<Scalar, 2> C(M, N);
A.setRandom();
B.setRandom();
ThreadPool tp(threads);
ThreadPoolDevice dev(&tp, threads);
using ContractDims = Tensor<Scalar, 2>::DimensionPair;
Eigen::array<ContractDims, 1> contract_dims = {ContractDims(1, 0)};
for (auto _ : state) {
C.device(dev) = A.contract(B, contract_dims);
benchmark::DoNotOptimize(C.data());
benchmark::ClobberMemory();
}
state.counters["GFLOPS"] =
benchmark::Counter(2.0 * M * N * K, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::kIs1000);
state.counters["threads"] = threads;
}
// --- Rank-3 batch contraction ---
// Contracts A(batch, M, K) with B(batch, K, N) over batch dim (0<->0)
// and K dim (2<->1), producing C(M, N). This sums over both the batch
// and inner dimensions: C(m, n) = sum_b sum_k A(b, m, k) * B(b, k, n).
static void BM_BatchContraction(benchmark::State& state) {
const int batch = state.range(0);
const int M = state.range(1);
const int N = state.range(2);
const int K = state.range(3);
Tensor<Scalar, 3> A(batch, M, K);
Tensor<Scalar, 3> B(batch, K, N);
Tensor<Scalar, 2> C(M, N);
A.setRandom();
B.setRandom();
using ContractDims = Tensor<Scalar, 3>::DimensionPair;
Eigen::array<ContractDims, 2> contract_dims = {ContractDims(0, 0), ContractDims(2, 1)};
for (auto _ : state) {
C = A.contract(B, contract_dims);
benchmark::DoNotOptimize(C.data());
benchmark::ClobberMemory();
}
state.counters["GFLOPS"] = benchmark::Counter(2.0 * batch * M * N * K, benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::kIs1000);
}
// --- RowMajor contraction ---
static void BM_Contraction_RowMajor(benchmark::State& state) {
const int M = state.range(0);
const int N = state.range(1);
const int K = state.range(2);
Tensor<Scalar, 2, RowMajor> A(M, K);
Tensor<Scalar, 2, RowMajor> B(K, N);
Tensor<Scalar, 2, RowMajor> C(M, N);
A.setRandom();
B.setRandom();
using ContractDims = Tensor<Scalar, 2, RowMajor>::DimensionPair;
Eigen::array<ContractDims, 1> contract_dims = {ContractDims(1, 0)};
for (auto _ : state) {
C = A.contract(B, contract_dims);
benchmark::DoNotOptimize(C.data());
benchmark::ClobberMemory();
}
state.counters["GFLOPS"] =
benchmark::Counter(2.0 * M * N * K, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::kIs1000);
}
// clang-format off
#define CONTRACTION_SIZES \
->Args({32, 32, 32})->Args({64, 64, 64})->Args({128, 128, 128}) \
->Args({256, 256, 256})->Args({512, 512, 512})->Args({1024, 1024, 1024}) \
->Args({256, 256, 1024})->Args({1024, 64, 64})
#define CONTRACTION_THREADPOOL_SIZES \
->Args({64, 64, 64, 1})->Args({64, 64, 64, 2})->Args({64, 64, 64, 4}) \
->Args({64, 64, 64, 8})->Args({64, 64, 64, 16}) \
->Args({256, 256, 256, 1})->Args({256, 256, 256, 2})->Args({256, 256, 256, 4}) \
->Args({256, 256, 256, 8})->Args({256, 256, 256, 16}) \
->Args({512, 512, 512, 1})->Args({512, 512, 512, 2})->Args({512, 512, 512, 4}) \
->Args({512, 512, 512, 8})->Args({512, 512, 512, 16}) \
->Args({1024, 1024, 1024, 1})->Args({1024, 1024, 1024, 2})->Args({1024, 1024, 1024, 4}) \
->Args({1024, 1024, 1024, 8})->Args({1024, 1024, 1024, 16})
#define BATCH_SIZES \
->Args({1, 64, 64, 64})->Args({1, 256, 256, 256}) \
->Args({8, 64, 64, 64})->Args({8, 256, 256, 256}) \
->Args({32, 64, 64, 64})->Args({32, 256, 256, 256})
// clang-format on
BENCHMARK(BM_Contraction) CONTRACTION_SIZES;
BENCHMARK(BM_Contraction_RowMajor) CONTRACTION_SIZES;
BENCHMARK(BM_Contraction_ThreadPool) CONTRACTION_THREADPOOL_SIZES;
BENCHMARK(BM_BatchContraction) BATCH_SIZES;