testing things...
This commit is contained in:
@@ -3427,6 +3427,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
||||
}
|
||||
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
||||
isHMMA) { // tensor core v1
|
||||
// vprintf("offset 0", smemObj.offsets[0]}, rewriter);
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
bool isMMAv1Row =
|
||||
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
@@ -3443,6 +3444,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
||||
}
|
||||
|
||||
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
|
||||
// LLVM::vprintf_array(i32_val(0), smemObj.offsets, "offsets ", "%d", rewriter);
|
||||
// TODO[Superjomn]: transA is not available here.
|
||||
bool transA = false;
|
||||
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc,
|
||||
@@ -4716,6 +4718,47 @@ private:
|
||||
});
|
||||
}
|
||||
|
||||
void rewriteConvertToDotOperand(ModuleOp mod) {
|
||||
mod.walk([&](triton::gpu::ConvertLayoutOp cvt){
|
||||
OpBuilder builder(cvt);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if(auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if(auto srcSharedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return;
|
||||
// dot operand output
|
||||
auto dstDotOperandLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return;
|
||||
unsigned opIdx = dstDotOperandLayout.getOpIdx();
|
||||
if(!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return;
|
||||
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if((order[0] == 1 && isMMAv1Row) ||
|
||||
(order[0] == 0 && !isMMAv1Row))
|
||||
return;
|
||||
auto newIsRow = BoolAttr::get(cvt.getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
cvt.getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
|
||||
newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(),
|
||||
dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvt.getLoc(), newDstType, cvt.getOperand());
|
||||
cvt.replaceAllUsesWith(newCvt.getResult());
|
||||
cvt.erase();
|
||||
});
|
||||
}
|
||||
|
||||
void decomposeInsertSliceAsyncOp(ModuleOp mod) {
|
||||
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
|
||||
axisInfoAnalysis.run(mod);
|
||||
@@ -4835,6 +4878,7 @@ public:
|
||||
// separation between 1/4 is that, step 3 is out of the scope of Dialect
|
||||
// Conversion, thus we need to make sure the smem is not revised during the
|
||||
// conversion of step 4.
|
||||
rewriteConvertToDotOperand(mod);
|
||||
decomposeMmaToDotOperand(mod, numWarps);
|
||||
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
@@ -4845,6 +4889,7 @@ public:
|
||||
MembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
llvm::outs() << mod << "\n";
|
||||
RewritePatternSet scf_patterns(context);
|
||||
mlir::populateLoopToStdConversionPatterns(scf_patterns);
|
||||
mlir::ConversionTarget scf_target(*context);
|
||||
|
Reference in New Issue
Block a user