[Triton-MLIR][BACKEND] Support $c from mma layout in dot (#798)

This PR does

1. Support the case where $c holding a mma layout, this should be useful
in forloop in k-axis in GEMM
2. Fix the `unrealized_conversion_cast` in ConvertLayout[shared->dot_op]

Known issue

1. There is some IO conflict in GEMM with a k-forloop, it is temporarily
solved by [adding a
barrier](https://github.com/openai/triton/pull/798/files#diff-8a9a5a7f4a025fb1299af29d190d5626bd9000406d3ea47c49679272d3d6abe9R3028)
in dot conversion, but we are still working on it, will get a more
generic fix for it in the following PR.
2. The parallel pass will result in a buggy instruction result type
```mlir
%1049 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.commit_group ;", ""  : () -> !llvm.void
%1050 = builtin.unrealized_conversion_cast %1049 : !llvm.void to !llvm.ptr<f16, 3>
```
So we temporarily disable it.
This commit is contained in:
Yan Chunwei
2022-10-26 10:33:04 +08:00
committed by GitHub
parent a2cbe7af91
commit 4dc2396ca0
3 changed files with 226 additions and 64 deletions

View File

@@ -622,6 +622,13 @@ protected:
Value smem;
};
Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
Type resType, Type elemType,
Value constVal,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Location loc);
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
// LLVM::StructType value.
//
@@ -632,16 +639,26 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto layout = tensorTy.getEncoding();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(srcType, constVal);
size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
llvm::SmallVector<Value, 4> elems(elemsPerThread, llSrc);
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
auto tensorTy = resType.cast<RankedTensorType>();
auto layout = tensorTy.getEncoding();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(srcType, constVal);
size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
return getStructFromElements(loc, elems, rewriter, structTy);
return getStructFromElements(loc, elems, rewriter, structTy);
} else if (auto mmaLayout =
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
return convertSplatLikeOpWithMmaLayout(
mmaLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
} else
assert(false && "Unsupported layout found in ConvertSplatLikeOp");
return Value{};
}
struct SplatOpConversion
@@ -2436,8 +2453,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
MLIRContext *ctx = op->getContext();
bool allowTF32 = op.allowTF32();
assert(isSplatLike(C) && "Currently only splat-like C is supported now");
// Here we assume the DotOp's operands always comes from shared memory.
auto AShape = A.getType().cast<RankedTensorType>().getShape();
size_t reduceAxis = 1;
@@ -2536,6 +2551,31 @@ struct DotOpConversionHelper {
mmaType = getTensorCoreTypeFromOperand(operandTy);
}
// Get the M and N of mat instruction shape.
static std::tuple<int, int> getMatShapeMN() {
// According to DotOpConversionHelper::mmaMatShape, all the matrix shape's
// M,N are {8,8}
return {8, 8};
}
// Get the M and N of mma instruction shape.
static std::tuple<int, int> getInstrShapeMN() {
// According to DotOpConversionHelper::mmaInstrShape, all the M,N are {16,8}
return {16, 8};
}
static std::tuple<int, int> getRepMN(const RankedTensorType &tensorTy) {
auto mmaLayout = tensorTy.getEncoding().cast<MmaEncodingAttr>();
auto wpt = mmaLayout.getWarpsPerCTA();
int M = tensorTy.getShape()[0];
int N = tensorTy.getShape()[1];
auto [instrM, instrN] = getInstrShapeMN();
int repM = std::max<int>(M / (wpt[0] * instrM), 1);
int repN = std::max<int>(N / (wpt[1] * instrN), 1);
return {repM, repN};
}
Type getShemPtrTy() const {
switch (mmaType) {
case TensorCoreType::FP32_FP16_FP16_FP32:
@@ -2633,15 +2673,20 @@ struct DotOpConversionHelper {
return mmaInstrShape.at(mmaType);
}
static ArrayRef<int> getMmaInstrShape(TensorCoreType tensorCoreType) {
assert(tensorCoreType != TensorCoreType::NOT_APPLICABLE &&
"Unknown mma type found.");
return mmaInstrShape.at(tensorCoreType);
}
ArrayRef<int> getMmaMatShape() const {
assert(mmaType != TensorCoreType::NOT_APPLICABLE &&
"Unknown mma type found.");
return mmaMatShape.at(mmaType);
}
// Deduce the TensorCoreType from either $a or $b's type. This method is not
// safe, but we cannot get the DotOp in some getmaMatShape usage case.
TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) const {
// Deduce the TensorCoreType from either $a or $b's type.
static TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) {
auto tensorTy = operandTy.cast<RankedTensorType>();
auto elemTy = tensorTy.getElementType();
if (elemTy.isF16())
@@ -2814,22 +2859,58 @@ struct MMA16816ConversionHelper {
// \param operand is either $a or $b's type.
inline int getNumRepM(Type operand, int M) const {
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
return std::max<int>(M / (wpt[0] * mmaInstrM), 1);
return getNumRepM(operand, M, wpt[0]);
}
// \param operand is either $a or $b's type.
inline int getNumRepN(Type operand, int N) const {
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
return std::max<int>(N / (wpt[1] * mmaInstrN), 1);
return getNumRepN(operand, N, wpt[1]);
}
// \param operand is either $a or $b's type.
inline int getNumRepK(Type operand, int K) const {
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
return getNumRepK_(operand, K);
}
static int getNumRepM(Type operand, int M, int wpt) {
auto tensorCoreType =
DotOpConversionHelper::getTensorCoreTypeFromOperand(operand);
int mmaInstrM = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[0];
return std::max<int>(M / (wpt * mmaInstrM), 1);
}
static int getNumRepN(Type operand, int N, int wpt) {
auto tensorCoreType =
DotOpConversionHelper::getTensorCoreTypeFromOperand(operand);
int mmaInstrN = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[1];
return std::max<int>(N / (wpt * mmaInstrN), 1);
}
static int getNumRepK_(Type operand, int K) {
auto tensorCoreType =
DotOpConversionHelper::getTensorCoreTypeFromOperand(operand);
int mmaInstrK = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[2];
return std::max<int>(K / mmaInstrK, 1);
}
// Get number of elements per thread for $a operand.
static size_t getANumElemsPerThread(RankedTensorType operand,
ArrayRef<unsigned> wpt) {
auto shape = operand.getShape();
int repM = getNumRepM(operand, shape[0], wpt[0]);
int repK = getNumRepK_(operand, shape[1]);
return 4 * repM * repK;
}
// Get number of elements per thread for $b operand.
static size_t getBNumElemsPerThread(RankedTensorType operand,
ArrayRef<unsigned> wpt) {
auto shape = operand.getShape();
int repK = getNumRepK_(operand, shape[0]);
int repN = getNumRepN(operand, shape[1], wpt[1]);
return 4 * std::max(repN / 2, 1) * repK;
}
// Loading $a from smem to registers, returns a LLVM::Struct.
Value loadA(Value tensor, Value llTensor) const {
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
@@ -2863,9 +2944,6 @@ struct MMA16816ConversionHelper {
// step2. Format the values to LLVM::Struct to passing to mma codegen.
Value result = composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
// TODO[Superjomn]: Replace the convert_layout op with the result once the
// DotOperandEncodingAttr is ready.
return result;
}
@@ -2894,15 +2972,21 @@ struct MMA16816ConversionHelper {
return result;
}
// Loading $c from smem(?) to registers, returns a Value.
// NOTE Only SplatLike tensor is supported now.
Value loadC(Value tensor) const {
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
// shared layout or blocked layout, we will support them by expanding
// convert_layout.
auto hc = helper.loadSplatLikeC(tensor, loc, rewriter);
assert(hc.size() == 4UL && "Only splat-like C is supported now");
return hc[0];
// Loading $c to registers, returns a Value.
Value loadC(Value tensor, Value llTensor) const {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto [repM, repN] = DotOpConversionHelper::getRepMN(tensorTy);
size_t fcSize = 4 * repM * repN;
assert(tensorTy.getEncoding().isa<MmaEncodingAttr>() &&
"Currently, we only support $c with a mma layout.");
// Load a normal C tensor with mma layout, that should be a
// LLVM::struct with fcSize elements.
auto structTy = llTensor.getType().cast<LLVM::LLVMStructType>();
assert(structTy.getBody().size() == fcSize &&
"DotOp's $c operand should pass the same number of values as $d in "
"mma layout.");
return llTensor;
}
// Conduct the Dot conversion.
@@ -2934,9 +3018,8 @@ struct MMA16816ConversionHelper {
getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK);
ValueTable hb = getValuesFromDotOperandLayoutStruct(
loadedB, std::max(numRepN / 2, 1), numRepK);
const int fcSize = 4 * numRepM * numRepN;
SmallVector<Value> fc(fcSize, loadedC);
auto fc = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct(
loc, loadedC, rewriter);
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
unsigned colsPerThread = numRepN * 2;
@@ -2974,6 +3057,11 @@ struct MMA16816ConversionHelper {
for (unsigned n = 0; n < numRepN; ++n)
callMma(2 * m, n, 2 * k);
// NOTE, the barrier here is a temporary trick making the gemm with a
// k-forloop pass the precision test, or it will fail.
// TODO[Superjomn]: Fix with a more general and performance-friendly way.
barrier;
// replace with new packed result
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fc.size(), type::f32Ty(ctx)));
@@ -3123,9 +3211,6 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
} else if (dotOperandLayout.getOpIdx() == 1) {
// operand $b
res = mmaHelper.loadB(src, adaptor.src());
} else if (dotOperandLayout.getOpIdx() == 2) {
// operand $c
res = mmaHelper.loadC(src);
}
rewriter.replaceOp(op, res);
@@ -3163,10 +3248,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
loadedB = mmaHelper.loadB(op.b(), adaptor.b());
}
// TODO[Superjomn]: Process C as a mma layout.
// Currently, C is simply treated as a Splat Op, and the data layout is not
// mattered.
loadedC = mmaHelper.loadC(op.c());
loadedC = mmaHelper.loadC(op.c(), adaptor.c());
return mmaHelper.convertDot(A, B, C, op.d(), loadedA, loadedB, loadedC, op,
adaptor);
@@ -3174,6 +3256,26 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
/// ====================== mma codegen end ============================
Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
Type resType, Type elemType,
Value constVal,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Location loc) {
if (layout.getVersion() == 2) {
auto tensorTy = resType.cast<RankedTensorType>();
auto [repM, repN] = DotOpConversionHelper::getRepMN(tensorTy);
size_t fcSize = 4 * repM * repN;
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(fcSize, elemType));
return getStructFromElements(loc, SmallVector<Value>(fcSize, constVal),
rewriter, structTy);
}
assert(false && "Unsupported mma layout found");
}
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
public:
using TypeConverter::convertType;
@@ -3199,6 +3301,7 @@ public:
}
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
auto ctx = type.getContext();
Attribute layout = type.getEncoding();
if (layout &&
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
@@ -3207,11 +3310,50 @@ public:
getElemsPerThread(layout, type.getShape());
SmallVector<Type, 4> types(numElementsPerThread,
convertType(type.getElementType()));
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto shared_layout =
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
} else if (auto mmaLayout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 2) {
auto [repM, repN] = DotOpConversionHelper::getRepMN(type);
size_t fcSize = 4 * repM * repN;
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fcSize, type.getElementType()));
}
llvm::errs()
<< "Unexpected mma layout detected in TritonToLLVMTypeConverter";
return llvm::None;
} else if (auto dot_op_layout =
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
auto mmaLayout = dot_op_layout.getParent().cast<MmaEncodingAttr>();
if (mmaLayout.getVersion() == 2) {
auto wpt = mmaLayout.getWarpsPerCTA();
Type elemTy = type.getElementType();
if (dot_op_layout.getOpIdx() == 0) { // $a
int elems =
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt);
Type x2Ty = vec_ty(elemTy, 2);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems, x2Ty));
}
if (dot_op_layout.getOpIdx() == 1) { // $b
int elems =
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt);
Type x2Ty = vec_ty(elemTy, 2);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems, x2Ty));
}
}
llvm::errs() << "Unexpected dot operand layout detected in "
"TritonToLLVMTypeConverter";
return llvm::None;
}
return llvm::None;
}
};

