testing things...

This commit is contained in:
Phil Tillet
2022-12-09 19:31:34 -08:00
parent fa6dbbff60
commit 58d2867fe6
6 changed files with 105 additions and 30 deletions

View File

@@ -1356,6 +1356,20 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value tensor, bool transA, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const {
// [1, 0] (isRow = True)
// x x x x || x x x x
// x x x x || x x x x
// stride = [8, 1]
// strideA0 = strideAk = 1
// strideA1 = strideAm = 8
// [0, 1] (isRow = False)
// x x x x || x x x x
// x x x x || x x x x
// stride = [1, 2]
// strideA0 = strideAm = 1
// strideA1 = strideAk = 2
auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
@@ -1364,8 +1378,8 @@ Value DotOpMmaV1ConversionHelper::loadA(
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
// Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
Value smemBase = smemObj.base;
bool isARow = order[0] != 0;
AParam param(isARow);
@@ -1387,6 +1401,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value strideA0 = isARow ? strideAK : strideAM;
Value strideA1 = isARow ? strideAM : strideAK;
smemBase = gep(ptr_ty(f16_ty), smemBase, Value(smemObj.offsets[1]));
int strideRepM = wpt[0] * fpw[0] * 8;
int strideRepK = 1;
@@ -1401,7 +1416,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
Value offA0 = isARow ? offsetAK : offsetAM;
Value offA1 = isARow ? offsetAM : offsetAK;
Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA));
offA0 = add(offA0, cSwizzleOffset);
// offA0 = add(offA0, smemObj.offsets[order[0]]);
// offA1 = add(offA1, smemObj.offsets[order[1]]);
SmallVector<Value> offA(numPtrA);
for (int i = 0; i < numPtrA; i++) {
Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM)));
@@ -1422,6 +1439,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
Type f16PtrTy = ptr_ty(f16_ty);
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
vals[{m, k}] = {val0, val1};
};
@@ -1451,7 +1469,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
}
};
unsigned numM = getNumM(shape, order);
llvm::outs() << "LOAD A " << numM << " " << NK << "\n";
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
loadA(m, k);

View File

@@ -3427,6 +3427,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
}
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1
// vprintf("offset 0", smemObj.offsets[0]}, rewriter);
DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isMMAv1Row =
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
@@ -3443,6 +3444,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
}
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
// LLVM::vprintf_array(i32_val(0), smemObj.offsets, "offsets ", "%d", rewriter);
// TODO[Superjomn]: transA is not available here.
bool transA = false;
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc,
@@ -4716,6 +4718,47 @@ private:
});
}
void rewriteConvertToDotOperand(ModuleOp mod) {
mod.walk([&](triton::gpu::ConvertLayoutOp cvt){
OpBuilder builder(cvt);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if(auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if(auto srcSharedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return;
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return;
unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row())
return;
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((order[0] == 1 && isMMAv1Row) ||
(order[0] == 0 && !isMMAv1Row))
return;
auto newIsRow = BoolAttr::get(cvt.getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
cvt.getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(),
dstType.getElementType(), newDstEncoding);
auto newCvt = builder.create<triton::gpu::ConvertLayoutOp>(
cvt.getLoc(), newDstType, cvt.getOperand());
cvt.replaceAllUsesWith(newCvt.getResult());
cvt.erase();
});
}
void decomposeInsertSliceAsyncOp(ModuleOp mod) {
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
axisInfoAnalysis.run(mod);
@@ -4835,6 +4878,7 @@ public:
// separation between 1/4 is that, step 3 is out of the scope of Dialect
// Conversion, thus we need to make sure the smem is not revised during the
// conversion of step 4.
rewriteConvertToDotOperand(mod);
decomposeMmaToDotOperand(mod, numWarps);
decomposeBlockedToDotOperand(mod);
@@ -4845,6 +4889,7 @@ public:
MembarAnalysis membarPass(&allocation);
membarPass.run();
llvm::outs() << mod << "\n";
RewritePatternSet scf_patterns(context);
mlir::populateLoopToStdConversionPatterns(scf_patterns);
mlir::ConversionTarget scf_target(*context);