[Triton-MLIR][OPTIMIZER] Cleaned up swizzling (#869)
Swizzling is no longer implemented as a separate pass. It is instead done in a specialized constructor of SharedEncodingAttr, and tested via google tests instead of triton-opt + filecheck. In the future we may want to implement it as a pass again once we have an additional dialect between TritonGPU and LLVM.
This commit is contained in:
1
unittest/Dialect/CMakeLists.txt
Normal file
1
unittest/Dialect/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(TritonGPU)
|
6
unittest/Dialect/TritonGPU/CMakeLists.txt
Normal file
6
unittest/Dialect/TritonGPU/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
|
||||
add_triton_ut(
|
||||
NAME TestSwizzling
|
||||
SRCS SwizzleTest.cpp
|
||||
LIBS TritonGPUIR ${dialect_libs} ${conversion_libs}
|
||||
)
|
52
unittest/Dialect/TritonGPU/SwizzleTest.cpp
Normal file
52
unittest/Dialect/TritonGPU/SwizzleTest.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using namespace mlir;
|
||||
using mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
struct swizzleParams {
|
||||
int vec;
|
||||
int perPhase;
|
||||
int maxPhase;
|
||||
};
|
||||
|
||||
struct ParamT {
|
||||
std::array<int64_t, 2> shape;
|
||||
int opIdx;
|
||||
int typeWidth;
|
||||
swizzleParams refSwizzle;
|
||||
};
|
||||
|
||||
class SwizzleDotOperandTestFixture : public ::testing::TestWithParam<ParamT> {
|
||||
protected:
|
||||
ParamType param;
|
||||
};
|
||||
|
||||
TEST_P(SwizzleDotOperandTestFixture, DotOperands) {
|
||||
auto params = GetParam();
|
||||
// init context
|
||||
MLIRContext ctx;
|
||||
ctx.loadDialect<triton::gpu::TritonGPUDialect>();
|
||||
// create encoding
|
||||
auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, {1, 1});
|
||||
auto encoding =
|
||||
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent);
|
||||
|
||||
// create element type
|
||||
Type eltType = IntegerType::get(&ctx, params.typeWidth);
|
||||
auto layout = SharedEncodingAttr::get(&ctx, encoding, params.shape, eltType);
|
||||
|
||||
ASSERT_EQ(layout.getVec(), params.refSwizzle.vec);
|
||||
ASSERT_EQ(layout.getPerPhase(), params.refSwizzle.perPhase);
|
||||
ASSERT_EQ(layout.getMaxPhase(), params.refSwizzle.maxPhase);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestDotOperands, SwizzleDotOperandTestFixture,
|
||||
::testing::Values(ParamT{{128, 64}, 0, 16, {8, 1, 8}},
|
||||
ParamT{{64, 256}, 1, 16, {8, 1, 8}},
|
||||
ParamT{{128, 32}, 0, 16, {8, 2, 4}},
|
||||
ParamT{{32, 128}, 1, 16, {8, 1, 8}},
|
||||
ParamT{{32, 32}, 0, 16, {8, 2, 4}},
|
||||
ParamT{{32, 32}, 1, 16, {8, 2, 4}},
|
||||
ParamT{{16, 16}, 0, 16, {8, 4, 2}},
|
||||
ParamT{{16, 16}, 1, 16, {8, 4, 2}}));
|
Reference in New Issue
Block a user