[BACKEND] Support splat constant on the DotOperandLayout (#1008)

This commit is contained in:
Keren Zhou
2022-12-22 00:48:46 -08:00
committed by GitHub
parent 925d3d7f98
commit fd2da4aff6
2 changed files with 53 additions and 14 deletions

View File

@@ -4,11 +4,13 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::DotOpMmaV2ConversionHelper;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::getElemsPerThread;
struct SplatOpConversion
@@ -38,6 +40,11 @@ struct SplatOpConversion
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
return getStructFromElements(loc, elems, rewriter, structTy);
} else if (auto dotLayout =
tensorTy.getEncoding()
.dyn_cast<triton::gpu::DotOperandEncodingAttr>()) {
return convertSplatLikeOpWithDotOperandLayout(
dotLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
} else if (auto mmaLayout =
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
return convertSplatLikeOpWithMmaLayout(
@@ -48,6 +55,38 @@ struct SplatOpConversion
return {};
}
static Value convertSplatLikeOpWithDotOperandLayout(
const triton::gpu::DotOperandEncodingAttr &layout, Type resType,
Type elemType, Value constVal, TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto shape = tensorTy.getShape();
auto parent = layout.getParent();
int numElems{};
if (auto mmaLayout = parent.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) {
numElems = layout.getOpIdx() == 0
? MMA16816ConversionHelper::getANumElemsPerThread(
tensorTy, mmaLayout.getWarpsPerCTA()[0])
: MMA16816ConversionHelper::getBNumElemsPerThread(
tensorTy, mmaLayout.getWarpsPerCTA()[1]);
} else if (mmaLayout.isVolta()) {
DotOpMmaV1ConversionHelper helper(mmaLayout);
numElems = layout.getOpIdx() == 0
? helper.numElemsPerThreadA(shape, {0, 1})
: helper.numElemsPerThreadB(shape, {0, 1});
}
} else if (auto blockedLayout = parent.dyn_cast<BlockedEncodingAttr>()) {
numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout);
} else {
assert(false && "Unsupported layout found");
}
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(numElems, elemType));
return getStructFromElements(loc, SmallVector<Value>(numElems, constVal),
rewriter, structTy);
}
static Value convertSplatLikeOpWithMmaLayout(
const MmaEncodingAttr &layout, Type resType, Type elemType,
Value constVal, TypeConverter *typeConverter,

View File

@@ -1227,20 +1227,20 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
# FIXME: Unsupported layout found in ConvertSplatLikeOp
# def test_dot_without_load():
# @triton.jit
# def kernel(out):
# pid = tl.program_id(axis=0)
# a = tl.zeros((32, 32), tl.float32)
# b = tl.zeros((32, 32), tl.float32)
# c = tl.zeros((32, 32), tl.float32)
# c = tl.dot(a, b)
# pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
# tl.store(pout, c)
#
# out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
# kernel[(1,)](out)
def test_dot_without_load():
@triton.jit
def kernel(out):
pid = tl.program_id(axis=0)
a = tl.zeros((32, 32), tl.float32)
b = tl.zeros((32, 32), tl.float32)
c = tl.zeros((32, 32), tl.float32)
c = tl.dot(a, b)
pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
tl.store(pout, c)
out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
kernel[(1,)](out)
# ---------------
# test arange