View File

@@ -35,6 +35,9 @@ def matmul_no_scf_kernel(
[256, 128, 16, 4],
[128, 16, 32, 4],
[32, 128, 64, 4],
[128, 128, 64, 4],
[64, 128, 128, 4],
[64, 128, 128, 2],
])
def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
@@ -78,24 +81,39 @@ def matmul_kernel(
tl.store(c_ptrs, accumulator)
# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment
# @pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
# [128, 256, 128, 4, 128, 256, 32],
# # [256, 128, 64, 4, 256, 128, 16],
# # [128, 16, 128, 4, 128, 16, 32],
# # [32, 128, 256, 4, 32, 128, 64],
# ])
# def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
# a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
# b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
# c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
# grid = lambda META: (1, )
# matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
# stride_am=a.stride(0), stride_ak=a.stride(1),
# stride_bk=b.stride(0), stride_bn=b.stride(1),
# stride_cm=c.stride(0), stride_cn=c.stride(1),
# M=a.shape[0], N=b.shape[1], K=a.shape[1],
# BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
# num_warps=NUM_WARPS)
# golden = torch.matmul(a, b)
# torch.set_printoptions(profile="full")
# assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
# Non-forloop
[64, 32, 64, 4, 64, 32, 64],
[128, 64, 128, 4, 128, 64, 128],
# K-Forloop
[64, 32, 128, 4, 64, 32, 64],
[128, 16, 128, 4, 128, 16, 32],
[32, 16, 128, 4, 32, 16, 32],
[32, 64, 128, 4, 32, 64, 32],
[32, 128, 256, 4, 32, 128, 64],
[64, 128, 64, 4, 64, 128, 32],
[128, 128, 64, 4, 128, 128, 32],
[64, 64, 128, 4, 64, 64, 32],
[128, 128, 128, 4, 128, 128, 32],
[128, 128, 256, 4, 128, 128, 64],
[128, 256, 128, 4, 128, 256, 32],
[256, 128, 64, 4, 256, 128, 16],
[128, 64, 128, 4, 128, 64, 32],
])
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
grid = lambda META: (1, )
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
M=a.shape[0], N=b.shape[1], K=a.shape[1],
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
num_warps=NUM_WARPS)
golden = torch.matmul(a, b)
torch.set_printoptions(profile="full")
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)

View File

@@ -872,7 +872,9 @@ def make_tritongpu_ir(mod, num_warps):
def optimize_tritongpu_ir(mod, num_stages):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_tritongpu_pipeline_pass(num_stages)
# Get error in backend due to wrong conversion in expanding async-related instruction.
# TODO[Superjomn]: Open it when fixed.
# pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_coalesce_pass()