[OPTIMIZER] Made layout simplification pass efficient for fused attention kernels (#790)
This commit is contained in:
@@ -272,7 +272,7 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||
//
|
||||
// SPMD Ops
|
||||
//
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoSideEffect]> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
@@ -280,7 +280,7 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
@@ -301,7 +301,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
@@ -324,6 +324,7 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
|
||||
];
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
|
@@ -328,4 +328,23 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
}
|
||||
|
||||
|
||||
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
|
||||
let mnemonic = "dot_op";
|
||||
|
||||
let description = [{
|
||||
In TritonGPU dialect, considering `d = tt.dot a, b, c`
|
||||
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
|
||||
a's opIdx is 0, b's opIdx is 1.
|
||||
The parend field in DotOperandEncodingAttr is the layout of d.
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$opIdx,
|
||||
"Attribute":$parent
|
||||
);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -14,7 +14,7 @@ class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<TritonGPU_Dialect, mnemonic, traits>;
|
||||
|
||||
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
|
||||
[NoSideEffect]> {
|
||||
[SameOperandsAndResultShape, NoSideEffect]> {
|
||||
let summary = "convert layout";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
@@ -32,10 +32,10 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
|
||||
// This is needed because Arith's Cmp ops don't
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
|
||||
// This is needed because these ops don't
|
||||
// handle encodings
|
||||
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
|
||||
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
|
||||
let summary = "integer comparison operation";
|
||||
|
||||
@@ -48,7 +48,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
||||
def TTG_CmpFOp : TTG_Op<"cmpf"> {
|
||||
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect]> {
|
||||
let summary = "floating-point comparison operation";
|
||||
|
||||
let description = [{}];
|
||||
@@ -60,6 +60,20 @@ def TTG_CmpFOp : TTG_Op<"cmpf"> {
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
||||
// TODO: migrate to arith::SelectOp on LLVM16
|
||||
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> {
|
||||
let summary = "select operation";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let arguments = (ins TT_BoolLike:$condition,
|
||||
TT_Tensor:$true_value,
|
||||
TT_Tensor:$false_value);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
}
|
||||
|
||||
|
||||
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
[SameVariadicOperandSize,
|
||||
// MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should?
|
||||
|
@@ -1,6 +1,7 @@
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
@@ -88,15 +89,7 @@ void populateArithmeticPatternsAndLegality(
|
||||
// non-null encoding
|
||||
// --------------
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// // Legality rule
|
||||
// target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
||||
// // TODO: check above rule here
|
||||
// [](Operation *op){
|
||||
// return true;
|
||||
// }
|
||||
// );
|
||||
// Rewrite rule
|
||||
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||
// TODO: there's probably a better way to avoid adding all ops one-by-one
|
||||
patterns.add<
|
||||
ArithConstantPattern, GenericOpPattern<arith::AddIOp>,
|
||||
GenericOpPattern<arith::SubIOp>, GenericOpPattern<arith::MulIOp>,
|
||||
@@ -121,8 +114,35 @@ void populateArithmeticPatternsAndLegality(
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||
// Cast Ops
|
||||
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>>(
|
||||
typeConverter, context);
|
||||
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
|
||||
GenericOpPattern<arith::SIToFPOp>>(typeConverter, context);
|
||||
}
|
||||
|
||||
// this shouldn't exist if mlir's SelectOp checked encodings properly
|
||||
class StdSelectPattern : public OpConversionPattern<SelectOp> {
|
||||
public:
|
||||
using OpConversionPattern<SelectOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
triton::gpu::SelectOp res =
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
||||
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
|
||||
adaptor.getFalseValue());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// Rewrite rule
|
||||
patterns.add<StdSelectPattern>(typeConverter, context);
|
||||
target.addLegalOp<ReturnOp>(); // this is ok because all functions are inlined
|
||||
// by the frontend
|
||||
}
|
||||
|
||||
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
@@ -231,7 +251,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32());
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
|
||||
adaptor.transB());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -418,6 +439,7 @@ public:
|
||||
// rewrite patterns
|
||||
RewritePatternSet patterns(context);
|
||||
// add rules
|
||||
populateStdPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateMathPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateTritonPatterns(typeConverter, patterns);
|
||||
|
@@ -12,21 +12,21 @@ include "triton/Dialect/Triton/IR/TritonOps.td"
|
||||
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
||||
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
|
||||
def CombineDotAddIPattern : Pat<
|
||||
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
def CombineDotAddFPattern : Pat<
|
||||
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
|
||||
def CombineDotAddIRevPattern : Pat<
|
||||
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
def CombineDotAddFRevPattern : Pat<
|
||||
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32),
|
||||
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
|
||||
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
|
||||
|
||||
|
@@ -484,6 +484,30 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DotOperand Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
NamedAttrList attrs;
|
||||
if (parser.parseOptionalAttrDict(attrs).failed())
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
|
||||
Attribute parent = attrs.get("parent");
|
||||
|
||||
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
|
||||
parent);
|
||||
}
|
||||
|
||||
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "opIdx = " << getOpIdx() << ", "
|
||||
<< "parent = " << getParent() << "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ASM Interface (i.e.: alias)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -34,6 +34,45 @@ namespace {
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// convert(blocked, dot_operand) ->
|
||||
// convert(blocked, mma) + convert(mma, dot_operand)
|
||||
// if this value is itself the result of a dot operation
|
||||
// this is a hueiristics to accomodate some pattern seen in fused attention
|
||||
// kernels.
|
||||
// TODO: replace this by something more generic, i.e. layout-aware CSE
|
||||
class DecomposeDotOperand : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
DecomposeDotOperand(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
|
||||
return mlir::failure();
|
||||
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
|
||||
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
auto tmpType =
|
||||
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
|
||||
dstType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent());
|
||||
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
convert.getLoc(), tmpType, convert.getOperand());
|
||||
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
convert.getLoc(), dstType, tmp);
|
||||
rewriter.replaceOp(op, {newConvert});
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
};
|
||||
|
||||
// Layout conversions can't deduce their return type automatically.
|
||||
// IIUC they are therefore not handled by DRR right now
|
||||
class SimplifyConversion : public mlir::RewritePattern {
|
||||
@@ -47,6 +86,13 @@ public:
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
|
||||
return mlir::failure();
|
||||
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accomodate fused attention
|
||||
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return mlir::failure();
|
||||
// convert to the same layout -- we can delete
|
||||
if (op->getResultTypes() == op->getOperandTypes()) {
|
||||
rewriter.replaceOp(op, op->getOperands());
|
||||
@@ -197,12 +243,16 @@ public:
|
||||
if (isSharedLayout(cvt->getResults()[0]) ||
|
||||
isSharedLayout(cvt->getOperand(0)))
|
||||
return mlir::failure();
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accomodate fused attention
|
||||
auto targetType = cvt->getResultTypes()[0].cast<RankedTensorType>();
|
||||
if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return mlir::failure();
|
||||
// DFS
|
||||
SetVector<Operation *> processed;
|
||||
SetVector<Attribute> layout;
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
std::vector<std::pair<Operation *, Attribute>> queue;
|
||||
std::vector<std::pair<Value, Attribute>> toConvert;
|
||||
queue.push_back({cvt, targetType.getEncoding()});
|
||||
int numCvts = 1;
|
||||
while (!queue.empty()) {
|
||||
@@ -222,17 +272,20 @@ public:
|
||||
// add all operands to the queue
|
||||
for (Value argI : currOp->getOperands()) {
|
||||
Attribute newEncoding;
|
||||
// cannot invert the current encoding for this operand
|
||||
// we stop everything
|
||||
if (failed(invertEncoding(currLayout, currOp, newEncoding)))
|
||||
return mlir::failure();
|
||||
toConvert.push_back({argI, newEncoding});
|
||||
if (toConvert.count(argI) && toConvert[argI] != newEncoding)
|
||||
return mlir::failure();
|
||||
//
|
||||
Operation *opArgI = argI.getDefiningOp();
|
||||
if (!opArgI)
|
||||
continue;
|
||||
toConvert.insert({argI, newEncoding});
|
||||
if (!opArgI || processed.contains(opArgI) ||
|
||||
(opArgI->getBlock() != cvt->getBlock()))
|
||||
continue;
|
||||
// if the conversion can be folded into opArgI then
|
||||
// we actually haven't added anny conversion
|
||||
// we don't count this conversion as expensive
|
||||
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
|
||||
continue;
|
||||
@@ -246,31 +299,30 @@ public:
|
||||
if (numCvts > 0)
|
||||
return mlir::failure();
|
||||
|
||||
FuncOp parentFunc = cvt->getParentOfType<FuncOp>();
|
||||
bool test = cvt->getResult(0)
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.isa<triton::gpu::MmaEncodingAttr>();
|
||||
// if (test)
|
||||
// llvm::outs() << "--------\nConverting " << *cvt << "\n---------\n";
|
||||
SmallVector<Value, 4> sortedValues;
|
||||
SetVector<Operation *> tmp;
|
||||
for (auto it = toConvert.begin(); it != toConvert.end(); ++it) {
|
||||
Value v = it->first;
|
||||
if (v.getDefiningOp())
|
||||
tmp.insert(v.getDefiningOp());
|
||||
else
|
||||
sortedValues.push_back(v);
|
||||
}
|
||||
tmp = mlir::topologicalSort(tmp);
|
||||
for (Operation *op : tmp)
|
||||
sortedValues.push_back(op->getResult(0));
|
||||
|
||||
// llvm::outs() << "----\n";
|
||||
BlockAndValueMapping mapping;
|
||||
for (int i = toConvert.size() - 1; i >= 0; i--) {
|
||||
for (Value currOperand : sortedValues) {
|
||||
// unpack information
|
||||
Value currOperand;
|
||||
Attribute targetLayout;
|
||||
std::tie(currOperand, targetLayout) = toConvert[i];
|
||||
// if (test)
|
||||
// llvm::outs() << "current " << currOperand << "\n";
|
||||
Attribute targetLayout = toConvert.lookup(currOperand);
|
||||
// rematerialize the operand if necessary
|
||||
Operation *currOperation = currOperand.getDefiningOp();
|
||||
if (processed.contains(currOperation)) {
|
||||
currOperation = cloneWithInferType(rewriter, currOperation, mapping);
|
||||
currOperand = currOperation->getResult(0);
|
||||
}
|
||||
if (i == 0)
|
||||
break;
|
||||
// compute target type for the layout cast
|
||||
auto currType = currOperand.getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
@@ -281,6 +333,7 @@ public:
|
||||
newOperand->moveAfter(currOperation);
|
||||
mapping.map(currOperand, newOperand);
|
||||
}
|
||||
// llvm::outs() << cvt->getParentOfType<mlir::FuncOp>() << "\n";
|
||||
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -290,97 +343,71 @@ public:
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// This modifies the loop in-place
|
||||
bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
|
||||
mlir::PatternRewriter &rewriter) {
|
||||
auto targetType = toPreserve.begin()->getType().cast<RankedTensorType>();
|
||||
auto newType = [&](RankedTensorType origType) {
|
||||
return RankedTensorType::get(origType.getShape(), origType.getElementType(),
|
||||
targetType.getEncoding());
|
||||
};
|
||||
bool hasSameTypes = op->getDialect()->getNamespace() == "arith" ||
|
||||
isa<triton::SplatOp, triton::AddPtrOp>(op);
|
||||
if (hasSameTypes) {
|
||||
// replace argument types
|
||||
for (auto arg : llvm::enumerate(op->getOperands())) {
|
||||
auto argType = arg.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (toPreserve.count(arg.value()) || !argType)
|
||||
continue;
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
rewriter.getUnknownLoc(), newType(argType), arg.value());
|
||||
newArg->moveBefore(op);
|
||||
op->setOperand(arg.index(), newArg);
|
||||
}
|
||||
// replace result types
|
||||
if (!isa<triton::SplatOp>(op))
|
||||
op->getResult(0).setType(op->getOperand(0).getType());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::pair<SmallVector<Value, 4>, scf::ForOp>
|
||||
tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i,
|
||||
Type newType) {
|
||||
forOp.getInductionVar();
|
||||
auto newEncoding = newType.cast<RankedTensorType>().getEncoding();
|
||||
auto ctx = forOp.getContext();
|
||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||
// Rewrite init argument
|
||||
Type origType = forOp.getInitArgs()[i].getType();
|
||||
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
||||
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
|
||||
// Clone for loop
|
||||
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), newInitArgs);
|
||||
newForOp->moveBefore(forOp);
|
||||
rewriter.setInsertionPointToStart(newForOp.getBody());
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
// traverse all ops in the loop
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
// we clone the op
|
||||
Operation *newOp = rewriter.clone(op, mapping);
|
||||
// if any argument of this op has changed type, then the
|
||||
// new operation is not legal and we should try to
|
||||
// legalize it.
|
||||
DenseSet<Value> modifiedTypes;
|
||||
for (Value arg : op.getOperands()) {
|
||||
if (mapping.contains(arg) &&
|
||||
mapping.lookup(arg).getType() != arg.getType())
|
||||
modifiedTypes.insert(mapping.lookup(arg));
|
||||
}
|
||||
|
||||
bool shouldTryLegalize = !modifiedTypes.empty();
|
||||
if (shouldTryLegalize)
|
||||
tryLegalizeOp(newOp, modifiedTypes, rewriter);
|
||||
}
|
||||
// create yield, inserting conversions if necessary
|
||||
auto yieldOp = forOp.getBody()->getTerminator();
|
||||
SmallVector<Value, 4> newYieldArgs;
|
||||
for (Value arg : yieldOp->getOperands())
|
||||
newYieldArgs.push_back(mapping.lookup(arg));
|
||||
newYieldArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
yieldOp->getLoc(), newType, newYieldArgs[i]);
|
||||
rewriter.create<scf::YieldOp>(forOp.getLoc(), newYieldArgs);
|
||||
|
||||
// replace
|
||||
SmallVector<Value, 4> newResults = newForOp->getResults();
|
||||
newResults[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
rewriter.getUnknownLoc(), origType, newForOp->getResult(i));
|
||||
newResults[i].getDefiningOp()->moveAfter(newForOp);
|
||||
return {newResults, newForOp};
|
||||
}
|
||||
// int test = 0;
|
||||
|
||||
class MoveConvertOutOfLoop : public mlir::RewritePattern {
|
||||
public:
|
||||
MoveConvertOutOfLoop(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
|
||||
|
||||
SmallVector<Value, 4>
|
||||
rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp,
|
||||
size_t i, RankedTensorType newType,
|
||||
triton::gpu::ConvertLayoutOp origConversion) const {
|
||||
|
||||
auto newEncoding = newType.cast<RankedTensorType>().getEncoding();
|
||||
auto ctx = forOp.getContext();
|
||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||
// Rewrite init argument
|
||||
Type origType = forOp.getInitArgs()[i].getType();
|
||||
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
||||
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
|
||||
// Clone for loop
|
||||
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), newInitArgs);
|
||||
newForOp->moveBefore(forOp);
|
||||
rewriter.setInsertionPointToStart(newForOp.getBody());
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(origConversion.getResult(), newForOp.getRegionIterArgs()[i]);
|
||||
// the iter arg of interest may have other uses than the conversion
|
||||
// we're hoisting out of the loop. If that's the case we will
|
||||
// need to add extra conversions for all uses... which is only useful
|
||||
// if these extra conversions can be removed by another pattern
|
||||
auto oldArg = forOp.getRegionIterArgs()[i];
|
||||
auto newArg = newForOp.getRegionIterArgs()[i];
|
||||
auto newArgFallback = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
newForOp.getLoc(), origType, newArg);
|
||||
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
if (&op == (Operation *)(&origConversion))
|
||||
continue;
|
||||
Operation *newOp = rewriter.clone(op, mapping);
|
||||
if (find(oldArg.getUsers(), &op) != oldArg.getUsers().end())
|
||||
newOp->replaceUsesOfWith(newArg, newArgFallback);
|
||||
}
|
||||
|
||||
// create yield, inserting conversions if necessary
|
||||
auto yieldOp = forOp.getBody()->getTerminator();
|
||||
SmallVector<Value, 4> newYieldArgs;
|
||||
for (Value arg : yieldOp->getOperands())
|
||||
newYieldArgs.push_back(mapping.lookup(arg));
|
||||
newYieldArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
yieldOp->getLoc(), newType, newYieldArgs[i]);
|
||||
rewriter.create<scf::YieldOp>(forOp.getLoc(), newYieldArgs);
|
||||
|
||||
// replace
|
||||
SmallVector<Value, 4> newResults = newForOp->getResults();
|
||||
newResults[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
rewriter.getUnknownLoc(), origType, newForOp->getResult(i));
|
||||
newResults[i].getDefiningOp()->moveAfter(newForOp);
|
||||
return newResults;
|
||||
}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
|
||||
@@ -388,17 +415,38 @@ public:
|
||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||
auto iterArgs = forOp.getRegionIterArgs();
|
||||
for (auto iterArg : llvm::enumerate(iterArgs)) {
|
||||
// if (iterArg.index() != 1)
|
||||
// continue;
|
||||
// skip non-tensor types
|
||||
if (!iterArg.value().getType().isa<RankedTensorType>())
|
||||
continue;
|
||||
// we only move `iterArg` out of the loop if
|
||||
// - there is only a single conversion use
|
||||
// - moving this conversion out of the loop will not generate
|
||||
// any extra non-removable conversion
|
||||
auto users = iterArg.value().getUsers();
|
||||
// check first condition
|
||||
SetVector<Type> cvtTargetTypes;
|
||||
for (auto user : users)
|
||||
if (isa<triton::gpu::ConvertLayoutOp>(user))
|
||||
cvtTargetTypes.insert(user->getResults()[0].getType());
|
||||
if (cvtTargetTypes.size() != 1)
|
||||
continue;
|
||||
// TODO: check second condition
|
||||
for (auto user : users) {
|
||||
if (isa<triton::gpu::ConvertLayoutOp>(user))
|
||||
continue;
|
||||
}
|
||||
// check
|
||||
for (auto op : iterArg.value().getUsers()) {
|
||||
if (isa<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
auto newFor = tryConvertIterArg(forOp, rewriter, iterArg.index(),
|
||||
op->getResult(0).getType());
|
||||
rewriter.replaceOp(forOp, newFor.first);
|
||||
return success();
|
||||
}
|
||||
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
if (!cvt)
|
||||
continue;
|
||||
auto targetType = op->getResultTypes()[0].cast<RankedTensorType>();
|
||||
auto newFor = rematerializeForLoop(rewriter, forOp, iterArg.index(),
|
||||
targetType, cvt);
|
||||
rewriter.replaceOp(forOp, newFor);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
@@ -434,20 +482,27 @@ public:
|
||||
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
|
||||
if (cvtSlices.empty())
|
||||
return failure();
|
||||
// if other operands are in the loop
|
||||
// then we don't touch anything
|
||||
Operation *op = cvtSlices.front();
|
||||
for (Value _arg : op->getOperands()) {
|
||||
Operation *arg = _arg.getDefiningOp();
|
||||
if (arg && isInLoop(arg) && (arg != cvt))
|
||||
|
||||
for (Operation *op : cvtSlices) {
|
||||
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
|
||||
!op->hasTrait<mlir::OpTrait::SameOperandsAndResultType>())
|
||||
return failure();
|
||||
for (Value arg : op->getOperands()) {
|
||||
Operation *argOp = arg.getDefiningOp();
|
||||
if (argOp && (argOp != cvt) &&
|
||||
!isa<arith::ConstantOp, triton::SplatOp>(argOp)) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise, we push the conversion forward
|
||||
// since we'll be able to move it out of
|
||||
// the loop once it reaches the yield op
|
||||
// op(cvt(arg_0), arg_1, ..., arg_n)
|
||||
// -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n)))
|
||||
BlockAndValueMapping mapping;
|
||||
auto op = cvtSlices.front();
|
||||
for (Value arg : op->getOperands()) {
|
||||
if (arg.getDefiningOp() == cvt)
|
||||
mapping.map(arg, cvt.getOperand());
|
||||
@@ -492,7 +547,7 @@ public:
|
||||
oldAcc.getLoc(), newRetType, oldAcc);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1),
|
||||
newAcc, dotOp.allowTF32());
|
||||
newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, oldRetType, newDot.getResult());
|
||||
@@ -515,6 +570,7 @@ public:
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
patterns.add<RematerializeForward>(context);
|
||||
patterns.add<MoveConvertOutOfLoop>(context);
|
||||
|
@@ -42,7 +42,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
addArgumentMaterialization([&](OpBuilder &builder,
|
||||
RankedTensorType tensorType, ValueRange inputs,
|
||||
Location loc) {
|
||||
llvm_unreachable("Not implemented");
|
||||
llvm_unreachable("Argument rematerialization not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
|
||||
@@ -50,7 +50,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// convert origValue to newValue
|
||||
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
llvm_unreachable("Not implemented");
|
||||
llvm_unreachable("Source rematerialization not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
|
||||
|
@@ -165,7 +165,13 @@ void init_triton_ir(py::module &&m) {
|
||||
else {
|
||||
/* issue an warning */
|
||||
}
|
||||
});
|
||||
})
|
||||
.def("replace_all_uses_with",
|
||||
[](mlir::Value &self, mlir::Value &newValue) {
|
||||
self.replaceAllUsesWith(newValue);
|
||||
})
|
||||
|
||||
;
|
||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement");
|
||||
|
||||
py::class_<mlir::Region>(m, "region")
|
||||
@@ -189,7 +195,7 @@ void init_triton_ir(py::module &&m) {
|
||||
if (self.getNumArguments() != 0)
|
||||
throw std::runtime_error(
|
||||
"This block has arguments, don't merge");
|
||||
dst.getOperations().splice(dst.end(), self.getOperations());
|
||||
dst.getOperations().splice(dst.begin(), self.getOperations());
|
||||
self.dropAllUses();
|
||||
self.erase();
|
||||
})
|
||||
@@ -262,7 +268,9 @@ void init_triton_ir(py::module &&m) {
|
||||
return mlir::succeeded(mlir::verify(self.getOperation()));
|
||||
});
|
||||
// scf Ops
|
||||
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
|
||||
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp")
|
||||
.def("get_induction_var", &mlir::scf::ForOp::getInductionVar);
|
||||
|
||||
py::class_<mlir::scf::IfOp, mlir::OpState>(m, "IfOp")
|
||||
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
|
||||
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
|
||||
@@ -501,24 +509,18 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
|
||||
// Ops
|
||||
.def("create_function",
|
||||
[](mlir::OpBuilder &self, std::string name,
|
||||
mlir::Type &funcType) -> mlir::FuncOp {
|
||||
// TODO: loc
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
|
||||
return self.create<mlir::FuncOp>(loc, name, funcTy);
|
||||
}
|
||||
throw std::runtime_error("invalid function type");
|
||||
})
|
||||
.def("get_or_insert_function",
|
||||
[](mlir::OpBuilder &self, mlir::ModuleOp &module,
|
||||
std::string &funcName, mlir::Type &funcType) -> mlir::FuncOp {
|
||||
std::string &funcName, mlir::Type &funcType,
|
||||
std::string &visibility) -> mlir::FuncOp {
|
||||
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
|
||||
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
|
||||
return self.create<mlir::FuncOp>(loc, funcName, funcTy);
|
||||
mlir::ArrayRef<mlir::NamedAttribute> attrs = {
|
||||
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
|
||||
self.getStringAttr(visibility))};
|
||||
return self.create<mlir::FuncOp>(loc, funcName, funcTy, attrs);
|
||||
}
|
||||
throw std::runtime_error("invalid function type");
|
||||
})
|
||||
@@ -648,6 +650,12 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::arith::IndexCastOp>(loc, input,
|
||||
self.getIndexType());
|
||||
})
|
||||
.def("create_index_to_si",
|
||||
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::IndexCastOp>(loc, input,
|
||||
self.getI32Type());
|
||||
})
|
||||
|
||||
.def("create_fmul",
|
||||
[](mlir::OpBuilder &self, mlir::Value &lhs,
|
||||
@@ -1065,10 +1073,11 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("create_dot",
|
||||
[](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b,
|
||||
mlir::Value &c, bool allowTF32) -> mlir::Value {
|
||||
mlir::Value &c, bool allowTF32, bool transA,
|
||||
bool transB) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
|
||||
allowTF32);
|
||||
allowTF32, transA, transB);
|
||||
})
|
||||
.def("create_exp",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||
@@ -1095,7 +1104,6 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::math::SqrtOp>(loc, val);
|
||||
})
|
||||
// .def("create_trans", &ir::builder::create_trans, ret::reference)
|
||||
.def("create_reduce",
|
||||
[](mlir::OpBuilder &self, mlir::Value &operand,
|
||||
mlir::triton::RedOp redOp, int axis) -> mlir::Value {
|
||||
@@ -1118,13 +1126,7 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::SelectOp>(loc, condition, trueValue,
|
||||
falseValue);
|
||||
})
|
||||
// // Intrinsics
|
||||
// // These have no place in the IR, and hopefully they can be removed at
|
||||
// some point .def("create_umulhi", &ir::builder::create_umulhi,
|
||||
// ret::reference) .def("create_barrier", &ir::builder::create_barrier,
|
||||
// ret::reference);
|
||||
;
|
||||
});
|
||||
|
||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||
.def(py::init<mlir::MLIRContext *>())
|
||||
@@ -1144,8 +1146,11 @@ void init_triton_ir(py::module &&m) {
|
||||
printingFlags);
|
||||
})
|
||||
.def("run",
|
||||
[](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
|
||||
return mlir::succeeded(self.run(mod.getOperation()));
|
||||
[](mlir::PassManager &self, mlir::ModuleOp &mod) {
|
||||
// TODO: maybe dump module to file and print error for better
|
||||
// diagnostics
|
||||
if (mlir::failed(self.run(mod.getOperation())))
|
||||
throw std::runtime_error("PassManager::run failed");
|
||||
})
|
||||
.def(
|
||||
"add_sccp_pass",
|
||||
@@ -1168,6 +1173,10 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("add_cse_pass",
|
||||
[](mlir::PassManager &self) { self.addPass(mlir::createCSEPass()); })
|
||||
.def("add_licm_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createLoopInvariantCodeMotionPass());
|
||||
})
|
||||
.def("add_triton_combine_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createCombineOpsPass());
|
||||
|
@@ -54,7 +54,7 @@ def test_binop_type_check():
|
||||
kernel = triton.compiler._compile(binop_type_check,
|
||||
signature="*fp32",
|
||||
device=0,
|
||||
output="ttgir")
|
||||
output="ttir")
|
||||
assert (kernel)
|
||||
# TODO: Check types of the results
|
||||
|
||||
@@ -75,6 +75,6 @@ def test_reduce_type_check():
|
||||
kernel = triton.compiler._compile(reduce_type_check,
|
||||
signature="*fp32",
|
||||
device=0,
|
||||
output="ttgir")
|
||||
output="ttir")
|
||||
assert (kernel)
|
||||
# TODO: Check types of the results
|
||||
|
@@ -64,12 +64,12 @@ def mangle_ty(ty):
|
||||
return 'fp32'
|
||||
if ty.is_fp64():
|
||||
return 'fp64'
|
||||
if ty.is_void():
|
||||
return 'V'
|
||||
if ty.is_block():
|
||||
elt = mangle_ty(ty.scalar)
|
||||
shape = '_'.join(map(str, ty.shape))
|
||||
return f'{elt}S{shape}S'
|
||||
if ty.is_void():
|
||||
return 'V'
|
||||
assert False, "Unsupported type"
|
||||
|
||||
|
||||
@@ -212,7 +212,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
||||
self.visit(init_node)
|
||||
# initialize function
|
||||
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder))
|
||||
visibility = "public" if self.is_kernel else "private"
|
||||
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility)
|
||||
self.module.push_back(fn)
|
||||
entry = fn.add_entry_block()
|
||||
arg_values = []
|
||||
@@ -585,6 +586,12 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
lb = self.builder.create_to_index(lb)
|
||||
ub = self.builder.create_to_index(ub)
|
||||
step = self.builder.create_to_index(step)
|
||||
# Create placeholder for the loop induction variable
|
||||
# We can use any value because the variable isn't a constexpr
|
||||
# but use a distinctive value (of the right type) to ease debugging
|
||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||
init_node = ast.Assign(targets=[st_target], value=ast.Num(value=0xBADF00D))
|
||||
self.visit(init_node)
|
||||
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, insert_block = sr
|
||||
@@ -609,13 +616,22 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
names.append(name)
|
||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
||||
|
||||
# create ForOp
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
||||
block.merge_block_before(for_op.get_body(0))
|
||||
|
||||
# update induction variable with actual value, and replace all uses
|
||||
self.builder.set_insertion_point_to_start(for_op.get_body(0))
|
||||
iv = self.builder.create_index_to_si(for_op.get_induction_var())
|
||||
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
|
||||
self.set_value(name, triton.language.core.tensor(iv, triton.language.core.int32))
|
||||
|
||||
# create YieldOp
|
||||
self.builder.set_insertion_point_to_end(for_op.get_body(0))
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
if len(yields) > 0:
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
for_op_region = for_op.get_body(0).get_parent()
|
||||
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
|
||||
# replace global uses with block arguments
|
||||
@@ -625,8 +641,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
# update lscope & local_defs (ForOp defines new values)
|
||||
for i, name in enumerate(names):
|
||||
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
self.set_value(name, triton.language.core.tensor(for_op.get_result(i), yields[i].type))
|
||||
|
||||
for stmt in node.orelse:
|
||||
assert False, "Don't know what to do with else after for"
|
||||
@@ -672,7 +687,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type([ret_type], arg_types)
|
||||
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types)
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
@@ -839,18 +854,16 @@ def optimize_triton_ir(mod):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_inliner_pass()
|
||||
pm.add_triton_combine_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_licm_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def make_tritongpu_ir(mod, num_warps):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_inliner_pass()
|
||||
pm.add_triton_combine_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
@@ -864,6 +877,7 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm.add_cse_pass()
|
||||
pm.add_coalesce_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_licm_pass()
|
||||
pm.add_triton_gpu_swizzle_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_cse_pass()
|
||||
|
@@ -741,7 +741,7 @@ def reshape(input, shape, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def dot(input, other, allow_tf32=True, _builder=None):
|
||||
def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=None):
|
||||
"""
|
||||
Returns the matrix product of two blocks.
|
||||
|
||||
@@ -753,7 +753,7 @@ def dot(input, other, allow_tf32=True, _builder=None):
|
||||
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
||||
"""
|
||||
allow_tf32 = _constexpr_to_value(allow_tf32)
|
||||
return semantic.dot(input, other, allow_tf32, _builder)
|
||||
return semantic.dot(input, other, allow_tf32, trans_a, trans_b, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
|
@@ -930,6 +930,8 @@ def atomic_xchg(ptr: tl.tensor,
|
||||
def dot(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
trans_a: bool,
|
||||
trans_b: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
if lhs.type.scalar.is_int():
|
||||
@@ -938,11 +940,11 @@ def dot(lhs: tl.tensor,
|
||||
else:
|
||||
_0 = builder.get_float32(0)
|
||||
ret_scalar_ty = tl.float32
|
||||
M = lhs.type.shape[0]
|
||||
N = rhs.type.shape[1]
|
||||
M = lhs.type.shape[1 if trans_a else 0]
|
||||
N = rhs.type.shape[0 if trans_b else 1]
|
||||
_0 = builder.create_splat(_0, [M, N])
|
||||
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, trans_a, trans_b),
|
||||
ret_ty)
|
||||
|
||||
|
||||
|
@@ -6,7 +6,7 @@ import triton._C.libtriton.triton as libtriton
|
||||
if __name__ == '__main__':
|
||||
|
||||
# valid source and target formats
|
||||
VALID_FORMATS = ['llvm-ir', 'ptx', 'triton-ir', 'triton-gpu-ir']
|
||||
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx']
|
||||
|
||||
# set up the argument parser
|
||||
# TODO: conditional requirements
|
||||
|
301
python/tutorials/06-fused-attention.py
Normal file
301
python/tutorials/06-fused-attention.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
|
||||
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_ptrs)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
# start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k, trans_b=True)
|
||||
qk *= sm_scale
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + start_n * stride_vk)
|
||||
p = p.to(tl.float16)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_i)
|
||||
tl.store(m_ptrs, m_i)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_DMODEL)
|
||||
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
denom = tl.load(L + off_m).to(tl.float32)
|
||||
# compute
|
||||
do = do / denom[:, None]
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_qz + off_h * stride_qh
|
||||
V += off_z * stride_qz + off_h * stride_qh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
DK += off_z * stride_qz + off_h * stride_qh
|
||||
DV += off_z * stride_qz + off_h * stride_qh
|
||||
for start_n in range(0, num_block):
|
||||
lo = start_n * BLOCK_M
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_m = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
m_ptrs = M + off_hz * N_CTX
|
||||
# initialize dv amd dk
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, k, trans_b=True)
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, v, trans_b=True)
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
|
||||
# # compute dq
|
||||
dq = tl.load(dq_ptrs)
|
||||
dq += tl.dot(ds.to(tl.float16), k)
|
||||
tl.store(dq_ptrs, dq)
|
||||
# # increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
tmp, L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.BLOCK = BLOCK
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
|
||||
num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
sm_scale = 0.3
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# triton implementation
|
||||
tri_out = attention(q, k, v, sm_scale)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
@@ -24,7 +24,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
// CHECK-NEXT: %6 -> %6
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
|
@@ -29,7 +29,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
@@ -59,11 +59,11 @@ func @reusable(%A : !tt.ptr<f16>) {
|
||||
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK-NEXT: offset = 16384, size = 8192
|
||||
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||
// CHECK-NEXT: offset = 0, size = 8192
|
||||
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
|
||||
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
return
|
||||
// CHECK-NEXT: size = 24576
|
||||
}
|
||||
|
@@ -27,7 +27,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
// CHECK: Membar 13
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
|
@@ -15,11 +15,3 @@ set(TRITON_TEST_DEPENDS
|
||||
triton-opt
|
||||
FileCheck
|
||||
)
|
||||
|
||||
add_lit_testsuite(check-triton "Running the triton regression tests"
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${TRITON_TEST_DEPENDS}
|
||||
)
|
||||
set_target_properties(check-triton PROPERTIES FOLDER "Tests")
|
||||
|
||||
add_lit_testsuites(TRITON ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TRITON_TEST_DEPENDS})
|
||||
|
@@ -113,13 +113,13 @@ func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
|
||||
%zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32>
|
||||
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
|
||||
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
|
||||
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32>
|
||||
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
|
||||
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
|
||||
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
|
||||
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32>
|
||||
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
|
||||
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
|
||||
|
||||
%ptr128x128 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
%ptr32x32 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>>
|
||||
|
@@ -5,7 +5,7 @@ func @ops() {
|
||||
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
|
||||
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
|
||||
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
|
||||
%0 = tt.dot %a, %b, %c {allowTF32 = true} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
|
||||
%0 = tt.dot %a, %b, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
@@ -662,7 +662,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
|
||||
%D = tt.dot %AA, %BB, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0>
|
||||
%D = tt.dot %AA, %BB, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0>
|
||||
|
||||
return
|
||||
}
|
||||
|
@@ -11,12 +11,12 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
|
||||
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
|
||||
%d = arith.constant dense<3.0> : tensor<128x128xf32>
|
||||
|
||||
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
|
||||
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
%res0 = arith.addf %dot_out, %d : tensor<128x128xf32>
|
||||
|
||||
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
|
||||
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
|
||||
|
||||
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
||||
|
@@ -53,7 +53,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
@@ -110,7 +110,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
@@ -160,7 +160,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A :
|
||||
scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
|
@@ -60,9 +60,9 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr<f32>>
|
||||
%46 = arith.index_cast %arg5 : i32 to index
|
||||
%47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) {
|
||||
%76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32>
|
||||
%77 = tt.load %arg15, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32>
|
||||
%78 = tt.dot %76, %77, %cst_0 {allowTF32 = true} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32>
|
||||
%76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32>
|
||||
%77 = tt.load %arg15, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32>
|
||||
%78 = tt.dot %76, %77, %cst_0 {allowTF32 = true, transA = false, transB = false} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32>
|
||||
%79 = arith.addf %arg13, %78 : tensor<64x64xf32>
|
||||
%80 = arith.muli %arg7, %c64_i32 : i32
|
||||
%81 = tt.splat %80 : (i32) -> tensor<64x64xi32>
|
||||
|
@@ -20,7 +20,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma8w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x256xf16, {{.*}}>) -> tensor<64x256xf16, [[shared_v8p1m8]]>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x64xf16, #shared> * tensor<64x256xf16, #shared> -> tensor<128x256xf32, #mma8w>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #shared> * tensor<64x256xf16, #shared> -> tensor<128x256xf32, #mma8w>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -32,7 +32,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x128xf16, {{.*}}>) -> tensor<64x128xf16, [[shared_v8p1m8]]>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x64xf16, #shared> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma4w>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #shared> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma4w>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x32xf16, {{.*}}>) -> tensor<128x32xf16, [[shared_v8p2m4]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x128xf16, {{.*}}>) -> tensor<32x128xf16, [[shared_v8p1m8]]>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma4w>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma4w>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -54,7 +54,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma2w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<32x32xf16, #shared> * tensor<32x32xf16, #shared> -> tensor<32x32xf32, #mma2w>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<32x32xf16, #shared> * tensor<32x32xf16, #shared> -> tensor<32x32xf32, #mma2w>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -65,7 +65,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma1w>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]>
|
||||
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared> * tensor<16x16xf16, #shared> -> tensor<16x16xf32, #mma1w>
|
||||
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #shared> * tensor<16x16xf16, #shared> -> tensor<16x16xf32, #mma1w>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user