blob: fcb9623e0fc159d1c070d469a838ba99862a9976 [file]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
// SPDX-FileCopyrightText: The Eigen Authors
// SPDX-License-Identifier: MPL-2.0
#ifndef EIGEN_SME_GENERALBLOCKPANELKERNEL_H
#define EIGEN_SME_GENERALBLOCKPANELKERNEL_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
#include <arm_sme.h>
namespace Eigen {
namespace internal {
// Streaming vector length in floats, derived from the compile-time SVE VL.
// EIGEN_ARM64_SVE_VL is set in ConfigureVectorization.h from __ARM_FEATURE_SVE_BITS.
static constexpr int kSmeVlFloats = EIGEN_ARM64_SVE_VL / 32;
// Micro-kernel dimensions (2x2 ZA tile grid).
// mr = 2 * SVL32 = 32 (LHS panel width, loaded as 2 VL-wide rows)
// nr = 2 * SVL32 = 32 (RHS panel width, loaded as 2 VL-wide cols)
// The 4 ZA tiles are arranged as a 2x2 grid covering a 32x32 output block:
// ZA0 = A_lo x B_lo (rows 0-15, cols 0-15)
// ZA1 = A_lo x B_hi (rows 0-15, cols 16-31)
// ZA2 = A_hi x B_lo (rows 16-31, cols 0-15)
// ZA3 = A_hi x B_hi (rows 16-31, cols 16-31)
// Per depth step: 2 LHS + 2 RHS VL-wide loads drive 4 FMOPAs -- a 1:1
// compute:load ratio, and each of A_lo/A_hi/B_lo/B_hi is reused across 2
// FMOPAs.
static constexpr int kSmeVl = kSmeVlFloats; // 16 -- one VL's worth of floats
static constexpr int kSmeMr = 2 * kSmeVl; // 32 -- LHS panel width
static constexpr int kSmeNr = 2 * kSmeVl; // 32 -- RHS panel width
template <typename Scalar, typename Index>
static EIGEN_ALWAYS_INLINE void sve_copy_panel(Scalar* EIGEN_RESTRICT dst, const Scalar* EIGEN_RESTRICT src,
Index src_stride, Index depth, int width) __arm_streaming {
const int lo_w = width > kSmeVl ? kSmeVl : width;
const int hi_w = width > kSmeVl ? width - kSmeVl : 0;
const svbool_t pred_lo = svwhilelt_b32(uint32_t(0), uint32_t(lo_w));
const svbool_t pred_hi = svwhilelt_b32(uint32_t(0), uint32_t(hi_w));
for (Index k = 0; k < depth; ++k) {
svst1_f32(pred_lo, &dst[k * width], svld1_f32(pred_lo, &src[k * src_stride]));
if (hi_w > 0) {
svst1_f32(pred_hi, &dst[k * width + kSmeVl], svld1_f32(pred_hi, &src[k * src_stride + kSmeVl]));
}
}
}
template <typename Scalar, typename Index>
static EIGEN_ALWAYS_INLINE void sme_transpose_pack_32(Scalar* EIGEN_RESTRICT dst, const Scalar* EIGEN_RESTRICT src,
Index src_stride, Index depth) __arm_streaming __arm_inout("za") {
static_assert(kSmeMr == kSmeNr, "SME transpose pack assumes square 32-wide panels");
constexpr int PACK = kSmeMr;
constexpr int VL = kSmeVl;
const Index depth_vl = (depth / VL) * VL;
const svbool_t pg_all = svptrue_b32();
const svfloat32_t zero = svdup_f32(0.f);
for (Index k = 0; k < depth_vl; k += VL) {
for (uint32_t r = 0; r < VL; ++r) {
svld1_hor_za32(0, r, pg_all, &src[(r + 0) * src_stride + k]);
svld1_hor_za32(1, r, pg_all, &src[(r + VL) * src_stride + k]);
}
for (uint32_t c = 0; c < VL; ++c) {
svst1_f32(pg_all, &dst[(k + c) * PACK + 0], svread_ver_za32_f32_m(zero, pg_all, 0, c));
svst1_f32(pg_all, &dst[(k + c) * PACK + VL], svread_ver_za32_f32_m(zero, pg_all, 1, c));
}
}
if (depth_vl < depth) {
const int dtail = static_cast<int>(depth - depth_vl);
const svbool_t pg_tail = svwhilelt_b32(uint32_t(0), uint32_t(dtail));
for (uint32_t r = 0; r < VL; ++r) {
svld1_hor_za32(0, r, pg_tail, &src[(r + 0) * src_stride + depth_vl]);
svld1_hor_za32(1, r, pg_tail, &src[(r + VL) * src_stride + depth_vl]);
}
for (int c = 0; c < dtail; ++c) {
svst1_f32(pg_all, &dst[(depth_vl + c) * PACK + 0], svread_ver_za32_f32_m(zero, pg_all, 0, uint32_t(c)));
svst1_f32(pg_all, &dst[(depth_vl + c) * PACK + VL], svread_ver_za32_f32_m(zero, pg_all, 1, uint32_t(c)));
}
}
}
template <typename Scalar, typename Index>
static EIGEN_ALWAYS_INLINE void scalar_tail_pack(Scalar* EIGEN_RESTRICT dst_panel, const Scalar* EIGEN_RESTRICT src,
Index src_stride, Index depth, Index tail) __arm_streaming {
for (Index k = 0; k < depth; ++k) {
for (Index i = 0; i < tail; ++i) {
dst_panel[k * tail + i] = src[i * src_stride + k];
}
}
}
/*****************************************************************************
* gebp_traits specialization for SME (float x float)
*
* Overrides mr and nr so that:
* - gemm_pack_lhs receives Pack1 = mr = 32, creating uniform LHS panels
* - gemm_pack_rhs receives nr = 32, creating uniform RHS panels
* - mc is rounded to a multiple of 32, nc to a multiple of 32
* - Cache blocking (kc, mc, nc) is recomputed accordingly
*
* We provide custom gemm_pack_lhs/gemm_pack_rhs specializations for float,
* so both ColMajor and RowMajor source matrices produce an identical,
* simple packed format that the SME kernel consumes.
*****************************************************************************/
template <>
class gebp_traits<float, float, false, false, Architecture::Target, GEBPPacketFull>
: public gebp_traits<float, float, false, false, Architecture::Target, GEBPPacketHalf> {
public:
// The base class provides all the standard typedefs (LhsPacket, etc.)
// We only override the register-block sizes.
enum {
mr = kSmeMr, // 32 -- LHS panel width
nr = kSmeNr // 32 -- RHS panel width
};
};
/*****************************************************************************
* gemm_pack_lhs specialization for SME (float, ColMajor)
*
* Packs the LHS matrix into uniform panels of width = mr = 32.
* Each depth step k writes exactly MR contiguous floats (2 full SVE registers).
*****************************************************************************/
template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
typedef float Scalar;
__arm_locally_streaming static void pack_lhs_colmajor(Scalar* dst_base, const Scalar* EIGEN_RESTRICT src,
Index src_stride, Index depth, Index rows, Index dst_stride,
Index dst_offset) {
constexpr int MR = kSmeMr; // 32
const Index peeled_rows = (rows / MR) * MR;
// Full panels of width MR — two VL-wide SVE load+stores per depth step.
for (Index i = 0; i < peeled_rows; i += MR) {
Scalar* dst_panel = PanelMode ? dst_base + i * dst_stride + dst_offset * MR : dst_base + i * depth;
sve_copy_panel(dst_panel, src + i, src_stride, depth, MR);
}
// Tail panel: rows < MR, use predicated SVE.
if (peeled_rows < rows) {
const Index tail = rows - peeled_rows;
Scalar* dst_panel =
PanelMode ? dst_base + peeled_rows * dst_stride + dst_offset * tail : dst_base + peeled_rows * depth;
sve_copy_panel(dst_panel, src + peeled_rows, src_stride, depth, static_cast<int>(tail));
}
}
EIGEN_DONT_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
Index offset = 0) {
if (PanelMode) {
eigen_assert(stride >= depth && offset <= stride);
}
const Scalar* src = (rows > 0 && depth > 0) ? &lhs(0, 0) : nullptr;
pack_lhs_colmajor(blockA, src, lhs.stride(), depth, rows, stride, offset);
}
};
// RowMajor LHS packer -- SME in-ZA transpose.
//
// The packed output wants depth-major layout (32 rows contiguous per depth
// step) but the RowMajor source has rows contiguous (strided by depth per
// row). A natural SVE gather would be slow; instead we use ZA's 2D store
// as a free transpose: load 16 rows as horizontal slices of a ZA.S tile,
// then read vertical slices to produce depth-major output. At MR=32 we
// use two tiles per VL-depth iteration (rows 0-15 in tile 0, rows 16-31
// in tile 1).
template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
typedef float Scalar;
__arm_locally_streaming __arm_new("za") static void pack_lhs_rowmajor(Scalar* dst_base,
const Scalar* EIGEN_RESTRICT src,
Index src_stride, Index depth, Index rows,
Index dst_stride, Index dst_offset) {
constexpr int MR = kSmeMr;
const Index peeled_rows = (rows / MR) * MR;
for (Index i = 0; i < peeled_rows; i += MR) {
Scalar* dst_panel = PanelMode ? dst_base + i * dst_stride + dst_offset * MR : dst_base + i * depth;
sme_transpose_pack_32(dst_panel, src + i * src_stride, src_stride, depth);
}
// Row tail (rows - peeled_rows in [1, MR-1]). This branch runs at most
// once per pack_lhs call with < MR = 32 rows and would need a partial-ZA-
// tile dance to vectorise; total copies are < 32 * depth per call, which
// is noise vs the main packer's workload, so scalar is the simple choice.
if (peeled_rows < rows) {
const Index tail = rows - peeled_rows;
Scalar* dst_panel =
PanelMode ? dst_base + peeled_rows * dst_stride + dst_offset * tail : dst_base + peeled_rows * depth;
scalar_tail_pack(dst_panel, src + peeled_rows * src_stride, src_stride, depth, tail);
}
}
EIGEN_DONT_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
Index offset = 0) {
if (PanelMode) {
eigen_assert(stride >= depth && offset <= stride);
}
const Scalar* src = (rows > 0 && depth > 0) ? &lhs(0, 0) : nullptr;
pack_lhs_rowmajor(blockA, src, lhs.stride(), depth, rows, stride, offset);
}
};
/*****************************************************************************
* gemm_pack_rhs specialization for SME (float, ColMajor) -- SME in-ZA
* transpose, mirroring the RowMajor LHS packer.
*
* Packs the RHS matrix into panels of width = nr = 32. ColMajor source has
* columns contiguous; we load 32 columns as horizontal ZA slices and then
* read verticals to produce depth-major packed output.
*****************************************************************************/
template <typename Index, typename DataMapper, int nr_, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<float, Index, DataMapper, nr_, ColMajor, Conjugate, PanelMode> {
typedef float Scalar;
__arm_locally_streaming __arm_new("za") static void pack_rhs_colmajor(Scalar* dst_base,
const Scalar* EIGEN_RESTRICT src,
Index src_stride, Index depth, Index cols,
Index dst_stride, Index dst_offset) {
constexpr int NR = kSmeNr;
const Index peeled_cols = (cols / NR) * NR;
for (Index j = 0; j < peeled_cols; j += NR) {
Scalar* dst_panel = PanelMode ? dst_base + j * dst_stride + dst_offset * NR : dst_base + j * depth;
sme_transpose_pack_32(dst_panel, src + j * src_stride, src_stride, depth);
}
// Col tail (cols - peeled_cols in [1, NR-1]). Same reasoning as the LHS
// RowMajor packer's row tail: runs at most once per call, < NR = 32 cols,
// not worth the partial-ZA-tile handling.
if (peeled_cols < cols) {
const Index tail = cols - peeled_cols;
Scalar* dst_panel =
PanelMode ? dst_base + peeled_cols * dst_stride + dst_offset * tail : dst_base + peeled_cols * depth;
scalar_tail_pack(dst_panel, src + peeled_cols * src_stride, src_stride, depth, tail);
}
}
EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
Index offset = 0) {
if (PanelMode) {
eigen_assert(stride >= depth && offset <= stride);
}
const Scalar* src = (cols > 0 && depth > 0) ? &rhs(0, 0) : nullptr;
pack_rhs_colmajor(blockB, src, rhs.stride(), depth, cols, stride, offset);
}
};
// RowMajor RHS packer -- streaming SVE copy (mirrors the ColMajor LHS packer).
// Rows are contiguous in the source, so each depth-step is NR=32 contiguous
// fp32 = 2 VL-wide slices.
template <typename Index, typename DataMapper, int nr_, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<float, Index, DataMapper, nr_, RowMajor, Conjugate, PanelMode> {
typedef float Scalar;
__arm_locally_streaming static void pack_rhs_rowmajor(Scalar* dst_base, const Scalar* EIGEN_RESTRICT src,
Index src_stride, Index depth, Index cols, Index dst_stride,
Index dst_offset) {
constexpr int NR = kSmeNr;
const Index peeled_cols = (cols / NR) * NR;
for (Index j = 0; j < peeled_cols; j += NR) {
Scalar* dst_panel = PanelMode ? dst_base + j * dst_stride + dst_offset * NR : dst_base + j * depth;
sve_copy_panel(dst_panel, src + j, src_stride, depth, NR);
}
if (peeled_cols < cols) {
const Index tail = cols - peeled_cols;
Scalar* dst_panel =
PanelMode ? dst_base + peeled_cols * dst_stride + dst_offset * tail : dst_base + peeled_cols * depth;
sve_copy_panel(dst_panel, src + peeled_cols, src_stride, depth, static_cast<int>(tail));
}
}
EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
Index offset = 0) {
if (PanelMode) {
eigen_assert(stride >= depth && offset <= stride);
}
const Scalar* src = (cols > 0 && depth > 0) ? &rhs(0, 0) : nullptr;
pack_rhs_rowmajor(blockB, src, rhs.stride(), depth, cols, stride, offset);
}
};
/*****************************************************************************
* sme_store_za_tile -- Store one ZA.S tile back to C with alpha scaling.
*
* `pw` is the row-predicate width for this tile, `cw` the col-predicate width.
*****************************************************************************/
template <int TileId, typename Scalar, typename Index>
EIGEN_ALWAYS_INLINE void sme_store_za_tile(Scalar* EIGEN_RESTRICT C, Index C_stride_row, Index C_stride_col,
Scalar alpha, Index row_start, int pw, Index col_start,
int cw) __arm_streaming __arm_inout("za") {
const svbool_t pg_m = svwhilelt_b32((uint32_t)0, (uint32_t)pw);
const svbool_t pg_n = svwhilelt_b32((uint32_t)0, (uint32_t)cw);
// FMLA and FADD have equal latency/throughput on ARMv9 cores, and
// multiplying by alpha=1.0 is exact in IEEE-754 so the FMLA form is
// bit-identical to FADD in that case. A single unconditional FMLA
// keeps the store compact and measures no worse (and a few percent
// better on small matrices, where the branch would otherwise disrupt
// instruction scheduling).
const svfloat32_t vzero = svdup_f32(0.f);
const svfloat32_t valpha = svdup_f32(alpha);
if (C_stride_row == 1) {
// Column-major C: extract vertical slices (columns of the ZA tile)
for (int ci = 0; ci < cw; ++ci) {
svfloat32_t vres = svread_ver_za32_f32_m(vzero, pg_m, TileId, (uint32_t)ci);
Scalar* pC = C + row_start + (col_start + ci) * C_stride_col;
svfloat32_t vc = svld1_f32(pg_m, pC);
svst1_f32(pg_m, pC, svmla_f32_x(pg_m, vc, vres, valpha));
}
} else if (C_stride_col == 1) {
// Row-major C: extract horizontal slices (rows of the ZA tile)
for (int ri = 0; ri < pw; ++ri) {
svfloat32_t vres = svread_hor_za32_f32_m(vzero, pg_n, TileId, (uint32_t)ri);
Scalar* pC = C + (row_start + ri) * C_stride_row + col_start;
svfloat32_t vc = svld1_f32(pg_n, pC);
svst1_f32(pg_n, pC, svmla_f32_x(pg_n, vc, vres, valpha));
}
} else {
// General stride: extract rows to temp buffer, scatter to C
Scalar scratch[kSmeVl];
for (int ri = 0; ri < pw; ++ri) {
svfloat32_t vres = svread_hor_za32_f32_m(vzero, pg_n, TileId, (uint32_t)ri);
vres = svmul_f32_x(pg_n, vres, valpha);
svst1_f32(pg_n, scratch, vres);
for (int ci = 0; ci < cw; ++ci) {
C[(row_start + ri) * C_stride_row + (col_start + ci) * C_stride_col] += scratch[ci];
}
}
}
}
/*****************************************************************************
* sme_process_2x2 -- 2x2 tile micro-kernel (pw x NR output)
*
* Performance-critical inner kernel. Per 4 unrolled depth steps, performs
* 2 svld1_f32_x4 for LHS (8 VL vectors = 4 depth × 2 rows)
* 2 svld1_f32_x4 for RHS (8 VL vectors = 4 depth × 2 cols)
* 16 svmopa (4 depth × 2x2 tile grid)
* Total: 4 load instructions + 16 FMOPAs per 4 depth steps -- a 1:1
* compute:load ratio at the vector level.
*
* `pw` is the actual LHS panel width (== MR for full panels, < MR for the
* tail). Full path uses the compile-time constant MR as the LHS stride.
*****************************************************************************/
template <typename Scalar, typename Index>
EIGEN_ALWAYS_INLINE void sme_process_2x2(Scalar* EIGEN_RESTRICT C, Index C_stride_row, Index C_stride_col,
const Scalar* EIGEN_RESTRICT blA, const Scalar* EIGEN_RESTRICT blB,
Index depth, Scalar alpha, Index row_start, int pw,
Index col_start) __arm_streaming __arm_inout("za") {
constexpr int MR = kSmeMr; // 32 (full LHS panel width)
constexpr int NR = kSmeNr; // 32 (full RHS panel width)
constexpr int VL = kSmeVl; // 16
const int pw_lo = pw > VL ? VL : pw;
const int pw_hi = pw > VL ? pw - VL : 0;
const svbool_t pg_lo = svwhilelt_b32((uint32_t)0, (uint32_t)pw_lo);
const svbool_t pg_hi = svwhilelt_b32((uint32_t)0, (uint32_t)pw_hi);
const svbool_t pg_all = svptrue_b32();
const svcount_t pn_all = svptrue_c32();
svzero_za();
// ---- 4x unrolled depth loop ----
if (pw == MR) {
const Index depth_4 = (depth / 4) * 4;
Index k = 0;
for (; k < depth_4; k += 4) {
// Load 4 depth steps × 2 LHS halves = 8 VL vectors as two x4 loads.
// va_01 = [d0 lo, d0 hi, d1 lo, d1 hi]
// va_23 = [d2 lo, d2 hi, d3 lo, d3 hi]
svfloat32x4_t va_01 = svld1_f32_x4(pn_all, &blA[k * MR]);
svfloat32x4_t vb_01 = svld1_f32_x4(pn_all, &blB[k * NR]);
// d0
svfloat32_t a0_lo = svget4_f32(va_01, 0), a0_hi = svget4_f32(va_01, 1);
svfloat32_t b0_lo = svget4_f32(vb_01, 0), b0_hi = svget4_f32(vb_01, 1);
svmopa_za32_f32_m(0, pg_all, pg_all, a0_lo, b0_lo);
svmopa_za32_f32_m(1, pg_all, pg_all, a0_lo, b0_hi);
svmopa_za32_f32_m(2, pg_all, pg_all, a0_hi, b0_lo);
svmopa_za32_f32_m(3, pg_all, pg_all, a0_hi, b0_hi);
// d1
svfloat32_t a1_lo = svget4_f32(va_01, 2), a1_hi = svget4_f32(va_01, 3);
svfloat32_t b1_lo = svget4_f32(vb_01, 2), b1_hi = svget4_f32(vb_01, 3);
svmopa_za32_f32_m(0, pg_all, pg_all, a1_lo, b1_lo);
svmopa_za32_f32_m(1, pg_all, pg_all, a1_lo, b1_hi);
svmopa_za32_f32_m(2, pg_all, pg_all, a1_hi, b1_lo);
svmopa_za32_f32_m(3, pg_all, pg_all, a1_hi, b1_hi);
svfloat32x4_t va_23 = svld1_f32_x4(pn_all, &blA[(k + 2) * MR]);
svfloat32x4_t vb_23 = svld1_f32_x4(pn_all, &blB[(k + 2) * NR]);
// d2
svfloat32_t a2_lo = svget4_f32(va_23, 0), a2_hi = svget4_f32(va_23, 1);
svfloat32_t b2_lo = svget4_f32(vb_23, 0), b2_hi = svget4_f32(vb_23, 1);
svmopa_za32_f32_m(0, pg_all, pg_all, a2_lo, b2_lo);
svmopa_za32_f32_m(1, pg_all, pg_all, a2_lo, b2_hi);
svmopa_za32_f32_m(2, pg_all, pg_all, a2_hi, b2_lo);
svmopa_za32_f32_m(3, pg_all, pg_all, a2_hi, b2_hi);
// d3
svfloat32_t a3_lo = svget4_f32(va_23, 2), a3_hi = svget4_f32(va_23, 3);
svfloat32_t b3_lo = svget4_f32(vb_23, 2), b3_hi = svget4_f32(vb_23, 3);
svmopa_za32_f32_m(0, pg_all, pg_all, a3_lo, b3_lo);
svmopa_za32_f32_m(1, pg_all, pg_all, a3_lo, b3_hi);
svmopa_za32_f32_m(2, pg_all, pg_all, a3_hi, b3_lo);
svmopa_za32_f32_m(3, pg_all, pg_all, a3_hi, b3_hi);
}
for (; k < depth; ++k) {
svfloat32_t a_lo = svld1_f32(pg_all, &blA[k * MR]);
svfloat32_t a_hi = svld1_f32(pg_all, &blA[k * MR + VL]);
svfloat32_t b_lo = svld1_f32(pg_all, &blB[k * NR]);
svfloat32_t b_hi = svld1_f32(pg_all, &blB[k * NR + VL]);
svmopa_za32_f32_m(0, pg_all, pg_all, a_lo, b_lo);
svmopa_za32_f32_m(1, pg_all, pg_all, a_lo, b_hi);
svmopa_za32_f32_m(2, pg_all, pg_all, a_hi, b_lo);
svmopa_za32_f32_m(3, pg_all, pg_all, a_hi, b_hi);
}
} else {
// ---- Tail LHS panel (pw < MR, stride = pw) ----
// Predicate the mopa rows; RHS is still full NR wide.
for (Index k = 0; k < depth; ++k) {
svfloat32_t a_lo = svld1_f32(pg_lo, &blA[k * pw]);
svfloat32_t b_lo = svld1_f32(pg_all, &blB[k * NR]);
svfloat32_t b_hi = svld1_f32(pg_all, &blB[k * NR + VL]);
svmopa_za32_f32_m(0, pg_lo, pg_all, a_lo, b_lo);
svmopa_za32_f32_m(1, pg_lo, pg_all, a_lo, b_hi);
if (pw_hi > 0) {
svfloat32_t a_hi = svld1_f32(pg_hi, &blA[k * pw + VL]);
svmopa_za32_f32_m(2, pg_hi, pg_all, a_hi, b_lo);
svmopa_za32_f32_m(3, pg_hi, pg_all, a_hi, b_hi);
}
}
}
// Store ZA tiles back to C, split into the 2x2 grid.
// ZA0: rows [0..pw_lo), cols [0..VL)
// ZA1: rows [0..pw_lo), cols [VL..NR)
// ZA2: rows [VL..VL+pw_hi), cols [0..VL)
// ZA3: rows [VL..VL+pw_hi), cols [VL..NR)
sme_store_za_tile<0>(C, C_stride_row, C_stride_col, alpha, row_start, pw_lo, col_start, VL);
sme_store_za_tile<1>(C, C_stride_row, C_stride_col, alpha, row_start, pw_lo, col_start + VL, VL);
if (pw_hi > 0) {
sme_store_za_tile<2>(C, C_stride_row, C_stride_col, alpha, row_start + VL, pw_hi, col_start, VL);
sme_store_za_tile<3>(C, C_stride_row, C_stride_col, alpha, row_start + VL, pw_hi, col_start + VL, VL);
}
}
/*****************************************************************************
* sme_process_microblock -- Handle a single partial tile of size (pw × cw).
*
* Used for the tail column region when 0 < cw < NR. One ZA tile per call.
*****************************************************************************/
template <int TileId, typename Scalar, typename Index>
EIGEN_ALWAYS_INLINE void sme_process_microblock(Scalar* EIGEN_RESTRICT C, Index C_stride_row, Index C_stride_col,
const Scalar* EIGEN_RESTRICT blA, const Scalar* EIGEN_RESTRICT blB,
int blA_stride, int blB_stride, Index depth, Scalar alpha,
Index row_start, int pw, Index col_start,
int cw) __arm_streaming __arm_inout("za") {
const svbool_t pg_m = svwhilelt_b32((uint32_t)0, (uint32_t)pw);
const svbool_t pg_n = svwhilelt_b32((uint32_t)0, (uint32_t)cw);
// svzero_mask_za's mask selects ZA.D byte-tiles (8 bits, one per ZA.D tile).
// An fp32 ZA.S tile is composed of two ZA.D byte-tiles at indices
// {TileId, TileId + 4}, so the mask 0x11 << TileId = (1 << TileId) |
// (1 << (TileId + 4)) zeroes exactly this ZA.S tile.
svzero_mask_za(0x11 << TileId);
for (Index k = 0; k < depth; ++k) {
svfloat32_t va = svld1_f32(pg_m, &blA[k * blA_stride]);
svfloat32_t vb = svld1_f32(pg_n, &blB[k * blB_stride]);
svmopa_za32_f32_m(TileId, pg_m, pg_n, va, vb);
}
sme_store_za_tile<TileId>(C, C_stride_row, C_stride_col, alpha, row_start, pw, col_start, cw);
}
template <typename Scalar, typename Index>
EIGEN_DONT_INLINE __arm_locally_streaming __arm_new("za") void sme_gebp_impl(
Scalar* C, Index C_stride_row, Index C_stride_col, const Scalar* blockA, const Scalar* blockB, Index rows,
Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
static_assert(Eigen::internal::is_same<Scalar, float>::value, "SME fp32 kernel only supports float");
constexpr int MR = kSmeMr; // 32
constexpr int NR = kSmeNr; // 32
constexpr int VL = kSmeVl; // 16
const Index peeled_cols = (cols / NR) * NR;
const Index peeled_rows = (rows / MR) * MR;
const int tail_pw = static_cast<int>(rows - peeled_rows);
// Column-outer, row-inner: keeps blB (one kc × NR panel) hot in L1 while
// smaller blA tiles stream from L2. The outer GOTO loop in
// GeneralMatrixMatrix.h ensures blockA fits in L2 via mc-blocking.
for (Index j = 0; j < peeled_cols; j += NR) {
const Scalar* blB = &blockB[j * strideB + offsetB * NR];
for (Index i = 0; i < peeled_rows; i += MR) {
const Scalar* blA = blockA + i * strideA + offsetA * MR;
sme_process_2x2(C, C_stride_row, C_stride_col, blA, blB, depth, alpha, i, MR, j);
}
if (tail_pw > 0) {
const Scalar* blA = blockA + peeled_rows * strideA + offsetA * tail_pw;
sme_process_2x2(C, C_stride_row, C_stride_col, blA, blB, depth, alpha, peeled_rows, tail_pw, j);
}
}
// Tail columns (cols - peeled_cols in [1, NR - 1]).
if (peeled_cols < cols) {
const int tail_cols = static_cast<int>(cols - peeled_cols);
const Scalar* blB_tail = &blockB[peeled_cols * strideB + offsetB * tail_cols];
const int lo_cw = tail_cols > VL ? VL : tail_cols;
const int hi_cw = tail_cols > VL ? tail_cols - VL : 0;
const Index num_panels = peeled_rows + (tail_pw > 0 ? MR : 0);
for (Index i = 0; i < num_panels; i += MR) {
const int pw = (i < peeled_rows) ? MR : tail_pw;
const int pw_lo = pw > VL ? VL : pw;
const int pw_hi = pw > VL ? pw - VL : 0;
const Scalar* blA = blockA + i * strideA + offsetA * pw;
// lo col slice (width lo_cw, rows row_start .. row_start+pw_lo)
sme_process_microblock<0>(C, C_stride_row, C_stride_col, blA, blB_tail, pw, tail_cols, depth, alpha, i, pw_lo,
peeled_cols, lo_cw);
// hi col slice (width hi_cw), if present
if (hi_cw > 0) {
sme_process_microblock<1>(C, C_stride_row, C_stride_col, blA, blB_tail + VL, pw, tail_cols, depth, alpha, i,
pw_lo, peeled_cols + VL, hi_cw);
}
if (pw_hi > 0) {
sme_process_microblock<2>(C, C_stride_row, C_stride_col, blA + VL, blB_tail, pw, tail_cols, depth, alpha,
i + VL, pw_hi, peeled_cols, lo_cw);
if (hi_cw > 0) {
sme_process_microblock<3>(C, C_stride_row, C_stride_col, blA + VL, blB_tail + VL, pw, tail_cols, depth, alpha,
i + VL, pw_hi, peeled_cols + VL, hi_cw);
}
}
}
}
}
template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
struct gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
typedef float Scalar;
typedef float ResScalar;
EIGEN_DONT_INLINE void operator()(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows,
Index depth, Index cols, ResScalar alpha, Index strideA = -1, Index strideB = -1,
Index offsetA = 0, Index offsetB = 0) {
static_assert(!ConjugateLhs && !ConjugateRhs, "SME fp32 kernel does not support conjugation");
if (strideA == -1) strideA = depth;
if (strideB == -1) strideB = depth;
if (rows <= 0 || cols <= 0 || depth <= 0) return;
Scalar* C_base = const_cast<Scalar*>(&res(0, 0));
const Index C_stride_row = &res(1, 0) - &res(0, 0);
const Index C_stride_col = &res(0, 1) - &res(0, 0);
sme_gebp_impl<Scalar, Index>(C_base, C_stride_row, C_stride_col, blockA, blockB, rows, depth, cols, alpha, strideA,
strideB, offsetA, offsetB);
}
};
} // namespace internal
} // namespace Eigen
#endif // EIGEN_SME_GENERALBLOCKPANELKERNEL_H