[OPTIMIZER] Made layout simplification pass efficient for fused attention kernels (#790)

This commit is contained in:
Philippe Tillet
2022-10-21 16:52:15 -07:00
committed by GitHub
parent c4726333bf
commit bb0f9235d1
26 changed files with 683 additions and 229 deletions

View File

@@ -272,7 +272,7 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
//
// SPMD Ops
//
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoSideEffect]> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
@@ -280,7 +280,7 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
let assemblyFormat = "attr-dict `:` type($result)";
}
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
@@ -301,7 +301,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB);
let results = (outs TT_FpIntTensor:$d);
@@ -324,6 +324,7 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
];
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
}
//

View File

@@ -328,4 +328,23 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
}
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
let mnemonic = "dot_op";
let description = [{
In TritonGPU dialect, considering `d = tt.dot a, b, c`
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
a's opIdx is 0, b's opIdx is 1.
The parend field in DotOperandEncodingAttr is the layout of d.
}];
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent
);
let extraClassDeclaration = extraBaseClassDeclaration;
}
#endif

View File

@@ -14,7 +14,7 @@ class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>;
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
[NoSideEffect]> {
[SameOperandsAndResultShape, NoSideEffect]> {
let summary = "convert layout";
let arguments = (ins TT_Tensor:$src);
@@ -32,10 +32,10 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
let assemblyFormat = "attr-dict";
}
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
// This is needed because Arith's Cmp ops don't
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
// This is needed because these ops don't
// handle encodings
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
let summary = "integer comparison operation";
@@ -48,7 +48,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
let results = (outs TT_BoolLike:$result);
}
def TTG_CmpFOp : TTG_Op<"cmpf"> {
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect]> {
let summary = "floating-point comparison operation";
let description = [{}];
@@ -60,6 +60,20 @@ def TTG_CmpFOp : TTG_Op<"cmpf"> {
let results = (outs TT_BoolLike:$result);
}
// TODO: migrate to arith::SelectOp on LLVM16
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> {
let summary = "select operation";
let description = [{}];
let arguments = (ins TT_BoolLike:$condition,
TT_Tensor:$true_value,
TT_Tensor:$false_value);
let results = (outs TT_Tensor:$result);
}
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[SameVariadicOperandSize,
// MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should?

View File

@@ -1,6 +1,7 @@
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -88,15 +89,7 @@ void populateArithmeticPatternsAndLegality(
// non-null encoding
// --------------
MLIRContext *context = patterns.getContext();
// // Legality rule
// target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
// // TODO: check above rule here
// [](Operation *op){
// return true;
// }
// );
// Rewrite rule
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
// TODO: there's probably a better way to avoid adding all ops one-by-one
patterns.add<
ArithConstantPattern, GenericOpPattern<arith::AddIOp>,
GenericOpPattern<arith::SubIOp>, GenericOpPattern<arith::MulIOp>,
@@ -121,8 +114,35 @@ void populateArithmeticPatternsAndLegality(
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
// Cast Ops
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>>(
typeConverter, context);
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
GenericOpPattern<arith::SIToFPOp>>(typeConverter, context);
}
// this shouldn't exist if mlir's SelectOp checked encodings properly
class StdSelectPattern : public OpConversionPattern<SelectOp> {
public:
using OpConversionPattern<SelectOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
triton::gpu::SelectOp res =
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
return success();
}
};
void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns,
TritonGPUConversionTarget &target) {
MLIRContext *context = patterns.getContext();
// Rewrite rule
patterns.add<StdSelectPattern>(typeConverter, context);
target.addLegalOp<ReturnOp>(); // this is ok because all functions are inlined
// by the frontend
}
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
@@ -231,7 +251,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32());
op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(),
adaptor.transB());
return success();
}
};
@@ -418,6 +439,7 @@ public:
// rewrite patterns
RewritePatternSet patterns(context);
// add rules
populateStdPatternsAndLegality(typeConverter, patterns, target);
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
populateMathPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns);

View File

