Files
triton/unittest/Dialect/TritonGPU/SwizzleTest.cpp
Keren Zhou 153aecb339 [Triton-MLIR][BACKEND] insert_slice_async on GPUs < sm80 (#908)
`insert_slice_async` is decomposed into `load + insert_slice` in the
backend.

Not sure if V100 perf can match the master branch though in this way.
Maybe the performance can be improved if instructions are arranged in
the following form:

```
%0 = load
%1 = load 
%2 = load 
...
insert_slice %0
insert_slice %1
insert_slice %2
```

Tested on A100 when manually enabling this decomposition.
Tests on V100 haven't been integrated yet, we can divide the tests into
two phases:
1. Test only load, insert_slice, and insert_slice_async, given TritonGPU
IRs in `test_backend.py`.
2. End to end gemm tests on V100.
2022-11-24 14:05:54 -08:00

53 lines
1.8 KiB
C++

#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <gtest/gtest.h>
using namespace mlir;
using mlir::triton::gpu::SharedEncodingAttr;
struct swizzleParams {
int vec;
int perPhase;
int maxPhase;
};
struct ParamT {
std::array<int64_t, 2> shape;
int opIdx;
int typeWidth;
swizzleParams refSwizzle;
};
class SwizzleDotOperandTestFixture : public ::testing::TestWithParam<ParamT> {
protected:
ParamType param;
};
TEST_P(SwizzleDotOperandTestFixture, DotOperands) {
auto params = GetParam();
// init context
MLIRContext ctx;
ctx.loadDialect<triton::gpu::TritonGPUDialect>();
// create encoding
auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, {1, 1});
auto encoding =
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent);
// create element type
Type eltType = IntegerType::get(&ctx, params.typeWidth);
auto layout =
SharedEncodingAttr::get(&ctx, encoding, params.shape, {1, 0}, eltType);
ASSERT_EQ(layout.getVec(), params.refSwizzle.vec);
ASSERT_EQ(layout.getPerPhase(), params.refSwizzle.perPhase);
ASSERT_EQ(layout.getMaxPhase(), params.refSwizzle.maxPhase);
}
INSTANTIATE_TEST_SUITE_P(TestDotOperands, SwizzleDotOperandTestFixture,
::testing::Values(ParamT{{128, 64}, 0, 16, {8, 1, 8}},
ParamT{{64, 256}, 1, 16, {8, 1, 8}},
ParamT{{128, 32}, 0, 16, {8, 2, 4}},
ParamT{{32, 128}, 1, 16, {8, 1, 8}},
ParamT{{32, 32}, 0, 16, {8, 2, 4}},
ParamT{{32, 32}, 1, 16, {8, 2, 4}},
ParamT{{16, 16}, 0, 16, {8, 4, 2}},
ParamT{{16, 16}, 1, 16, {8, 4, 2}}));