[Triton-MLIR] Minor fixes related with scf/swizzling support (#791)

1, Disable static loop unrolling in the frontend by default;
2, A minor fix in axisAnalysis in order to support scf;
3, A minor fix in TritonGPUToLLVM to support swizzling.
This commit is contained in:
goostavz
2022-10-21 11:46:28 +08:00
committed by GitHub
parent dc0588a898
commit c4726333bf
4 changed files with 79 additions and 26 deletions

View File

@@ -40,7 +40,8 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
if (TensorType ty = value.getType().dyn_cast<TensorType>())
rank = ty.getRank();
int divHint = 1;
if (BlockArgument blockArg = value.dyn_cast<BlockArgument>()) {
BlockArgument blockArg = value.dyn_cast<BlockArgument>();
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
Operation *op = blockArg.getOwner()->getParentOp();
if (FuncOp fun = dyn_cast<FuncOp>(op)) {
Attribute attr =

View File

@@ -1867,8 +1867,8 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread());
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
unsigned wordVecIdx =
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep);
wordVecs[wordVecIdx] =