diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 21dccb963..7c33f795c 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -4,11 +4,13 @@ using namespace mlir; using namespace mlir::triton; +using ::mlir::LLVM::DotOpFMAConversionHelper; using ::mlir::LLVM::DotOpMmaV1ConversionHelper; using ::mlir::LLVM::DotOpMmaV2ConversionHelper; using ::mlir::LLVM::getElementsFromStruct; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getStructFromElements; +using ::mlir::LLVM::MMA16816ConversionHelper; using ::mlir::triton::gpu::getElemsPerThread; struct SplatOpConversion @@ -38,6 +40,11 @@ struct SplatOpConversion LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); return getStructFromElements(loc, elems, rewriter, structTy); + } else if (auto dotLayout = + tensorTy.getEncoding() + .dyn_cast()) { + return convertSplatLikeOpWithDotOperandLayout( + dotLayout, resType, elemType, constVal, typeConverter, rewriter, loc); } else if (auto mmaLayout = tensorTy.getEncoding().dyn_cast()) { return convertSplatLikeOpWithMmaLayout( @@ -48,6 +55,38 @@ struct SplatOpConversion return {}; } + static Value convertSplatLikeOpWithDotOperandLayout( + const triton::gpu::DotOperandEncodingAttr &layout, Type resType, + Type elemType, Value constVal, TypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Location loc) { + auto tensorTy = resType.cast(); + auto shape = tensorTy.getShape(); + auto parent = layout.getParent(); + int numElems{}; + if (auto mmaLayout = parent.dyn_cast()) { + if (mmaLayout.isAmpere()) { + numElems = layout.getOpIdx() == 0 + ? MMA16816ConversionHelper::getANumElemsPerThread( + tensorTy, mmaLayout.getWarpsPerCTA()[0]) + : MMA16816ConversionHelper::getBNumElemsPerThread( + tensorTy, mmaLayout.getWarpsPerCTA()[1]); + } else if (mmaLayout.isVolta()) { + DotOpMmaV1ConversionHelper helper(mmaLayout); + numElems = layout.getOpIdx() == 0 + ? helper.numElemsPerThreadA(shape, {0, 1}) + : helper.numElemsPerThreadB(shape, {0, 1}); + } + } else if (auto blockedLayout = parent.dyn_cast()) { + numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout); + } else { + assert(false && "Unsupported layout found"); + } + auto structTy = LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), SmallVector(numElems, elemType)); + return getStructFromElements(loc, SmallVector(numElems, constVal), + rewriter, structTy); + } + static Value convertSplatLikeOpWithMmaLayout( const MmaEncodingAttr &layout, Type resType, Type elemType, Value constVal, TypeConverter *typeConverter, diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index ac381d50a..862687f6f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1227,20 +1227,20 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi elif dtype == 'int8': assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx -# FIXME: Unsupported layout found in ConvertSplatLikeOp -# def test_dot_without_load(): -# @triton.jit -# def kernel(out): -# pid = tl.program_id(axis=0) -# a = tl.zeros((32, 32), tl.float32) -# b = tl.zeros((32, 32), tl.float32) -# c = tl.zeros((32, 32), tl.float32) -# c = tl.dot(a, b) -# pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] -# tl.store(pout, c) -# -# out = torch.ones((32, 32), dtype=torch.float32, device="cuda") -# kernel[(1,)](out) + +def test_dot_without_load(): + @triton.jit + def kernel(out): + pid = tl.program_id(axis=0) + a = tl.zeros((32, 32), tl.float32) + b = tl.zeros((32, 32), tl.float32) + c = tl.zeros((32, 32), tl.float32) + c = tl.dot(a, b) + pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(pout, c) + + out = torch.ones((32, 32), dtype=torch.float32, device="cuda") + kernel[(1,)](out) # --------------- # test arange