@@ -12,21 +12,21 @@ include "triton/Dialect/Triton/IR/TritonOps.td"
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
def CombineDotAddIPattern : Pat<
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddIRevPattern : Pat<
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d),
(TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB),
[(Constraint<CPred<"isZero($0)">> $c)]>;

View File

@@ -484,6 +484,30 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
return success();
}
//===----------------------------------------------------------------------===//
// DotOperand Encoding
//===----------------------------------------------------------------------===//
Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
NamedAttrList attrs;
if (parser.parseOptionalAttrDict(attrs).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent);
}
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent() << "}>";
}
//===----------------------------------------------------------------------===//
// ASM Interface (i.e.: alias)
//===----------------------------------------------------------------------===//

View File

@@ -34,6 +34,45 @@ namespace {
//
// -----------------------------------------------------------------------------
// convert(blocked, dot_operand) ->
// convert(blocked, mma) + convert(mma, dot_operand)
// if this value is itself the result of a dot operation
// this is a hueiristics to accomodate some pattern seen in fused attention
// kernels.
// TODO: replace this by something more generic, i.e. layout-aware CSE
class DecomposeDotOperand : public mlir::RewritePattern {
public:
DecomposeDotOperand(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
return mlir::failure();
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>();
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
auto tmpType =
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
dstType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent());
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
convert.getLoc(), tmpType, convert.getOperand());
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
convert.getLoc(), dstType, tmp);
rewriter.replaceOp(op, {newConvert});
return mlir::success();
}
return mlir::failure();
}
};
// Layout conversions can't deduce their return type automatically.
// IIUC they are therefore not handled by DRR right now
class SimplifyConversion : public mlir::RewritePattern {
@@ -47,6 +86,13 @@ public:
mlir::PatternRewriter &rewriter) const override {
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
return mlir::failure();
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>();
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accomodate fused attention
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
return mlir::failure();
// convert to the same layout -- we can delete
if (op->getResultTypes() == op->getOperandTypes()) {
rewriter.replaceOp(op, op->getOperands());
@@ -197,12 +243,16 @@ public:
if (isSharedLayout(cvt->getResults()[0]) ||
isSharedLayout(cvt->getOperand(0)))
return mlir::failure();
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accomodate fused attention
auto targetType = cvt->getResultTypes()[0].cast<RankedTensorType>();
if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
return mlir::failure();
// DFS
SetVector<Operation *> processed;
SetVector<Attribute> layout;
llvm::MapVector<Value, Attribute> toConvert;
std::vector<std::pair<Operation *, Attribute>> queue;
std::vector<std::pair<Value, Attribute>> toConvert;
queue.push_back({cvt, targetType.getEncoding()});
int numCvts = 1;
while (!queue.empty()) {
@@ -222,17 +272,20 @@ public:
// add all operands to the queue
for (Value argI : currOp->getOperands()) {
Attribute newEncoding;
// cannot invert the current encoding for this operand
// we stop everything
if (failed(invertEncoding(currLayout, currOp, newEncoding)))
return mlir::failure();
toConvert.push_back({argI, newEncoding});
if (toConvert.count(argI) && toConvert[argI] != newEncoding)
return mlir::failure();
//
Operation *opArgI = argI.getDefiningOp();
if (!opArgI)
continue;
toConvert.insert({argI, newEncoding});
if (!opArgI || processed.contains(opArgI) ||
(opArgI->getBlock() != cvt->getBlock()))
continue;
// if the conversion can be folded into opArgI then
// we actually haven't added anny conversion
// we don't count this conversion as expensive
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
continue;
@@ -246,31 +299,30 @@ public:
if (numCvts > 0)
return mlir::failure();
FuncOp parentFunc = cvt->getParentOfType<FuncOp>();
bool test = cvt->getResult(0)
.getType()
.cast<RankedTensorType>()
.getEncoding()
.isa<triton::gpu::MmaEncodingAttr>();
// if (test)
// llvm::outs() << "--------\nConverting " << *cvt << "\n---------\n";
SmallVector<Value, 4> sortedValues;
SetVector<Operation *> tmp;
for (auto it = toConvert.begin(); it != toConvert.end(); ++it) {
Value v = it->first;
if (v.getDefiningOp())
tmp.insert(v.getDefiningOp());
else
sortedValues.push_back(v);
}
tmp = mlir::topologicalSort(tmp);
for (Operation *op : tmp)
sortedValues.push_back(op->getResult(0));
// llvm::outs() << "----\n";
BlockAndValueMapping mapping;
for (int i = toConvert.size() - 1; i >= 0; i--) {
for (Value currOperand : sortedValues) {
// unpack information
Value currOperand;
Attribute targetLayout;
std::tie(currOperand, targetLayout) = toConvert[i];
// if (test)
// llvm::outs() << "current " << currOperand << "\n";
Attribute targetLayout = toConvert.lookup(currOperand);
// rematerialize the operand if necessary
Operation *currOperation = currOperand.getDefiningOp();
if (processed.contains(currOperation)) {
currOperation = cloneWithInferType(rewriter, currOperation, mapping);
currOperand = currOperation->getResult(0);
}
if (i == 0)
break;
// compute target type for the layout cast
auto currType = currOperand.getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
@@ -281,6 +333,7 @@ public:
newOperand->moveAfter(currOperation);
mapping.map(currOperand, newOperand);
}
// llvm::outs() << cvt->getParentOfType<mlir::FuncOp>() << "\n";
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
return mlir::success();
}
@@ -290,97 +343,71 @@ public:
//
// -----------------------------------------------------------------------------
// This modifies the loop in-place
bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
mlir::PatternRewriter &rewriter) {
auto targetType = toPreserve.begin()->getType().cast<RankedTensorType>();
auto newType = [&](RankedTensorType origType) {
return RankedTensorType::get(origType.getShape(), origType.getElementType(),
targetType.getEncoding());
};
bool hasSameTypes = op->getDialect()->getNamespace() == "arith" ||
isa<triton::SplatOp, triton::AddPtrOp>(op);
if (hasSameTypes) {
// replace argument types
for (auto arg : llvm::enumerate(op->getOperands())) {
auto argType = arg.value().getType().dyn_cast<RankedTensorType>();
if (toPreserve.count(arg.value()) || !argType)
continue;
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
rewriter.getUnknownLoc(), newType(argType), arg.value());
newArg->moveBefore(op);
op->setOperand(arg.index(), newArg);
}
// replace result types
if (!isa<triton::SplatOp>(op))
op->getResult(0).setType(op->getOperand(0).getType());
return true;
}
return false;
}
std::pair<SmallVector<Value, 4>, scf::ForOp>
tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i,
Type newType) {
forOp.getInductionVar();
auto newEncoding = newType.cast<RankedTensorType>().getEncoding();
auto ctx = forOp.getContext();
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
// Rewrite init argument
Type origType = forOp.getInitArgs()[i].getType();
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
// Clone for loop
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newInitArgs);
newForOp->moveBefore(forOp);
rewriter.setInsertionPointToStart(newForOp.getBody());
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
// traverse all ops in the loop
for (Operation &op : forOp.getBody()->without_terminator()) {
// we clone the op
Operation *newOp = rewriter.clone(op, mapping);
// if any argument of this op has changed type, then the
// new operation is not legal and we should try to
// legalize it.
DenseSet<Value> modifiedTypes;
for (Value arg : op.getOperands()) {
if (mapping.contains(arg) &&
mapping.lookup(arg).getType() != arg.getType())
modifiedTypes.insert(mapping.lookup(arg));
}
bool shouldTryLegalize = !modifiedTypes.empty();
if (shouldTryLegalize)
tryLegalizeOp(newOp, modifiedTypes, rewriter);
}
// create yield, inserting conversions if necessary
auto yieldOp = forOp.getBody()->getTerminator();
SmallVector<Value, 4> newYieldArgs;
for (Value arg : yieldOp->getOperands())
newYieldArgs.push_back(mapping.lookup(arg));
newYieldArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
yieldOp->getLoc(), newType, newYieldArgs[i]);
rewriter.create<scf::YieldOp>(forOp.getLoc(), newYieldArgs);
// replace
SmallVector<Value, 4> newResults = newForOp->getResults();
newResults[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
rewriter.getUnknownLoc(), origType, newForOp->getResult(i));
newResults[i].getDefiningOp()->moveAfter(newForOp);
return {newResults, newForOp};
}
// int test = 0;
class MoveConvertOutOfLoop : public mlir::RewritePattern {
public:
MoveConvertOutOfLoop(mlir::MLIRContext *context)
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
SmallVector<Value, 4>
rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp,
size_t i, RankedTensorType newType,
triton::gpu::ConvertLayoutOp origConversion) const {
auto newEncoding = newType.cast<RankedTensorType>().getEncoding();
auto ctx = forOp.getContext();
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
// Rewrite init argument
Type origType = forOp.getInitArgs()[i].getType();
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
// Clone for loop
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newInitArgs);
newForOp->moveBefore(forOp);
rewriter.setInsertionPointToStart(newForOp.getBody());
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(origConversion.getResult(), newForOp.getRegionIterArgs()[i]);
// the iter arg of interest may have other uses than the conversion
// we're hoisting out of the loop. If that's the case we will
// need to add extra conversions for all uses... which is only useful
// if these extra conversions can be removed by another pattern
auto oldArg = forOp.getRegionIterArgs()[i];
auto newArg = newForOp.getRegionIterArgs()[i];
auto newArgFallback = rewriter.create<triton::gpu::ConvertLayoutOp>(
newForOp.getLoc(), origType, newArg);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
for (Operation &op : forOp.getBody()->without_terminator()) {
if (&op == (Operation *)(&origConversion))
continue;
Operation *newOp = rewriter.clone(op, mapping);
if (find(oldArg.getUsers(), &op) != oldArg.getUsers().end())
newOp->replaceUsesOfWith(newArg, newArgFallback);
}
// create yield, inserting conversions if necessary
auto yieldOp = forOp.getBody()->getTerminator();
SmallVector<Value, 4> newYieldArgs;
for (Value arg : yieldOp->getOperands())
newYieldArgs.push_back(mapping.lookup(arg));
newYieldArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
yieldOp->getLoc(), newType, newYieldArgs[i]);
rewriter.create<scf::YieldOp>(forOp.getLoc(), newYieldArgs);
// replace
SmallVector<Value, 4> newResults = newForOp->getResults();
newResults[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
rewriter.getUnknownLoc(), origType, newForOp->getResult(i));
newResults[i].getDefiningOp()->moveAfter(newForOp);
return newResults;
}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const {
@@ -388,17 +415,38 @@ public:
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
auto iterArgs = forOp.getRegionIterArgs();
for (auto iterArg : llvm::enumerate(iterArgs)) {
// if (iterArg.index() != 1)
// continue;
// skip non-tensor types
if (!iterArg.value().getType().isa<RankedTensorType>())
continue;
// we only move `iterArg` out of the loop if
// - there is only a single conversion use
// - moving this conversion out of the loop will not generate
// any extra non-removable conversion
auto users = iterArg.value().getUsers();
// check first condition
SetVector<Type> cvtTargetTypes;
for (auto user : users)
if (isa<triton::gpu::ConvertLayoutOp>(user))
cvtTargetTypes.insert(user->getResults()[0].getType());
if (cvtTargetTypes.size() != 1)
continue;
// TODO: check second condition
for (auto user : users) {
if (isa<triton::gpu::ConvertLayoutOp>(user))
continue;
}
// check
for (auto op : iterArg.value().getUsers()) {
if (isa<triton::gpu::ConvertLayoutOp>(op)) {
auto newFor = tryConvertIterArg(forOp, rewriter, iterArg.index(),
op->getResult(0).getType());
rewriter.replaceOp(forOp, newFor.first);
return success();
}
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
if (!cvt)
continue;
auto targetType = op->getResultTypes()[0].cast<RankedTensorType>();
auto newFor = rematerializeForLoop(rewriter, forOp, iterArg.index(),
targetType, cvt);
rewriter.replaceOp(forOp, newFor);
return success();
}
}
return failure();
@@ -434,20 +482,27 @@ public:
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
if (cvtSlices.empty())
return failure();
// if other operands are in the loop
// then we don't touch anything
Operation *op = cvtSlices.front();
for (Value _arg : op->getOperands()) {
Operation *arg = _arg.getDefiningOp();
if (arg && isInLoop(arg) && (arg != cvt))
for (Operation *op : cvtSlices) {
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
!op->hasTrait<mlir::OpTrait::SameOperandsAndResultType>())
return failure();
for (Value arg : op->getOperands()) {
Operation *argOp = arg.getDefiningOp();
if (argOp && (argOp != cvt) &&
!isa<arith::ConstantOp, triton::SplatOp>(argOp)) {
return failure();
}
}
}
// otherwise, we push the conversion forward
// since we'll be able to move it out of
// the loop once it reaches the yield op
// op(cvt(arg_0), arg_1, ..., arg_n)
// -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n)))
BlockAndValueMapping mapping;
auto op = cvtSlices.front();
for (Value arg : op->getOperands()) {
if (arg.getDefiningOp() == cvt)
mapping.map(arg, cvt.getOperand());
@@ -492,7 +547,7 @@ public:
oldAcc.getLoc(), newRetType, oldAcc);
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1),
newAcc, dotOp.allowTF32());
newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());
@@ -515,6 +570,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<SimplifyConversion>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);
patterns.add<RematerializeForward>(context);
patterns.add<MoveConvertOutOfLoop>(context);

