testing things...

This commit is contained in:
Phil Tillet
2022-12-09 19:31:34 -08:00
parent fa6dbbff60
commit 58d2867fe6
6 changed files with 105 additions and 30 deletions

View File

@@ -3427,6 +3427,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
}
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1
// vprintf("offset 0", smemObj.offsets[0]}, rewriter);
DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isMMAv1Row =
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
@@ -3443,6 +3444,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
}
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
// LLVM::vprintf_array(i32_val(0), smemObj.offsets, "offsets ", "%d", rewriter);
// TODO[Superjomn]: transA is not available here.
bool transA = false;
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc,
@@ -4716,6 +4718,47 @@ private:
});
}
void rewriteConvertToDotOperand(ModuleOp mod) {
mod.walk([&](triton::gpu::ConvertLayoutOp cvt){
OpBuilder builder(cvt);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if(auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if(auto srcSharedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return;
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return;
unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row())
return;
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((order[0] == 1 && isMMAv1Row) ||
(order[0] == 0 && !isMMAv1Row))
return;
auto newIsRow = BoolAttr::get(cvt.getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
cvt.getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(),
dstType.getElementType(), newDstEncoding);
auto newCvt = builder.create<triton::gpu::ConvertLayoutOp>(
cvt.getLoc(), newDstType, cvt.getOperand());
cvt.replaceAllUsesWith(newCvt.getResult());
cvt.erase();
});
}
void decomposeInsertSliceAsyncOp(ModuleOp mod) {
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
axisInfoAnalysis.run(mod);
@@ -4835,6 +4878,7 @@ public:
// separation between 1/4 is that, step 3 is out of the scope of Dialect
// Conversion, thus we need to make sure the smem is not revised during the
// conversion of step 4.
rewriteConvertToDotOperand(mod);
decomposeMmaToDotOperand(mod, numWarps);
decomposeBlockedToDotOperand(mod);
@@ -4845,6 +4889,7 @@ public:
MembarAnalysis membarPass(&allocation);
membarPass.run();
llvm::outs() << mod << "\n";
RewritePatternSet scf_patterns(context);
mlir::populateLoopToStdConversionPatterns(scf_patterns);
mlir::ConversionTarget scf_target(*context);