finish porting the original logic
This commit is contained in:
@@ -2642,20 +2642,21 @@ public:
|
||||
|
||||
Value elems[4];
|
||||
Type elemTy = type::f32Ty(ctx);
|
||||
Type elemPtrTy = ptr_ty(elemTy);
|
||||
if (kOrder == 1) {
|
||||
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[1] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[2] =
|
||||
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
elems[3] =
|
||||
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
} else {
|
||||
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[2] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[1] =
|
||||
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
elems[3] =
|
||||
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
}
|
||||
|
||||
return {elems[0], elems[1], elems[2], elems[3]};
|
||||
@@ -2799,6 +2800,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
size_t reduceAxis = 1;
|
||||
unsigned K = AShape[reduceAxis];
|
||||
bool isOuter = K == 1;
|
||||
|
||||
bool isMMA = D.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
@@ -2810,11 +2812,13 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
.getEncoding()
|
||||
.cast<MmaEncodingAttr>();
|
||||
|
||||
if (!isOuter && isMMA) {
|
||||
bool isHMMA = isDotHMMA(op);
|
||||
if (!isOuter && isMMA && isHMMA) {
|
||||
if (mmaLayout.getVersion() == 1)
|
||||
return convertMMA884(op, adaptor, rewriter);
|
||||
if (mmaLayout.getVersion() == 2)
|
||||
return convertMMA16816(op, adaptor, rewriter);
|
||||
|
||||
llvm::report_fatal_error(
|
||||
"Unsupported MMA kind found when converting DotOp to LLVM.");
|
||||
}
|
||||
@@ -2827,6 +2831,46 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
"Unsupported DotOp found when converting TritonGPU to LLVM.");
|
||||
}
|
||||
|
||||
// Tell whether a DotOp support HMMA.
|
||||
// This is port from the master branch, the original logic is retained.
|
||||
static bool isDotHMMA(DotOp op) {
|
||||
auto a = op.a();
|
||||
auto b = op.b();
|
||||
auto c = op.c();
|
||||
auto d = op.getResult();
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto cTensorTy = c.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!dTensorTy.getEncoding().isa<MmaEncodingAttr>())
|
||||
return false;
|
||||
|
||||
auto mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
auto aElemTy = aTensorTy.getElementType();
|
||||
auto bElemTy = bTensorTy.getElementType();
|
||||
// Refer to mma section for the data type supported by Volta and Hopper
|
||||
// Tensor Core in
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
||||
return (aElemTy.isF16() && bElemTy.isF16()) ||
|
||||
(aElemTy.isBF16() && bElemTy.isBF16()) ||
|
||||
(aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() &&
|
||||
mmaLayout.getVersion() >= 2) ||
|
||||
(aElemTy.isInteger(8) && bElemTy.isInteger(8) &&
|
||||
mmaLayout.getVersion() >= 2);
|
||||
}
|
||||
|
||||
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
|
||||
// We cannot get both the operand types(in TypeConverter), here we assume the
|
||||
// types of both the operands are identical here.
|
||||
// TODO[Superjomn]: Find a better way to implement it.
|
||||
static bool isDotHMMA(TensorType operand, bool allowTF32, int mmaVersion) {
|
||||
auto elemTy = operand.getElementType();
|
||||
return elemTy.isF16() || elemTy.isBF16() ||
|
||||
(elemTy.isF32() && allowTF32 && mmaVersion >= 2) ||
|
||||
(elemTy.isInteger(8) && mmaVersion >= 2);
|
||||
}
|
||||
|
||||
private:
|
||||
// Convert to mma.m16n8k16
|
||||
LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adaptor,
|
||||
@@ -3590,6 +3634,73 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
// Helper for FMADot conversion.
|
||||
class DotOpFMAConversionHelper {
|
||||
public:
|
||||
MmaEncodingAttr mmaLayout;
|
||||
ArrayRef<unsigned> wpt;
|
||||
|
||||
using ValueTable = std::map<std::pair<int, int>, Value>;
|
||||
|
||||
explicit DotOpFMAConversionHelper(MmaEncodingAttr mmaLayout)
|
||||
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
|
||||
|
||||
// Currently, we can tell whether to use FMAdot only from the operand type,
|
||||
// while in the original code, FMADot requires that both the operand and
|
||||
// result of dot should be fp32.
|
||||
// This method should be safe to use in the cases where tensor core is not
|
||||
// appliable.
|
||||
static bool useFMA(TensorType operand) {
|
||||
return operand.getElementType().isF32();
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getOrder() const {
|
||||
SmallVector<unsigned> order(2);
|
||||
if (mmaLayout.getVersion() == 1)
|
||||
order = {0, 1};
|
||||
else if (mmaLayout.getVersion() == 0)
|
||||
order = {1, 0};
|
||||
else {
|
||||
assert(false && "Unexpected MMA version found.");
|
||||
}
|
||||
return order;
|
||||
}
|
||||
|
||||
Value loadA(Value tensor, Value llTensor, Value threadId, Location loc,
|
||||
Value smem, ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto aShape = tensorTy.getShape();
|
||||
auto aLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto aOrder = aLayout.getOrder();
|
||||
auto order = getOrder();
|
||||
|
||||
bool isARow = aOrder[0] == 1;
|
||||
|
||||
int strideAM = isARow ? aShape[1] : 1;
|
||||
int strideAK = isARow ? 1 : aShape[0];
|
||||
int strideA0 = isARow ? strideAK : strideAM;
|
||||
int strideA1 = isARow ? strideAM : strideAK;
|
||||
int lda = isARow ? strideAM : strideAK;
|
||||
int aPerPhase = aLayout.getPerPhase();
|
||||
int aMaxPhase = aLayout.getMaxPhase();
|
||||
int aNumPtr = 8;
|
||||
int bNumPtr = 8;
|
||||
int aVec = 2;
|
||||
int NK = aShape[isARow ? 1 : 0];
|
||||
|
||||
return Value{};
|
||||
}
|
||||
|
||||
Value loadB(Value tensor, Value llTensor, Value threadId, Location loc,
|
||||
Value smem, ConversionPatternRewriter &rewriter) const {
|
||||
return Value{};
|
||||
}
|
||||
|
||||
ValueTable extractLoadedOperand(Value llTensor) const { return ValueTable{}; }
|
||||
};
|
||||
|
||||
LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
@@ -3605,8 +3716,24 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
|
||||
assert(mmaLayout);
|
||||
|
||||
bool isOuter{};
|
||||
{
|
||||
int K{};
|
||||
if (dotOperandLayout.getOpIdx() == 0) // $a
|
||||
K = dstTensorTy.getShape()[1];
|
||||
else // $b
|
||||
K = dstTensorTy.getShape()[0];
|
||||
isOuter = K == 1;
|
||||
}
|
||||
|
||||
// TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it
|
||||
// is an attribute of DotOp.
|
||||
bool allowTF32 = false;
|
||||
bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, allowTF32,
|
||||
mmaLayout.getVersion());
|
||||
|
||||
Value res;
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2
|
||||
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
||||
rewriter, getTypeConverter(),
|
||||
op.getLoc());
|
||||
@@ -3618,7 +3745,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
// operand $b
|
||||
res = mmaHelper.loadB(src, adaptor.src());
|
||||
}
|
||||
} else if (mmaLayout.getVersion() == 1) {
|
||||
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
||||
isHMMA) { // tensor core v1
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
if (dotOperandLayout.getOpIdx() == 0) {
|
||||
// operand $a
|
||||
@@ -3629,6 +3757,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc),
|
||||
adaptor.src(), loc, rewriter);
|
||||
}
|
||||
} else if (DotOpFMAConversionHelper::useFMA(dstTensorTy)) { // fmadot
|
||||
|
||||
} else {
|
||||
assert(false && "Unsupported mma layout found");
|
||||
}
|
||||
@@ -4184,9 +4314,11 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
for (int i = 0; i < bNumPtr; ++i)
|
||||
bPtrs[i] = gep(f32PtrTy, adaptor.b(), bOff[i]);
|
||||
|
||||
// TODO initialize ret with $c.
|
||||
std::map<std::pair<int, int>, Value> has, hbs;
|
||||
// TODO initialize ret with zeros.
|
||||
SmallVector<Value> ret(NK);
|
||||
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
||||
SmallVector<Value> ret(cShape[0] * cShape[1], cc[0]);
|
||||
|
||||
for (unsigned k = 0; k < NK; k++) {
|
||||
int z = 0;
|
||||
for (unsigned i = 0; i < cShape[order[1]]; i += cShapePerCTA[order[1]])
|
||||
@@ -4203,14 +4335,15 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
Value va = load(pa);
|
||||
has[{m + mm, k}] = va;
|
||||
}
|
||||
if (!has.count({n + nn, k})) {
|
||||
if (!hbs.count({n + nn, k})) {
|
||||
Value pb = gep(f32PtrTy, bPtrs[0],
|
||||
i32_val((n + nn) * strideBN + k * strideBK));
|
||||
Value vb = load(pb);
|
||||
has[{n + nn, k}] = vb;
|
||||
hbs[{n + nn, k}] = vb;
|
||||
}
|
||||
ret[z++] = rewriter.create<LLVM::FMulAddOp>(
|
||||
loc, has[{m + mm, k}], hbs[{n + nn, k}], ret[z]);
|
||||
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
|
||||
hbs[{n + nn, k}], ret[z]);
|
||||
++z;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -789,3 +789,19 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
|
||||
// CHECK: llvm.intr.fmuladd
|
||||
%28 = tt.dot %a, %b, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #shared> * tensor<16x32xf32, #shared> -> tensor<16x16xf32, #mma>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user