From df8d276089a711a682e34f9705c7b05b0addc1ea Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Tue, 27 Sep 2022 17:28:17 +0800 Subject: [PATCH] [Triton-MLIR][Backend] Fix smem base bug in dot codegen (#715) Get SMEM base address of an input operand from `adapter.arg()` instead of `getSharedMemoryBase(arg, ...)`, for the latter one not works with memory alias, for example: ```llvm %a = extract_slice %b, %offset %c = dot %a, %d ``` `%a` should have different smem base address from `%b` --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 742ba60d7..165fb86fb 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -455,13 +455,6 @@ public: return multiDimBase; } - SmallVector - emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b, - const MmaEncodingAttr &mmaLayout, - ArrayRef shape) const { - // ongoing - } - SmallVector> emitIndices(Location loc, ConversionPatternRewriter &b, const Attribute &layout, @@ -2446,8 +2439,8 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, // Load A or B matrix. auto getLoadMatrixFn = - [&](Value tensor, int wpt, int kOrder, ArrayRef instrShape, - ArrayRef matShape, Value warpId, + [&](Value tensor, Value llTensor, int wpt, int kOrder, + ArrayRef instrShape, ArrayRef matShape, Value warpId, decltype(ha) &vals) -> std::function { auto tensorTy = tensor.getType().cast(); // We assumes that the input operand of Dot should be from shared layout. @@ -2468,10 +2461,9 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, SmallVector ptrs(numPtrs); Type smemPtrTy = helper.getShemPtrTy(); - auto smemBase = getSharedMemoryBase(loc, rewriter, tensor); for (int i = 0; i < numPtrs; ++i) { - ptrs[i] = bit_cast( - smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]}))); + ptrs[i] = + bit_cast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]}))); } bool needTrans = kOrder != order[0]; @@ -2499,16 +2491,16 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, std::function loadA; std::function loadB = getLoadMatrixFn( - B, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/, - {mmaInstrK, mmaInstrN} /*instrShpae*/, + B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, + 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/, {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); if (aTensorTy.getEncoding() .dyn_cast()) { // load from smem - loadA = getLoadMatrixFn(A, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, - 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/, - {matShapeM, matShapeK} /*matShape*/, - warpM /*warpId*/, ha /*vals*/); + loadA = getLoadMatrixFn( + A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, + 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/, + {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/); } else if (auto blockedLayout = aTensorTy.getEncoding() .dyn_cast()) { // load from registers,