[Triton-MLIR][Backend] Fix smem base bug in dot codegen (#715)
Get SMEM base address of an input operand from `adapter.arg()` instead of `getSharedMemoryBase(arg, ...)`, for the latter one not works with memory alias, for example: ```llvm %a = extract_slice %b, %offset %c = dot %a, %d ``` `%a` should have different smem base address from `%b`
This commit is contained in:
@@ -455,13 +455,6 @@ public:
|
|||||||
return multiDimBase;
|
return multiDimBase;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value>
|
|
||||||
emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b,
|
|
||||||
const MmaEncodingAttr &mmaLayout,
|
|
||||||
ArrayRef<int64_t> shape) const {
|
|
||||||
// ongoing
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||||
ConversionPatternRewriter &b,
|
ConversionPatternRewriter &b,
|
||||||
const Attribute &layout,
|
const Attribute &layout,
|
||||||
@@ -2446,8 +2439,8 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|||||||
|
|
||||||
// Load A or B matrix.
|
// Load A or B matrix.
|
||||||
auto getLoadMatrixFn =
|
auto getLoadMatrixFn =
|
||||||
[&](Value tensor, int wpt, int kOrder, ArrayRef<int> instrShape,
|
[&](Value tensor, Value llTensor, int wpt, int kOrder,
|
||||||
ArrayRef<int> matShape, Value warpId,
|
ArrayRef<int> instrShape, ArrayRef<int> matShape, Value warpId,
|
||||||
decltype(ha) &vals) -> std::function<void(int, int)> {
|
decltype(ha) &vals) -> std::function<void(int, int)> {
|
||||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
// We assumes that the input operand of Dot should be from shared layout.
|
// We assumes that the input operand of Dot should be from shared layout.
|
||||||
@@ -2468,10 +2461,9 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|||||||
SmallVector<Value> ptrs(numPtrs);
|
SmallVector<Value> ptrs(numPtrs);
|
||||||
|
|
||||||
Type smemPtrTy = helper.getShemPtrTy();
|
Type smemPtrTy = helper.getShemPtrTy();
|
||||||
auto smemBase = getSharedMemoryBase(loc, rewriter, tensor);
|
|
||||||
for (int i = 0; i < numPtrs; ++i) {
|
for (int i = 0; i < numPtrs; ++i) {
|
||||||
ptrs[i] = bit_cast(
|
ptrs[i] =
|
||||||
smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]})));
|
bit_cast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]})));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool needTrans = kOrder != order[0];
|
bool needTrans = kOrder != order[0];
|
||||||
@@ -2499,16 +2491,16 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|||||||
|
|
||||||
std::function<void(int, int)> loadA;
|
std::function<void(int, int)> loadA;
|
||||||
std::function<void(int, int)> loadB = getLoadMatrixFn(
|
std::function<void(int, int)> loadB = getLoadMatrixFn(
|
||||||
B, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/,
|
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||||
{mmaInstrK, mmaInstrN} /*instrShpae*/,
|
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||||
|
|
||||||
if (aTensorTy.getEncoding()
|
if (aTensorTy.getEncoding()
|
||||||
.dyn_cast<SharedEncodingAttr>()) { // load from smem
|
.dyn_cast<SharedEncodingAttr>()) { // load from smem
|
||||||
loadA = getLoadMatrixFn(A, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
loadA = getLoadMatrixFn(
|
||||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||||
{matShapeM, matShapeK} /*matShape*/,
|
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
||||||
warpM /*warpId*/, ha /*vals*/);
|
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
||||||
} else if (auto blockedLayout =
|
} else if (auto blockedLayout =
|
||||||
aTensorTy.getEncoding()
|
aTensorTy.getEncoding()
|
||||||
.dyn_cast<BlockedEncodingAttr>()) { // load from registers,
|
.dyn_cast<BlockedEncodingAttr>()) { // load from registers,
|
||||||
|
Reference in New Issue
Block a user