View File

@@ -42,7 +42,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) {
llvm_unreachable("Not implemented");
llvm_unreachable("Argument rematerialization not implemented");
return llvm::None;
});
@@ -50,7 +50,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// convert origValue to newValue
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
llvm_unreachable("Source rematerialization not implemented");
return llvm::None;
});

View File

@@ -165,7 +165,13 @@ void init_triton_ir(py::module &&m) {
else {
/* issue an warning */
}
});
})
.def("replace_all_uses_with",
[](mlir::Value &self, mlir::Value &newValue) {
self.replaceAllUsesWith(newValue);
})
;
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement");
py::class_<mlir::Region>(m, "region")
@@ -189,7 +195,7 @@ void init_triton_ir(py::module &&m) {
if (self.getNumArguments() != 0)
throw std::runtime_error(
"This block has arguments, don't merge");
dst.getOperations().splice(dst.end(), self.getOperations());
dst.getOperations().splice(dst.begin(), self.getOperations());
self.dropAllUses();
self.erase();
})
@@ -262,7 +268,9 @@ void init_triton_ir(py::module &&m) {
return mlir::succeeded(mlir::verify(self.getOperation()));
});
// scf Ops
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp");
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp")
.def("get_induction_var", &mlir::scf::ForOp::getInductionVar);
py::class_<mlir::scf::IfOp, mlir::OpState>(m, "IfOp")
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
@@ -501,24 +509,18 @@ void init_triton_ir(py::module &&m) {
})
// Ops
.def("create_function",
[](mlir::OpBuilder &self, std::string name,
mlir::Type &funcType) -> mlir::FuncOp {
// TODO: loc
auto loc = self.getUnknownLoc();
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
return self.create<mlir::FuncOp>(loc, name, funcTy);
}
throw std::runtime_error("invalid function type");
})
.def("get_or_insert_function",
[](mlir::OpBuilder &self, mlir::ModuleOp &module,
std::string &funcName, mlir::Type &funcType) -> mlir::FuncOp {
std::string &funcName, mlir::Type &funcType,
std::string &visibility) -> mlir::FuncOp {
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
auto loc = self.getUnknownLoc();
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
return self.create<mlir::FuncOp>(loc, funcName, funcTy);
mlir::ArrayRef<mlir::NamedAttribute> attrs = {
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
self.getStringAttr(visibility))};
return self.create<mlir::FuncOp>(loc, funcName, funcTy, attrs);
}
throw std::runtime_error("invalid function type");
})
@@ -648,6 +650,12 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::arith::IndexCastOp>(loc, input,
self.getIndexType());
})
.def("create_index_to_si",
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::IndexCastOp>(loc, input,
self.getI32Type());
})
.def("create_fmul",
[](mlir::OpBuilder &self, mlir::Value &lhs,
@@ -1065,10 +1073,11 @@ void init_triton_ir(py::module &&m) {
})
.def("create_dot",
[](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b,
mlir::Value &c, bool allowTF32) -> mlir::Value {
mlir::Value &c, bool allowTF32, bool transA,
bool transB) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
allowTF32);
allowTF32, transA, transB);
})
.def("create_exp",
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
@@ -1095,7 +1104,6 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::math::SqrtOp>(loc, val);
})
// .def("create_trans", &ir::builder::create_trans, ret::reference)
.def("create_reduce",
[](mlir::OpBuilder &self, mlir::Value &operand,
mlir::triton::RedOp redOp, int axis) -> mlir::Value {
@@ -1118,13 +1126,7 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::SelectOp>(loc, condition, trueValue,
falseValue);
})
// // Intrinsics
// // These have no place in the IR, and hopefully they can be removed at
// some point .def("create_umulhi", &ir::builder::create_umulhi,
// ret::reference) .def("create_barrier", &ir::builder::create_barrier,
// ret::reference);
;
});
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())
@@ -1144,8 +1146,11 @@ void init_triton_ir(py::module &&m) {
printingFlags);
})
.def("run",
[](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
return mlir::succeeded(self.run(mod.getOperation()));
[](mlir::PassManager &self, mlir::ModuleOp &mod) {
// TODO: maybe dump module to file and print error for better
// diagnostics
if (mlir::failed(self.run(mod.getOperation())))
throw std::runtime_error("PassManager::run failed");
})
.def(
"add_sccp_pass",
@@ -1168,6 +1173,10 @@ void init_triton_ir(py::module &&m) {
})
.def("add_cse_pass",
[](mlir::PassManager &self) { self.addPass(mlir::createCSEPass()); })
.def("add_licm_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createLoopInvariantCodeMotionPass());
})
.def("add_triton_combine_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::triton::createCombineOpsPass());

View File

@@ -54,7 +54,7 @@ def test_binop_type_check():
kernel = triton.compiler._compile(binop_type_check,
signature="*fp32",
device=0,
output="ttgir")
output="ttir")
assert (kernel)
# TODO: Check types of the results
@@ -75,6 +75,6 @@ def test_reduce_type_check():
kernel = triton.compiler._compile(reduce_type_check,
signature="*fp32",
device=0,
output="ttgir")
output="ttir")
assert (kernel)
# TODO: Check types of the results

