[BACKEND] Support splat constant on the DotOperandLayout (#1008)
This commit is contained in:
@@ -4,11 +4,13 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::DotOpFMAConversionHelper;
|
||||
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
|
||||
using ::mlir::LLVM::DotOpMmaV2ConversionHelper;
|
||||
using ::mlir::LLVM::getElementsFromStruct;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStructFromElements;
|
||||
using ::mlir::LLVM::MMA16816ConversionHelper;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
|
||||
struct SplatOpConversion
|
||||
@@ -38,6 +40,11 @@ struct SplatOpConversion
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
} else if (auto dotLayout =
|
||||
tensorTy.getEncoding()
|
||||
.dyn_cast<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
return convertSplatLikeOpWithDotOperandLayout(
|
||||
dotLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
|
||||
} else if (auto mmaLayout =
|
||||
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
|
||||
return convertSplatLikeOpWithMmaLayout(
|
||||
@@ -48,6 +55,38 @@ struct SplatOpConversion
|
||||
return {};
|
||||
}
|
||||
|
||||
static Value convertSplatLikeOpWithDotOperandLayout(
|
||||
const triton::gpu::DotOperandEncodingAttr &layout, Type resType,
|
||||
Type elemType, Value constVal, TypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
auto parent = layout.getParent();
|
||||
int numElems{};
|
||||
if (auto mmaLayout = parent.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isAmpere()) {
|
||||
numElems = layout.getOpIdx() == 0
|
||||
? MMA16816ConversionHelper::getANumElemsPerThread(
|
||||
tensorTy, mmaLayout.getWarpsPerCTA()[0])
|
||||
: MMA16816ConversionHelper::getBNumElemsPerThread(
|
||||
tensorTy, mmaLayout.getWarpsPerCTA()[1]);
|
||||
} else if (mmaLayout.isVolta()) {
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
numElems = layout.getOpIdx() == 0
|
||||
? helper.numElemsPerThreadA(shape, {0, 1})
|
||||
: helper.numElemsPerThreadB(shape, {0, 1});
|
||||
}
|
||||
} else if (auto blockedLayout = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout);
|
||||
} else {
|
||||
assert(false && "Unsupported layout found");
|
||||
}
|
||||
auto structTy = LLVM::LLVMStructType::getLiteral(
|
||||
rewriter.getContext(), SmallVector<Type>(numElems, elemType));
|
||||
return getStructFromElements(loc, SmallVector<Value>(numElems, constVal),
|
||||
rewriter, structTy);
|
||||
}
|
||||
|
||||
static Value convertSplatLikeOpWithMmaLayout(
|
||||
const MmaEncodingAttr &layout, Type resType, Type elemType,
|
||||
Value constVal, TypeConverter *typeConverter,
|
||||
|
@@ -1227,20 +1227,20 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
# FIXME: Unsupported layout found in ConvertSplatLikeOp
|
||||
# def test_dot_without_load():
|
||||
# @triton.jit
|
||||
# def kernel(out):
|
||||
# pid = tl.program_id(axis=0)
|
||||
# a = tl.zeros((32, 32), tl.float32)
|
||||
# b = tl.zeros((32, 32), tl.float32)
|
||||
# c = tl.zeros((32, 32), tl.float32)
|
||||
# c = tl.dot(a, b)
|
||||
# pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
# tl.store(pout, c)
|
||||
#
|
||||
# out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
|
||||
# kernel[(1,)](out)
|
||||
|
||||
def test_dot_without_load():
|
||||
@triton.jit
|
||||
def kernel(out):
|
||||
pid = tl.program_id(axis=0)
|
||||
a = tl.zeros((32, 32), tl.float32)
|
||||
b = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.dot(a, b)
|
||||
pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
tl.store(pout, c)
|
||||
|
||||
out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
|
||||
kernel[(1,)](out)
|
||||
|
||||
# ---------------
|
||||
# test arange
|
||||
|
Reference in New Issue
Block a user