finish porting the original logic

This commit is contained in:
Superjomn
2022-11-04 13:35:49 +08:00
parent da2993e1c7
commit 1f552308c4
2 changed files with 166 additions and 17 deletions

View File

@@ -2642,20 +2642,21 @@ public:
Value elems[4];
Type elemTy = type::f32Ty(ctx);
Type elemPtrTy = ptr_ty(elemTy);
if (kOrder == 1) {
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
elems[1] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
elems[2] =
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
elems[3] =
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
} else {
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
elems[2] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
elems[1] =
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
elems[3] =
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
}
return {elems[0], elems[1], elems[2], elems[3]};
@@ -2799,6 +2800,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
size_t reduceAxis = 1;
unsigned K = AShape[reduceAxis];
bool isOuter = K == 1;
bool isMMA = D.getType()
.cast<RankedTensorType>()
.getEncoding()
@@ -2810,11 +2812,13 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
.getEncoding()
.cast<MmaEncodingAttr>();
if (!isOuter && isMMA) {
bool isHMMA = isDotHMMA(op);
if (!isOuter && isMMA && isHMMA) {
if (mmaLayout.getVersion() == 1)
return convertMMA884(op, adaptor, rewriter);
if (mmaLayout.getVersion() == 2)
return convertMMA16816(op, adaptor, rewriter);
llvm::report_fatal_error(
"Unsupported MMA kind found when converting DotOp to LLVM.");
}
@@ -2827,6 +2831,46 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
"Unsupported DotOp found when converting TritonGPU to LLVM.");
}
// Tell whether a DotOp support HMMA.
// This is port from the master branch, the original logic is retained.
static bool isDotHMMA(DotOp op) {
auto a = op.a();
auto b = op.b();
auto c = op.c();
auto d = op.getResult();
auto aTensorTy = a.getType().cast<RankedTensorType>();
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto cTensorTy = c.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>();
if (!dTensorTy.getEncoding().isa<MmaEncodingAttr>())
return false;
auto mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
auto aElemTy = aTensorTy.getElementType();
auto bElemTy = bTensorTy.getElementType();
// Refer to mma section for the data type supported by Volta and Hopper
// Tensor Core in
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
return (aElemTy.isF16() && bElemTy.isF16()) ||
(aElemTy.isBF16() && bElemTy.isBF16()) ||
(aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() &&
mmaLayout.getVersion() >= 2) ||
(aElemTy.isInteger(8) && bElemTy.isInteger(8) &&
mmaLayout.getVersion() >= 2);
}
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
// We cannot get both the operand types(in TypeConverter), here we assume the
// types of both the operands are identical here.
// TODO[Superjomn]: Find a better way to implement it.
static bool isDotHMMA(TensorType operand, bool allowTF32, int mmaVersion) {
auto elemTy = operand.getElementType();
return elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && allowTF32 && mmaVersion >= 2) ||
(elemTy.isInteger(8) && mmaVersion >= 2);
}
private:
// Convert to mma.m16n8k16
LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adaptor,
@@ -3590,6 +3634,73 @@ private:
}
};
// Helper for FMADot conversion.
class DotOpFMAConversionHelper {
public:
MmaEncodingAttr mmaLayout;
ArrayRef<unsigned> wpt;
using ValueTable = std::map<std::pair<int, int>, Value>;
explicit DotOpFMAConversionHelper(MmaEncodingAttr mmaLayout)
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
// Currently, we can tell whether to use FMAdot only from the operand type,
// while in the original code, FMADot requires that both the operand and
// result of dot should be fp32.
// This method should be safe to use in the cases where tensor core is not
// appliable.
static bool useFMA(TensorType operand) {
return operand.getElementType().isF32();
}
SmallVector<unsigned> getOrder() const {
SmallVector<unsigned> order(2);
if (mmaLayout.getVersion() == 1)
order = {0, 1};
else if (mmaLayout.getVersion() == 0)
order = {1, 0};
else {
assert(false && "Unexpected MMA version found.");
}
return order;
}
Value loadA(Value tensor, Value llTensor, Value threadId, Location loc,
Value smem, ConversionPatternRewriter &rewriter) const {
auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto aShape = tensorTy.getShape();
auto aLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
auto aOrder = aLayout.getOrder();
auto order = getOrder();
bool isARow = aOrder[0] == 1;
int strideAM = isARow ? aShape[1] : 1;
int strideAK = isARow ? 1 : aShape[0];
int strideA0 = isARow ? strideAK : strideAM;
int strideA1 = isARow ? strideAM : strideAK;
int lda = isARow ? strideAM : strideAK;
int aPerPhase = aLayout.getPerPhase();
int aMaxPhase = aLayout.getMaxPhase();
int aNumPtr = 8;
int bNumPtr = 8;
int aVec = 2;
int NK = aShape[isARow ? 1 : 0];
return Value{};
}
Value loadB(Value tensor, Value llTensor, Value threadId, Location loc,
Value smem, ConversionPatternRewriter &rewriter) const {
return Value{};
}
ValueTable extractLoadedOperand(Value llTensor) const { return ValueTable{}; }
};
LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -3605,8 +3716,24 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
assert(mmaLayout);
bool isOuter{};
{
int K{};
if (dotOperandLayout.getOpIdx() == 0) // $a
K = dstTensorTy.getShape()[1];
else // $b
K = dstTensorTy.getShape()[0];
isOuter = K == 1;
}
// TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it
// is an attribute of DotOp.
bool allowTF32 = false;
bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, allowTF32,
mmaLayout.getVersion());
Value res;
if (mmaLayout.getVersion() == 2) {
if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
rewriter, getTypeConverter(),
op.getLoc());
@@ -3618,7 +3745,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
// operand $b
res = mmaHelper.loadB(src, adaptor.src());
}
} else if (mmaLayout.getVersion() == 1) {
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1
DotOpMmaV1ConversionHelper helper(mmaLayout);
if (dotOperandLayout.getOpIdx() == 0) {
// operand $a
@@ -3629,6 +3757,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc),
adaptor.src(), loc, rewriter);
}
} else if (DotOpFMAConversionHelper::useFMA(dstTensorTy)) { // fmadot
} else {
assert(false && "Unsupported mma layout found");
}
@@ -4184,9 +4314,11 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
for (int i = 0; i < bNumPtr; ++i)
bPtrs[i] = gep(f32PtrTy, adaptor.b(), bOff[i]);
// TODO initialize ret with $c.
std::map<std::pair<int, int>, Value> has, hbs;
// TODO initialize ret with zeros.
SmallVector<Value> ret(NK);
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
SmallVector<Value> ret(cShape[0] * cShape[1], cc[0]);
for (unsigned k = 0; k < NK; k++) {
int z = 0;
for (unsigned i = 0; i < cShape[order[1]]; i += cShapePerCTA[order[1]])
@@ -4203,14 +4335,15 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
Value va = load(pa);
has[{m + mm, k}] = va;
}
if (!has.count({n + nn, k})) {
if (!hbs.count({n + nn, k})) {
Value pb = gep(f32PtrTy, bPtrs[0],
i32_val((n + nn) * strideBN + k * strideBK));
Value vb = load(pb);
has[{n + nn, k}] = vb;
hbs[{n + nn, k}] = vb;
}
ret[z++] = rewriter.create<LLVM::FMulAddOp>(
loc, has[{m + mm, k}], hbs[{n + nn, k}], ret[z]);
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
hbs[{n + nn, k}], ret[z]);
++z;
}
}

View File

@@ -789,3 +789,19 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
return
}
}
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
// CHECK: llvm.intr.fmuladd
%28 = tt.dot %a, %b, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #shared> * tensor<16x32xf32, #shared> -> tensor<16x16xf32, #mma>
return
}
}