View File

@@ -64,12 +64,12 @@ def mangle_ty(ty):
return 'fp32'
if ty.is_fp64():
return 'fp64'
if ty.is_void():
return 'V'
if ty.is_block():
elt = mangle_ty(ty.scalar)
shape = '_'.join(map(str, ty.shape))
return f'{elt}S{shape}S'
if ty.is_void():
return 'V'
assert False, "Unsupported type"
@@ -212,7 +212,8 @@ class CodeGenerator(ast.NodeVisitor):
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node)
# initialize function
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder))
visibility = "public" if self.is_kernel else "private"
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility)
self.module.push_back(fn)
entry = fn.add_entry_block()
arg_values = []
@@ -585,6 +586,12 @@ class CodeGenerator(ast.NodeVisitor):
lb = self.builder.create_to_index(lb)
ub = self.builder.create_to_index(ub)
step = self.builder.create_to_index(step)
# Create placeholder for the loop induction variable
# We can use any value because the variable isn't a constexpr
# but use a distinctive value (of the right type) to ease debugging
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
init_node = ast.Assign(targets=[st_target], value=ast.Num(value=0xBADF00D))
self.visit(init_node)
with enter_sub_region(self) as sr:
liveins, insert_block = sr
@@ -609,13 +616,22 @@ class CodeGenerator(ast.NodeVisitor):
names.append(name)
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
# create ForOp
self.builder.set_insertion_point_to_end(insert_block)
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
block.merge_block_before(for_op.get_body(0))
# update induction variable with actual value, and replace all uses
self.builder.set_insertion_point_to_start(for_op.get_body(0))
iv = self.builder.create_index_to_si(for_op.get_induction_var())
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
self.set_value(name, triton.language.core.tensor(iv, triton.language.core.int32))
# create YieldOp
self.builder.set_insertion_point_to_end(for_op.get_body(0))
self.builder.create_yield_op([y.handle for y in yields])
if len(yields) > 0:
self.builder.create_yield_op([y.handle for y in yields])
for_op_region = for_op.get_body(0).get_parent()
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
# replace global uses with block arguments
@@ -625,8 +641,7 @@ class CodeGenerator(ast.NodeVisitor):
# update lscope & local_defs (ForOp defines new values)
for i, name in enumerate(names):
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
self.set_value(name, triton.language.core.tensor(for_op.get_result(i), yields[i].type))
for stmt in node.orelse:
assert False, "Don't know what to do with else after for"
@@ -672,7 +687,7 @@ class CodeGenerator(ast.NodeVisitor):
ret_type = triton.language.void
prototype = triton.language.function_type([ret_type], arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types)
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
@@ -839,18 +854,16 @@ def optimize_triton_ir(mod):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_licm_pass()
pm.run(mod)
return mod
def make_tritongpu_ir(mod, num_warps):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.run(mod)
return mod
@@ -864,6 +877,7 @@ def optimize_tritongpu_ir(mod, num_stages):
pm.add_cse_pass()
pm.add_coalesce_pass()
pm.add_triton_gpu_combine_pass()
pm.add_licm_pass()
pm.add_triton_gpu_swizzle_pass()
pm.add_triton_gpu_combine_pass()
pm.add_cse_pass()

