[BACKEND] Add isRow attribute for DotOp tensors whose parent is mmav1 (#970)

Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
This commit is contained in:
Philippe Tillet
2022-12-11 19:01:57 -08:00
committed by GitHub
parent 4fb048873a
commit 52accd4c2b
7 changed files with 186 additions and 27 deletions

View File

@@ -416,15 +416,35 @@ In TritonGPU dialect, considering `d = tt.dot a, b, c`
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
a's opIdx is 0, b's opIdx is 1.
The parend field in DotOperandEncodingAttr is the layout of d.
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
section 9.7.13.4.1 for more details.
}];
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent
"Attribute":$parent,
"Attribute":$isMMAv1Row
);
let builders = [
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent), [{
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
isMMAv1Row = BoolAttr::get(context, true);
}
return $_get(context, opIdx, parent, isMMAv1Row);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration;
}
#endif

View File

@@ -3432,6 +3432,20 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1
DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isMMAv1Row =
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<SharedEncodingAttr>();
// Can only convert [1, 0] to row or [0, 1] to col for now
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
return Value();
}
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
// TODO[Superjomn]: transA is not available here.
bool transA = false;
@@ -3544,6 +3558,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
auto ALayout = A.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto BLayout = B.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
@@ -3555,12 +3577,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
// TODO[Superjomn]: order cannot accessed in DotOp.
SmallVector<unsigned> AOrder({1, 0});
SmallVector<unsigned> BOrder({1, 0});
bool isARow = AOrder[0] != 0;
bool isBRow = BOrder[0] != 0;
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
// TODO[Superjomn]: ld.v4 is not supported.

View File

@@ -589,15 +589,24 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
isMMAv1Row = attrs.get("isMMAv1Row");
if(!isMMAv1Row)
llvm::report_fatal_error("isMMAv1Row attribute is missing");
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent);
parent, isMMAv1Row);
}
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent() << "}>";
<< "parent = " << getParent();
if(getIsMMAv1Row())
printer << ", isMMAv1Row = " << getIsMMAv1Row();
printer << "}>";
}
//===----------------------------------------------------------------------===//

View File

@@ -715,6 +715,55 @@ public:
}
};
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
public:
OptimizeConvertToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if(auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if(auto srcSharedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return failure();
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return failure();
unsigned opIdx = dstDotOperandLayout.getOpIdx();
if(!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row = dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if((order[0] == 1 && isMMAv1Row) ||
(order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(), dstDotOperandLayout.getParent(),
newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(),
dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult());
return success();
}
};
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
@@ -772,14 +821,28 @@ public:
Value b = dotOp.b();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
auto oldBOrder = oldBType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>()
.getParent().cast<triton::gpu::BlockedEncodingAttr>().getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if(version == 1){
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding()));
newRetType.getEncoding(),
isMMAv1RowA));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding()));
newRetType.getEncoding(),
isMMAv1RowB));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(dotOp.getLoc(), newRetType, a,
@@ -791,6 +854,51 @@ public:
}
};
class FixupLoop : public mlir::RewritePattern {
public:
FixupLoop(mlir::MLIRContext *context)
: mlir::RewritePattern(scf::ForOp::getOperationName(), 2,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto forOp = cast<scf::ForOp>(op);
// Rewrite init argument
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
bool shouldRematerialize = false;
for(size_t i = 0; i < newInitArgs.size(); i++){
auto initArg = newInitArgs[i];
auto regionArg = forOp.getRegionIterArgs()[i];
if(newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()){
shouldRematerialize = true;
break;
}
}
if(!shouldRematerialize)
return failure();
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newInitArgs);
newForOp->moveBefore(forOp);
rewriter.setInsertionPointToStart(newForOp.getBody());
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
for (Operation &op : forOp.getBody()->getOperations()) {
Operation *newOp = rewriter.clone(op, mapping);
}
rewriter.replaceOp(forOp, newForOp.getResults());
return success();
}
};
} // namespace
#define GEN_PASS_CLASSES
@@ -810,6 +918,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeBlockedToShared>(context);
patterns.add<OptimizeConvertToDotOperand>(context);
patterns.add<SimplifyConversion>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);
@@ -820,6 +929,13 @@ public:
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
// llvm::outs() << m << "\n";
mlir::RewritePatternSet loopFixup(context);
loopFixup.add<FixupLoop>(context);
if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) {
signalPassFailure();
}
}
};

View File

@@ -32,7 +32,7 @@ def matmul_no_scf_kernel(
(shape, num_warps, trans_a, trans_b)
for shape in [
[128, 256, 32],
[256, 128, 16],
# [256, 128, 16],
[128, 16, 32],
[32, 128, 64],
[128, 128, 64],
@@ -43,8 +43,6 @@ def matmul_no_scf_kernel(
for trans_b in [False, True]
])
def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B)
SIZE_M, SIZE_N, SIZE_K = SHAPE
if (TRANS_A):
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
@@ -83,7 +81,7 @@ def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
for trans_b in [False, True]
])
def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B, is_int8=True)
guard_for_volta(is_int8=True)
SIZE_M, SIZE_N, SIZE_K = SHAPE
@@ -199,7 +197,6 @@ def get_proper_err(a, b, golden):
[128, 64, 128, 4, 128, 64, 32, False, True],
])
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
guard_for_volta(NUM_WARPS, TRANS_A, TRANS_B)
if (TRANS_A):
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
@@ -276,7 +273,7 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, c_mask)
guard_for_volta(num_warps, trans_a=False, trans_b=False, is_tf32=allow_tf32)
guard_for_volta(is_tf32=allow_tf32)
# Configure the pytorch counterpart
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
@@ -302,7 +299,7 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
def guard_for_volta(num_warps, trans_a, trans_b, is_int8=False, is_tf32=False):
def guard_for_volta(is_int8=False, is_tf32=False):
'''
Tell whether the test case is valid on Volta GPU.
Some features are WIP, so the corresponding support are missing.
@@ -311,8 +308,7 @@ def guard_for_volta(num_warps, trans_a, trans_b, is_int8=False, is_tf32=False):
is_on_Volta = capability[0] < 8
# TODO[Superjomn]: Remove the constraints below when features are ready
is_feature_supported = not (is_int8 or is_tf32)
is_feature_ready = not (trans_a or trans_b)
if is_on_Volta:
if (not is_feature_supported) or (not is_feature_ready):
pytest.skip("Not valid on Volta")
if (not is_feature_supported):
pytest.skip("Not valid on Volta")

View File

@@ -1385,6 +1385,8 @@ arg_type_pattern = {
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
def compile(fn, **kwargs):
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
# we get the kernel, i.e. the first function generated in the module
# if fn is not a JITFunction, then it
# has to be a path to a file
@@ -1392,11 +1394,9 @@ def compile(fn, **kwargs):
asm = dict()
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3)
num_stages = kwargs.get("num_stages", 3 if capability >= 75 else 2)
extern_libs = kwargs.get("extern_libs", dict())
device = kwargs.get("device", torch.cuda.current_device())
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
# build compilation stages
stages = {
"ast": (lambda path: fn, None),

View File

@@ -879,8 +879,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {