[Triton-MLIR][Backend] Fix smem base bug in dot codegen (#715)
Get SMEM base address of an input operand from `adapter.arg()` instead of `getSharedMemoryBase(arg, ...)`, for the latter one not works with memory alias, for example: ```llvm %a = extract_slice %b, %offset %c = dot %a, %d ``` `%a` should have different smem base address from `%b`
This commit is contained in:
@@ -455,13 +455,6 @@ public:
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
// ongoing
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||
ConversionPatternRewriter &b,
|
||||
const Attribute &layout,
|
||||
@@ -2446,8 +2439,8 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
||||
|
||||
// Load A or B matrix.
|
||||
auto getLoadMatrixFn =
|
||||
[&](Value tensor, int wpt, int kOrder, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, Value warpId,
|
||||
[&](Value tensor, Value llTensor, int wpt, int kOrder,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape, Value warpId,
|
||||
decltype(ha) &vals) -> std::function<void(int, int)> {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
// We assumes that the input operand of Dot should be from shared layout.
|
||||
@@ -2468,10 +2461,9 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
||||
SmallVector<Value> ptrs(numPtrs);
|
||||
|
||||
Type smemPtrTy = helper.getShemPtrTy();
|
||||
auto smemBase = getSharedMemoryBase(loc, rewriter, tensor);
|
||||
for (int i = 0; i < numPtrs; ++i) {
|
||||
ptrs[i] = bit_cast(
|
||||
smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]})));
|
||||
ptrs[i] =
|
||||
bit_cast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
|
||||
}
|
||||
|
||||
bool needTrans = kOrder != order[0];
|
||||
@@ -2499,16 +2491,16 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
||||
|
||||
std::function<void(int, int)> loadA;
|
||||
std::function<void(int, int)> loadB = getLoadMatrixFn(
|
||||
B, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/,
|
||||
{mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||
|
||||
if (aTensorTy.getEncoding()
|
||||
.dyn_cast<SharedEncodingAttr>()) { // load from smem
|
||||
loadA = getLoadMatrixFn(A, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
||||
{matShapeM, matShapeK} /*matShape*/,
|
||||
warpM /*warpId*/, ha /*vals*/);
|
||||
loadA = getLoadMatrixFn(
|
||||
A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
||||
} else if (auto blockedLayout =
|
||||
aTensorTy.getEncoding()
|
||||
.dyn_cast<BlockedEncodingAttr>()) { // load from registers,
|
||||
|
Reference in New Issue
Block a user