View File

@@ -741,7 +741,7 @@ def reshape(input, shape, _builder=None):
@builtin
def dot(input, other, allow_tf32=True, _builder=None):
def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=None):
"""
Returns the matrix product of two blocks.
@@ -753,7 +753,7 @@ def dot(input, other, allow_tf32=True, _builder=None):
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
"""
allow_tf32 = _constexpr_to_value(allow_tf32)
return semantic.dot(input, other, allow_tf32, _builder)
return semantic.dot(input, other, allow_tf32, trans_a, trans_b, _builder)
# -----------------------

View File

@@ -930,6 +930,8 @@ def atomic_xchg(ptr: tl.tensor,
def dot(lhs: tl.tensor,
rhs: tl.tensor,
allow_tf32: bool,
trans_a: bool,
trans_b: bool,
builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
if lhs.type.scalar.is_int():
@@ -938,11 +940,11 @@ def dot(lhs: tl.tensor,
else:
_0 = builder.get_float32(0)
ret_scalar_ty = tl.float32
M = lhs.type.shape[0]
N = rhs.type.shape[1]
M = lhs.type.shape[1 if trans_a else 0]
N = rhs.type.shape[0 if trans_b else 1]
_0 = builder.create_splat(_0, [M, N])
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, trans_a, trans_b),
ret_ty)

View File

@@ -6,7 +6,7 @@ import triton._C.libtriton.triton as libtriton
if __name__ == '__main__':
# valid source and target formats
VALID_FORMATS = ['llvm-ir', 'ptx', 'triton-ir', 'triton-gpu-ir']
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx']
# set up the argument parser
# TODO: conditional requirements

View File

@@ -0,0 +1,301 @@
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
"""
import pytest
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
# start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk *= sm_scale
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
p = p.to(tl.float16)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, k, trans_b=True)
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
# # compute dq
dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(tl.float16), k)
tl.store(dq_ptrs, dq)
# # increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q, k, v, sm_scale,
tmp, L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, l, m = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
num_stages=1,
)
return dq, dk, dv, None
attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
tri_out = attention(q, k, v, sm_scale)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
triton.testing.assert_almost_equal(ref_dv, tri_dv)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)

View File

@@ -24,7 +24,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
// CHECK-NEXT: %6 -> %6
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>

View File

@@ -29,7 +29,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
// CHECK-NEXT: offset = 8192, size = 8192
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
@@ -59,11 +59,11 @@ func @reusable(%A : !tt.ptr<f16>) {
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK-NEXT: offset = 16384, size = 8192
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
// CHECK-NEXT: offset = 0, size = 8192
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A>
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
return
// CHECK-NEXT: size = 24576
}

View File

@@ -27,7 +27,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
// CHECK: Membar 13
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>

View File

@@ -15,11 +15,3 @@ set(TRITON_TEST_DEPENDS
triton-opt
FileCheck
)
add_lit_testsuite(check-triton "Running the triton regression tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${TRITON_TEST_DEPENDS}
)
set_target_properties(check-triton PROPERTIES FOLDER "Tests")
add_lit_testsuites(TRITON ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TRITON_TEST_DEPENDS})

View File

@@ -113,13 +113,13 @@ func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
%zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32>
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32>
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32>
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
%ptr128x128 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
%ptr32x32 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>>

View File

@@ -5,7 +5,7 @@ func @ops() {
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
%0 = tt.dot %a, %b, %c {allowTF32 = true} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
%0 = tt.dot %a, %b, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
return
}

View File

@@ -662,7 +662,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
%D = tt.dot %AA, %BB, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0>
%D = tt.dot %AA, %BB, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0>
return
}

View File

@@ -11,12 +11,12 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
%d = arith.constant dense<3.0> : tensor<128x128xf32>
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
%dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
%res0 = arith.addf %dot_out, %d : tensor<128x128xf32>
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
%res1 = arith.addf %d, %dot_out : tensor<128x128xf32>
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>

View File

@@ -53,7 +53,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
@@ -110,7 +110,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
@@ -160,7 +160,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A :
scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
}

View File

@@ -60,9 +60,9 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
%45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr<f32>>
%46 = arith.index_cast %arg5 : i32 to index
%47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) {
%76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32>
%77 = tt.load %arg15, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32>
%78 = tt.dot %76, %77, %cst_0 {allowTF32 = true} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32>
%76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32>
%77 = tt.load %arg15, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32>
%78 = tt.dot %76, %77, %cst_0 {allowTF32 = true, transA = false, transB = false} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32>
%79 = arith.addf %arg13, %78 : tensor<64x64xf32>
%80 = arith.muli %arg7, %c64_i32 : i32
%81 = tt.splat %80 : (i32) -> tensor<64x64xi32>

View File

@@ -20,7 +20,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma8w>
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]>
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x256xf16, {{.*}}>) -> tensor<64x256xf16, [[shared_v8p1m8]]>
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x64xf16, #shared> * tensor<64x256xf16, #shared> -> tensor<128x256xf32, #mma8w>
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #shared> * tensor<64x256xf16, #shared> -> tensor<128x256xf32, #mma8w>
return
}
}
@@ -32,7 +32,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w>
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]>
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x128xf16, {{.*}}>) -> tensor<64x128xf16, [[shared_v8p1m8]]>
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x64xf16, #shared> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma4w>
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #shared> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma4w>
return
}
}
@@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w>
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x32xf16, {{.*}}>) -> tensor<128x32xf16, [[shared_v8p2m4]]>
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x128xf16, {{.*}}>) -> tensor<32x128xf16, [[shared_v8p1m8]]>
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma4w>
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma4w>
return
}
}
@@ -54,7 +54,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%cst0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma2w>
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]>
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]>
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<32x32xf16, #shared> * tensor<32x32xf16, #shared> -> tensor<32x32xf32, #mma2w>
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<32x32xf16, #shared> * tensor<32x32xf16, #shared> -> tensor<32x32xf32, #mma2w>
return
}
}
@@ -65,7 +65,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma1w>
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]>
// CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]>
%D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared> * tensor<16x16xf16, #shared> -> tensor<16x16xf32, #mma1w>
%D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #shared> * tensor<16x16xf16, #shared> -> tensor<16x16xf32, #mma1w>
return
}
}