testing things...
This commit is contained in:
@@ -432,13 +432,12 @@ section 9.7.13.4.1 for more details.
|
|||||||
let builders = [
|
let builders = [
|
||||||
AttrBuilder<(ins "unsigned":$opIdx,
|
AttrBuilder<(ins "unsigned":$opIdx,
|
||||||
"Attribute":$parent), [{
|
"Attribute":$parent), [{
|
||||||
|
Attribute isMMAv1Row;
|
||||||
if(parent.isa<MmaEncodingAttr>() &&
|
if(parent.isa<MmaEncodingAttr>() &&
|
||||||
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
||||||
llvm::report_fatal_error("DotOperand for MMAv1 must have isMMAv1Row field");
|
isMMAv1Row = BoolAttr::get(context, true);
|
||||||
return {};
|
|
||||||
}
|
}
|
||||||
Attribute none;
|
return $_get(context, opIdx, parent, isMMAv1Row);
|
||||||
return $_get(context, opIdx, parent, none);
|
|
||||||
}]>
|
}]>
|
||||||
|
|
||||||
];
|
];
|
||||||
|
@@ -1356,6 +1356,20 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
Value tensor, bool transA, const SharedMemoryObject &smemObj, Value thread,
|
Value tensor, bool transA, const SharedMemoryObject &smemObj, Value thread,
|
||||||
Location loc, ConversionPatternRewriter &rewriter) const {
|
Location loc, ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
// [1, 0] (isRow = True)
|
||||||
|
// x x x x || x x x x
|
||||||
|
// x x x x || x x x x
|
||||||
|
// stride = [8, 1]
|
||||||
|
// strideA0 = strideAk = 1
|
||||||
|
// strideA1 = strideAm = 8
|
||||||
|
|
||||||
|
// [0, 1] (isRow = False)
|
||||||
|
// x x x x || x x x x
|
||||||
|
// x x x x || x x x x
|
||||||
|
// stride = [1, 2]
|
||||||
|
// strideA0 = strideAm = 1
|
||||||
|
// strideA1 = strideAk = 2
|
||||||
|
|
||||||
auto *ctx = rewriter.getContext();
|
auto *ctx = rewriter.getContext();
|
||||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||||
@@ -1364,8 +1378,8 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||||
sharedLayout.getOrder().end());
|
sharedLayout.getOrder().end());
|
||||||
|
|
||||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
// Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||||
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
Value smemBase = smemObj.base;
|
||||||
|
|
||||||
bool isARow = order[0] != 0;
|
bool isARow = order[0] != 0;
|
||||||
AParam param(isARow);
|
AParam param(isARow);
|
||||||
@@ -1387,6 +1401,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
Value strideA0 = isARow ? strideAK : strideAM;
|
Value strideA0 = isARow ? strideAK : strideAM;
|
||||||
Value strideA1 = isARow ? strideAM : strideAK;
|
Value strideA1 = isARow ? strideAM : strideAK;
|
||||||
|
|
||||||
|
smemBase = gep(ptr_ty(f16_ty), smemBase, Value(smemObj.offsets[1]));
|
||||||
int strideRepM = wpt[0] * fpw[0] * 8;
|
int strideRepM = wpt[0] * fpw[0] * 8;
|
||||||
int strideRepK = 1;
|
int strideRepK = 1;
|
||||||
|
|
||||||
@@ -1401,7 +1416,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
Value offA0 = isARow ? offsetAK : offsetAM;
|
Value offA0 = isARow ? offsetAK : offsetAM;
|
||||||
Value offA1 = isARow ? offsetAM : offsetAK;
|
Value offA1 = isARow ? offsetAM : offsetAK;
|
||||||
Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA));
|
Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA));
|
||||||
offA0 = add(offA0, cSwizzleOffset);
|
// offA0 = add(offA0, smemObj.offsets[order[0]]);
|
||||||
|
// offA1 = add(offA1, smemObj.offsets[order[1]]);
|
||||||
|
|
||||||
SmallVector<Value> offA(numPtrA);
|
SmallVector<Value> offA(numPtrA);
|
||||||
for (int i = 0; i < numPtrA; i++) {
|
for (int i = 0; i < numPtrA; i++) {
|
||||||
Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM)));
|
Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM)));
|
||||||
@@ -1422,6 +1439,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
|
|
||||||
Type f16PtrTy = ptr_ty(f16_ty);
|
Type f16PtrTy = ptr_ty(f16_ty);
|
||||||
|
|
||||||
|
|
||||||
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
|
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
|
||||||
vals[{m, k}] = {val0, val1};
|
vals[{m, k}] = {val0, val1};
|
||||||
};
|
};
|
||||||
@@ -1451,7 +1469,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
unsigned numM = getNumM(shape, order);
|
unsigned numM = getNumM(shape, order);
|
||||||
|
llvm::outs() << "LOAD A " << numM << " " << NK << "\n";
|
||||||
|
|
||||||
for (unsigned k = 0; k < NK; k += 4)
|
for (unsigned k = 0; k < NK; k += 4)
|
||||||
for (unsigned m = 0; m < numM / 2; ++m)
|
for (unsigned m = 0; m < numM / 2; ++m)
|
||||||
loadA(m, k);
|
loadA(m, k);
|
||||||
|
@@ -3427,6 +3427,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
|||||||
}
|
}
|
||||||
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
||||||
isHMMA) { // tensor core v1
|
isHMMA) { // tensor core v1
|
||||||
|
// vprintf("offset 0", smemObj.offsets[0]}, rewriter);
|
||||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||||
bool isMMAv1Row =
|
bool isMMAv1Row =
|
||||||
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
@@ -3443,6 +3444,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
|
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
|
||||||
|
// LLVM::vprintf_array(i32_val(0), smemObj.offsets, "offsets ", "%d", rewriter);
|
||||||
// TODO[Superjomn]: transA is not available here.
|
// TODO[Superjomn]: transA is not available here.
|
||||||
bool transA = false;
|
bool transA = false;
|
||||||
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc,
|
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc,
|
||||||
@@ -4716,6 +4718,47 @@ private:
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void rewriteConvertToDotOperand(ModuleOp mod) {
|
||||||
|
mod.walk([&](triton::gpu::ConvertLayoutOp cvt){
|
||||||
|
OpBuilder builder(cvt);
|
||||||
|
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||||
|
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||||
|
// order
|
||||||
|
ArrayRef<unsigned> order;
|
||||||
|
if(auto srcBlockedLayout =
|
||||||
|
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||||
|
order = srcBlockedLayout.getOrder();
|
||||||
|
else if(auto srcSharedLayout =
|
||||||
|
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||||
|
order = srcSharedLayout.getOrder();
|
||||||
|
else
|
||||||
|
return;
|
||||||
|
// dot operand output
|
||||||
|
auto dstDotOperandLayout =
|
||||||
|
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||||
|
if (!dstDotOperandLayout)
|
||||||
|
return;
|
||||||
|
unsigned opIdx = dstDotOperandLayout.getOpIdx();
|
||||||
|
if(!dstDotOperandLayout.getIsMMAv1Row())
|
||||||
|
return;
|
||||||
|
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
|
if((order[0] == 1 && isMMAv1Row) ||
|
||||||
|
(order[0] == 0 && !isMMAv1Row))
|
||||||
|
return;
|
||||||
|
auto newIsRow = BoolAttr::get(cvt.getContext(), !isMMAv1Row);
|
||||||
|
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||||
|
cvt.getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
|
||||||
|
newIsRow);
|
||||||
|
auto newDstType = RankedTensorType::get(
|
||||||
|
dstType.getShape(),
|
||||||
|
dstType.getElementType(), newDstEncoding);
|
||||||
|
auto newCvt = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||||
|
cvt.getLoc(), newDstType, cvt.getOperand());
|
||||||
|
cvt.replaceAllUsesWith(newCvt.getResult());
|
||||||
|
cvt.erase();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void decomposeInsertSliceAsyncOp(ModuleOp mod) {
|
void decomposeInsertSliceAsyncOp(ModuleOp mod) {
|
||||||
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
|
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
|
||||||
axisInfoAnalysis.run(mod);
|
axisInfoAnalysis.run(mod);
|
||||||
@@ -4835,6 +4878,7 @@ public:
|
|||||||
// separation between 1/4 is that, step 3 is out of the scope of Dialect
|
// separation between 1/4 is that, step 3 is out of the scope of Dialect
|
||||||
// Conversion, thus we need to make sure the smem is not revised during the
|
// Conversion, thus we need to make sure the smem is not revised during the
|
||||||
// conversion of step 4.
|
// conversion of step 4.
|
||||||
|
rewriteConvertToDotOperand(mod);
|
||||||
decomposeMmaToDotOperand(mod, numWarps);
|
decomposeMmaToDotOperand(mod, numWarps);
|
||||||
|
|
||||||
decomposeBlockedToDotOperand(mod);
|
decomposeBlockedToDotOperand(mod);
|
||||||
@@ -4845,6 +4889,7 @@ public:
|
|||||||
MembarAnalysis membarPass(&allocation);
|
MembarAnalysis membarPass(&allocation);
|
||||||
membarPass.run();
|
membarPass.run();
|
||||||
|
|
||||||
|
llvm::outs() << mod << "\n";
|
||||||
RewritePatternSet scf_patterns(context);
|
RewritePatternSet scf_patterns(context);
|
||||||
mlir::populateLoopToStdConversionPatterns(scf_patterns);
|
mlir::populateLoopToStdConversionPatterns(scf_patterns);
|
||||||
mlir::ConversionTarget scf_target(*context);
|
mlir::ConversionTarget scf_target(*context);
|
||||||
|
@@ -713,9 +713,9 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class OptimizeBlockedToDotOperand : public mlir::RewritePattern {
|
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
||||||
public:
|
public:
|
||||||
OptimizeBlockedToDotOperand(mlir::MLIRContext *context)
|
OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||||
context) {}
|
context) {}
|
||||||
|
|
||||||
@@ -725,18 +725,27 @@ public:
|
|||||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||||
auto srcBlockedLayout =
|
// order
|
||||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
ArrayRef<unsigned> order;
|
||||||
|
if(auto srcBlockedLayout =
|
||||||
|
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||||
|
order = srcBlockedLayout.getOrder();
|
||||||
|
else if(auto srcSharedLayout =
|
||||||
|
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||||
|
order = srcSharedLayout.getOrder();
|
||||||
|
else
|
||||||
|
return failure();
|
||||||
|
// dot operand output
|
||||||
auto dstDotOperandLayout =
|
auto dstDotOperandLayout =
|
||||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||||
if (!srcBlockedLayout || !dstDotOperandLayout)
|
if (!dstDotOperandLayout)
|
||||||
return failure();
|
return failure();
|
||||||
unsigned opIdx = dstDotOperandLayout.getOpIdx();
|
unsigned opIdx = dstDotOperandLayout.getOpIdx();
|
||||||
if(!dstDotOperandLayout.getIsMMAv1Row())
|
if(!dstDotOperandLayout.getIsMMAv1Row())
|
||||||
return failure();
|
return failure();
|
||||||
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
if((srcBlockedLayout.getOrder()[0] == 1 && isMMAv1Row) ||
|
if((order[0] == 1 && isMMAv1Row) ||
|
||||||
(srcBlockedLayout.getOrder()[0] == 0 && !isMMAv1Row))
|
(order[0] == 0 && !isMMAv1Row))
|
||||||
return failure();
|
return failure();
|
||||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||||
@@ -862,7 +871,7 @@ public:
|
|||||||
mlir::RewritePatternSet patterns(context);
|
mlir::RewritePatternSet patterns(context);
|
||||||
|
|
||||||
patterns.add<OptimizeBlockedToShared>(context);
|
patterns.add<OptimizeBlockedToShared>(context);
|
||||||
patterns.add<OptimizeBlockedToDotOperand>(context);
|
// patterns.add<OptimizeConvertToDotOperand>(context);
|
||||||
patterns.add<SimplifyConversion>(context);
|
patterns.add<SimplifyConversion>(context);
|
||||||
patterns.add<DecomposeDotOperand>(context);
|
patterns.add<DecomposeDotOperand>(context);
|
||||||
patterns.add<RematerializeBackward>(context);
|
patterns.add<RematerializeBackward>(context);
|
||||||
@@ -873,6 +882,7 @@ public:
|
|||||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -297,22 +297,22 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
|||||||
# NOTE this is useful only on Volta GPU.
|
# NOTE this is useful only on Volta GPU.
|
||||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
|
||||||
# Non-forloop
|
# Non-forloop
|
||||||
[16, 16, 16, 1, 16, 16, 16, False, False],
|
# [16, 16, 16, 1, 16, 16, 16, False, False],
|
||||||
[16, 16, 32, 1, 16, 16, 32, False, False],
|
# [16, 16, 32, 1, 16, 16, 32, False, False],
|
||||||
[32, 16, 32, 1, 32, 16, 32, False, False],
|
# [32, 16, 32, 1, 32, 16, 32, False, False],
|
||||||
[32, 32, 32, 1, 32, 32, 32, False, False],
|
# [32, 32, 32, 1, 32, 32, 32, False, False],
|
||||||
[128, 32, 32, 1, 128, 32, 32, False, False],
|
# [128, 32, 32, 1, 128, 32, 32, False, False],
|
||||||
|
|
||||||
[128, 32, 32, 1, 128, 32, 32, True, False],
|
# [128, 32, 32, 1, 128, 32, 32, True, False],
|
||||||
[128, 32, 32, 1, 128, 32, 32, True, True],
|
# [128, 32, 32, 1, 128, 32, 32, True, True],
|
||||||
|
|
||||||
# split-K
|
# # split-K
|
||||||
[16, 16, 32, 1, 16, 16, 16, False, False],
|
# [16, 16, 32, 1, 16, 16, 16, False, False],
|
||||||
[64, 64, 128, 1, 64, 64, 32, False, False],
|
# [64, 64, 128, 1, 64, 64, 32, False, False],
|
||||||
|
|
||||||
[16, 16, 32, 1, 16, 16, 16, True, False],
|
# [16, 16, 32, 1, 16, 16, 16, True, False],
|
||||||
[16, 16, 32, 1, 16, 16, 16, True, True],
|
# [16, 16, 32, 1, 16, 16, 16, True, True],
|
||||||
[64, 64, 128, 1, 64, 64, 32, True, True],
|
[64, 64, 64, 1, 64, 64, 32, True, False],
|
||||||
])
|
])
|
||||||
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||||
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
||||||
|
@@ -1402,9 +1402,9 @@ def compile(fn, **kwargs):
|
|||||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
lambda src: ttir_to_ttgir(src, num_warps, num_stages, 70)),
|
||||||
"llir": (lambda path: Path(path).read_bytes(),
|
"llir": (lambda path: Path(path).read_bytes(),
|
||||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
lambda src: ttgir_to_llir(src, extern_libs, 70)),
|
||||||
"ptx": (lambda path: Path(path).read_text(),
|
"ptx": (lambda path: Path(path).read_text(),
|
||||||
lambda src: llir_to_ptx(src, capability)),
|
lambda src: llir_to_ptx(src, capability)),
|
||||||
"cubin": (lambda path: Path(path).read_bytes(),
|
"cubin": (lambda path: Path(path).read_bytes(),
|
||||||
|
Reference in New Issue
Block a user