[Triton-MLIR][Backend] Fix smem base bug in dot codegen (#715)

Get SMEM base address of an input operand from `adapter.arg()` instead
of `getSharedMemoryBase(arg, ...)`, for the latter one not works with
memory alias, for example:

```llvm
%a = extract_slice %b, %offset
%c = dot %a, %d
```

`%a` should have different smem base address from `%b`
This commit is contained in:
Yan Chunwei
2022-09-27 17:28:17 +08:00
committed by GitHub
parent 3a84278530
commit df8d276089

View File

@@ -455,13 +455,6 @@ public:
return multiDimBase; return multiDimBase;
} }
SmallVector<Value>
emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
// ongoing
}
SmallVector<SmallVector<Value>> emitIndices(Location loc, SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b, ConversionPatternRewriter &b,
const Attribute &layout, const Attribute &layout,
@@ -2446,8 +2439,8 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
// Load A or B matrix. // Load A or B matrix.
auto getLoadMatrixFn = auto getLoadMatrixFn =
[&](Value tensor, int wpt, int kOrder, ArrayRef<int> instrShape, [&](Value tensor, Value llTensor, int wpt, int kOrder,
ArrayRef<int> matShape, Value warpId, ArrayRef<int> instrShape, ArrayRef<int> matShape, Value warpId,
decltype(ha) &vals) -> std::function<void(int, int)> { decltype(ha) &vals) -> std::function<void(int, int)> {
auto tensorTy = tensor.getType().cast<RankedTensorType>(); auto tensorTy = tensor.getType().cast<RankedTensorType>();
// We assumes that the input operand of Dot should be from shared layout. // We assumes that the input operand of Dot should be from shared layout.
@@ -2468,10 +2461,9 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
SmallVector<Value> ptrs(numPtrs); SmallVector<Value> ptrs(numPtrs);
Type smemPtrTy = helper.getShemPtrTy(); Type smemPtrTy = helper.getShemPtrTy();
auto smemBase = getSharedMemoryBase(loc, rewriter, tensor);
for (int i = 0; i < numPtrs; ++i) { for (int i = 0; i < numPtrs; ++i) {
ptrs[i] = bit_cast( ptrs[i] =
smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]}))); bit_cast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
} }
bool needTrans = kOrder != order[0]; bool needTrans = kOrder != order[0];
@@ -2499,16 +2491,16 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
std::function<void(int, int)> loadA; std::function<void(int, int)> loadA;
std::function<void(int, int)> loadB = getLoadMatrixFn( std::function<void(int, int)> loadB = getLoadMatrixFn(
B, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/, B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
{mmaInstrK, mmaInstrN} /*instrShpae*/, 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
if (aTensorTy.getEncoding() if (aTensorTy.getEncoding()
.dyn_cast<SharedEncodingAttr>()) { // load from smem .dyn_cast<SharedEncodingAttr>()) { // load from smem
loadA = getLoadMatrixFn(A, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, loadA = getLoadMatrixFn(
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/, A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
{matShapeM, matShapeK} /*matShape*/, 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
warpM /*warpId*/, ha /*vals*/); {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
} else if (auto blockedLayout = } else if (auto blockedLayout =
aTensorTy.getEncoding() aTensorTy.getEncoding()
.dyn_cast<BlockedEncodingAttr>()) { // load from registers, .dyn_cast<BlockedEncodingAttr>()) { // load from registers,