[Triton-MLIR][Backend] Fix smem base bug in dot codegen (#715)

Get SMEM base address of an input operand from `adapter.arg()` instead
of `getSharedMemoryBase(arg, ...)`, for the latter one not works with
memory alias, for example:

```llvm
%a = extract_slice %b, %offset
%c = dot %a, %d
```

`%a` should have different smem base address from `%b`
This commit is contained in:
Yan Chunwei
2022-09-27 17:28:17 +08:00
committed by GitHub
parent 3a84278530
commit df8d276089

View File

@@ -455,13 +455,6 @@ public:
return multiDimBase;
}
SmallVector<Value>
emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
// ongoing
}
SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b,
const Attribute &layout,
@@ -2446,8 +2439,8 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
// Load A or B matrix.
auto getLoadMatrixFn =
[&](Value tensor, int wpt, int kOrder, ArrayRef<int> instrShape,
ArrayRef<int> matShape, Value warpId,
[&](Value tensor, Value llTensor, int wpt, int kOrder,
ArrayRef<int> instrShape, ArrayRef<int> matShape, Value warpId,
decltype(ha) &vals) -> std::function<void(int, int)> {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
// We assumes that the input operand of Dot should be from shared layout.
@@ -2468,10 +2461,9 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
SmallVector<Value> ptrs(numPtrs);
Type smemPtrTy = helper.getShemPtrTy();
auto smemBase = getSharedMemoryBase(loc, rewriter, tensor);
for (int i = 0; i < numPtrs; ++i) {
ptrs[i] = bit_cast(
smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]})));
ptrs[i] =
bit_cast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
}
bool needTrans = kOrder != order[0];
@@ -2499,16 +2491,16 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
std::function<void(int, int)> loadA;
std::function<void(int, int)> loadB = getLoadMatrixFn(
B, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/,
{mmaInstrK, mmaInstrN} /*instrShpae*/,
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
if (aTensorTy.getEncoding()
.dyn_cast<SharedEncodingAttr>()) { // load from smem
loadA = getLoadMatrixFn(A, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
{matShapeM, matShapeK} /*matShape*/,
warpM /*warpId*/, ha /*vals*/);
loadA = getLoadMatrixFn(
A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
} else if (auto blockedLayout =
aTensorTy.getEncoding()
.dyn_cast<BlockedEncodingAttr>()) { // load from registers,