diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index ce32d3914..0d1675b39 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -272,7 +272,7 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect, // // SPMD Ops // -def TT_GetProgramIdOp : TT_Op<"get_program_id"> { +def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoSideEffect]> { let arguments = (ins I32Attr:$axis); let results = (outs I32:$result); @@ -280,7 +280,7 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id"> { let assemblyFormat = "attr-dict `:` type($result)"; } -def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> { +def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> { let arguments = (ins I32Attr:$axis); let results = (outs I32:$result); @@ -301,7 +301,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect, $d = matrix_multiply($a, $b) + $c }]; - let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32); + let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB); let results = (outs TT_FpIntTensor:$d); @@ -324,6 +324,7 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect, ]; let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)"; + } // diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 5964c7c37..0a1a72bc3 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -328,4 +328,23 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { } +def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> { + let mnemonic = "dot_op"; + + let description = [{ +In TritonGPU dialect, considering `d = tt.dot a, b, c` +tt.dot's operands a and b must be of DotOperandEncodingAttr layout. +a's opIdx is 0, b's opIdx is 1. +The parend field in DotOperandEncodingAttr is the layout of d. + }]; + + let parameters = ( + ins + "unsigned":$opIdx, + "Attribute":$parent + ); + + let extraClassDeclaration = extraBaseClassDeclaration; +} + #endif diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 5d7f346d2..b88b80e7d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -14,7 +14,7 @@ class TTG_Op traits = []> : Op; def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", - [NoSideEffect]> { + [SameOperandsAndResultShape, NoSideEffect]> { let summary = "convert layout"; let arguments = (ins TT_Tensor:$src); @@ -32,10 +32,10 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { let assemblyFormat = "attr-dict"; } -// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU. -// This is needed because Arith's Cmp ops don't +// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU. +// This is needed because these ops don't // handle encodings -// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111 +// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111 def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> { let summary = "integer comparison operation"; @@ -48,7 +48,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> { let results = (outs TT_BoolLike:$result); } -def TTG_CmpFOp : TTG_Op<"cmpf"> { +def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect]> { let summary = "floating-point comparison operation"; let description = [{}]; @@ -60,6 +60,20 @@ def TTG_CmpFOp : TTG_Op<"cmpf"> { let results = (outs TT_BoolLike:$result); } +// TODO: migrate to arith::SelectOp on LLVM16 +def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> { + let summary = "select operation"; + + let description = [{}]; + + let arguments = (ins TT_BoolLike:$condition, + TT_Tensor:$true_value, + TT_Tensor:$false_value); + + let results = (outs TT_Tensor:$result); +} + + def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", [SameVariadicOperandSize, // MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should? diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 788d20eaa..13eff038c 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -1,6 +1,7 @@ #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "../PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -88,15 +89,7 @@ void populateArithmeticPatternsAndLegality( // non-null encoding // -------------- MLIRContext *context = patterns.getContext(); - // // Legality rule - // target.addDynamicallyLegalDialect( - // // TODO: check above rule here - // [](Operation *op){ - // return true; - // } - // ); - // Rewrite rule - // patterns.add(typeConverter, context); + // TODO: there's probably a better way to avoid adding all ops one-by-one patterns.add< ArithConstantPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, @@ -121,8 +114,35 @@ void populateArithmeticPatternsAndLegality( ArithCmpPattern, ArithCmpPattern, // Cast Ops - GenericOpPattern, GenericOpPattern>( - typeConverter, context); + GenericOpPattern, GenericOpPattern, + GenericOpPattern>(typeConverter, context); +} + +// this shouldn't exist if mlir's SelectOp checked encodings properly +class StdSelectPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = this->getTypeConverter()->convertType(op.getType()); + triton::gpu::SelectOp res = + rewriter.replaceOpWithNewOp( + op, retType, adaptor.getCondition(), adaptor.getTrueValue(), + adaptor.getFalseValue()); + return success(); + } +}; + +void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + // Rewrite rule + patterns.add(typeConverter, context); + target.addLegalOp(); // this is ok because all functions are inlined + // by the frontend } void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, @@ -231,7 +251,8 @@ struct TritonDotPattern : public OpConversionPattern { b = rewriter.create(b.getLoc(), dstType, b); } auto newDot = rewriter.replaceOpWithNewOp( - op, retType, a, b, adaptor.c(), adaptor.allowTF32()); + op, retType, a, b, adaptor.c(), adaptor.allowTF32(), adaptor.transA(), + adaptor.transB()); return success(); } }; @@ -418,6 +439,7 @@ public: // rewrite patterns RewritePatternSet patterns(context); // add rules + populateStdPatternsAndLegality(typeConverter, patterns, target); populateArithmeticPatternsAndLegality(typeConverter, patterns, target); populateMathPatternsAndLegality(typeConverter, patterns, target); populateTritonPatterns(typeConverter, patterns); diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index 1b84cc7f3..b8caf93ce 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -12,21 +12,21 @@ include "triton/Dialect/Triton/IR/TritonOps.td" // AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) // AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) def CombineDotAddIPattern : Pat< - (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)), - (TT_DotOp $a, $b, $d, $allowTF32), + (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)), + (TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB), [(Constraint> $c)]>; def CombineDotAddFPattern : Pat< - (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)), - (TT_DotOp $a, $b, $d, $allowTF32), + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB)), + (TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB), [(Constraint> $c)]>; def CombineDotAddIRevPattern : Pat< - (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d), - (TT_DotOp $a, $b, $d, $allowTF32), + (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d), + (TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB), [(Constraint> $c)]>; def CombineDotAddFRevPattern : Pat< - (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d), - (TT_DotOp $a, $b, $d, $allowTF32), + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $transA, $transB), $d), + (TT_DotOp $a, $b, $d, $allowTF32, $transA, $transB), [(Constraint> $c)]>; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index f9da54a64..12d88f5b1 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -484,6 +484,30 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes( return success(); } +//===----------------------------------------------------------------------===// +// DotOperand Encoding +//===----------------------------------------------------------------------===// +Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + unsigned opIdx = attrs.get("opIdx").cast().getInt(); + Attribute parent = attrs.get("parent"); + + return parser.getChecked(parser.getContext(), opIdx, + parent); +} + +void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "opIdx = " << getOpIdx() << ", " + << "parent = " << getParent() << "}>"; +} + //===----------------------------------------------------------------------===// // ASM Interface (i.e.: alias) //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 9f4b690bf..74cd31b2a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -34,6 +34,45 @@ namespace { // // ----------------------------------------------------------------------------- +// convert(blocked, dot_operand) -> +// convert(blocked, mma) + convert(mma, dot_operand) +// if this value is itself the result of a dot operation +// this is a hueiristics to accomodate some pattern seen in fused attention +// kernels. +// TODO: replace this by something more generic, i.e. layout-aware CSE +class DecomposeDotOperand : public mlir::RewritePattern { + +public: + DecomposeDotOperand(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), + 1, context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + if (!llvm::isa(op)) + return mlir::failure(); + auto convert = llvm::cast(op); + auto srcType = convert.getOperand().getType().cast(); + auto dstType = convert.getType().cast(); + if (srcType.getEncoding().isa() && + dstType.getEncoding().isa()) { + auto tmpType = + RankedTensorType::get(dstType.getShape(), dstType.getElementType(), + dstType.getEncoding() + .cast() + .getParent()); + auto tmp = rewriter.create( + convert.getLoc(), tmpType, convert.getOperand()); + auto newConvert = rewriter.create( + convert.getLoc(), dstType, tmp); + rewriter.replaceOp(op, {newConvert}); + return mlir::success(); + } + return mlir::failure(); + } +}; + // Layout conversions can't deduce their return type automatically. // IIUC they are therefore not handled by DRR right now class SimplifyConversion : public mlir::RewritePattern { @@ -47,6 +86,13 @@ public: mlir::PatternRewriter &rewriter) const override { if (!llvm::isa(op)) return mlir::failure(); + auto convert = llvm::cast(op); + auto srcType = convert.getOperand().getType().cast(); + auto dstType = convert.getType().cast(); + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristics to accomodate fused attention + if (dstType.getEncoding().isa()) + return mlir::failure(); // convert to the same layout -- we can delete if (op->getResultTypes() == op->getOperandTypes()) { rewriter.replaceOp(op, op->getOperands()); @@ -197,12 +243,16 @@ public: if (isSharedLayout(cvt->getResults()[0]) || isSharedLayout(cvt->getOperand(0))) return mlir::failure(); + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristics to accomodate fused attention auto targetType = cvt->getResultTypes()[0].cast(); + if (targetType.getEncoding().isa()) + return mlir::failure(); // DFS SetVector processed; SetVector layout; + llvm::MapVector toConvert; std::vector> queue; - std::vector> toConvert; queue.push_back({cvt, targetType.getEncoding()}); int numCvts = 1; while (!queue.empty()) { @@ -222,17 +272,20 @@ public: // add all operands to the queue for (Value argI : currOp->getOperands()) { Attribute newEncoding; + // cannot invert the current encoding for this operand + // we stop everything if (failed(invertEncoding(currLayout, currOp, newEncoding))) return mlir::failure(); - toConvert.push_back({argI, newEncoding}); + if (toConvert.count(argI) && toConvert[argI] != newEncoding) + return mlir::failure(); + // Operation *opArgI = argI.getDefiningOp(); - if (!opArgI) - continue; + toConvert.insert({argI, newEncoding}); if (!opArgI || processed.contains(opArgI) || (opArgI->getBlock() != cvt->getBlock())) continue; // if the conversion can be folded into opArgI then - // we actually haven't added anny conversion + // we don't count this conversion as expensive if (isa(*opArgI)) continue; @@ -246,31 +299,30 @@ public: if (numCvts > 0) return mlir::failure(); - FuncOp parentFunc = cvt->getParentOfType(); - bool test = cvt->getResult(0) - .getType() - .cast() - .getEncoding() - .isa(); - // if (test) - // llvm::outs() << "--------\nConverting " << *cvt << "\n---------\n"; + SmallVector sortedValues; + SetVector tmp; + for (auto it = toConvert.begin(); it != toConvert.end(); ++it) { + Value v = it->first; + if (v.getDefiningOp()) + tmp.insert(v.getDefiningOp()); + else + sortedValues.push_back(v); + } + tmp = mlir::topologicalSort(tmp); + for (Operation *op : tmp) + sortedValues.push_back(op->getResult(0)); + // llvm::outs() << "----\n"; BlockAndValueMapping mapping; - for (int i = toConvert.size() - 1; i >= 0; i--) { + for (Value currOperand : sortedValues) { // unpack information - Value currOperand; - Attribute targetLayout; - std::tie(currOperand, targetLayout) = toConvert[i]; - // if (test) - // llvm::outs() << "current " << currOperand << "\n"; + Attribute targetLayout = toConvert.lookup(currOperand); // rematerialize the operand if necessary Operation *currOperation = currOperand.getDefiningOp(); if (processed.contains(currOperation)) { currOperation = cloneWithInferType(rewriter, currOperation, mapping); currOperand = currOperation->getResult(0); } - if (i == 0) - break; // compute target type for the layout cast auto currType = currOperand.getType().cast(); auto newType = RankedTensorType::get( @@ -281,6 +333,7 @@ public: newOperand->moveAfter(currOperation); mapping.map(currOperand, newOperand); } + // llvm::outs() << cvt->getParentOfType() << "\n"; rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0))); return mlir::success(); } @@ -290,97 +343,71 @@ public: // // ----------------------------------------------------------------------------- -// This modifies the loop in-place -bool tryLegalizeOp(Operation *op, DenseSet toPreserve, - mlir::PatternRewriter &rewriter) { - auto targetType = toPreserve.begin()->getType().cast(); - auto newType = [&](RankedTensorType origType) { - return RankedTensorType::get(origType.getShape(), origType.getElementType(), - targetType.getEncoding()); - }; - bool hasSameTypes = op->getDialect()->getNamespace() == "arith" || - isa(op); - if (hasSameTypes) { - // replace argument types - for (auto arg : llvm::enumerate(op->getOperands())) { - auto argType = arg.value().getType().dyn_cast(); - if (toPreserve.count(arg.value()) || !argType) - continue; - auto newArg = rewriter.create( - rewriter.getUnknownLoc(), newType(argType), arg.value()); - newArg->moveBefore(op); - op->setOperand(arg.index(), newArg); - } - // replace result types - if (!isa(op)) - op->getResult(0).setType(op->getOperand(0).getType()); - return true; - } - return false; -} - -std::pair, scf::ForOp> -tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i, - Type newType) { - forOp.getInductionVar(); - auto newEncoding = newType.cast().getEncoding(); - auto ctx = forOp.getContext(); - auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; - // Rewrite init argument - Type origType = forOp.getInitArgs()[i].getType(); - SmallVector newInitArgs = forOp.getInitArgs(); - newInitArgs[i] = rewriter.create( - newInitArgs[i].getLoc(), newType, newInitArgs[i]); - // Clone for loop - scf::ForOp newForOp = rewriter.create( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newInitArgs); - newForOp->moveBefore(forOp); - rewriter.setInsertionPointToStart(newForOp.getBody()); - BlockAndValueMapping mapping; - for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) - mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); - mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); - // traverse all ops in the loop - for (Operation &op : forOp.getBody()->without_terminator()) { - // we clone the op - Operation *newOp = rewriter.clone(op, mapping); - // if any argument of this op has changed type, then the - // new operation is not legal and we should try to - // legalize it. - DenseSet modifiedTypes; - for (Value arg : op.getOperands()) { - if (mapping.contains(arg) && - mapping.lookup(arg).getType() != arg.getType()) - modifiedTypes.insert(mapping.lookup(arg)); - } - - bool shouldTryLegalize = !modifiedTypes.empty(); - if (shouldTryLegalize) - tryLegalizeOp(newOp, modifiedTypes, rewriter); - } - // create yield, inserting conversions if necessary - auto yieldOp = forOp.getBody()->getTerminator(); - SmallVector newYieldArgs; - for (Value arg : yieldOp->getOperands()) - newYieldArgs.push_back(mapping.lookup(arg)); - newYieldArgs[i] = rewriter.create( - yieldOp->getLoc(), newType, newYieldArgs[i]); - rewriter.create(forOp.getLoc(), newYieldArgs); - - // replace - SmallVector newResults = newForOp->getResults(); - newResults[i] = rewriter.create( - rewriter.getUnknownLoc(), origType, newForOp->getResult(i)); - newResults[i].getDefiningOp()->moveAfter(newForOp); - return {newResults, newForOp}; -} +// int test = 0; class MoveConvertOutOfLoop : public mlir::RewritePattern { public: MoveConvertOutOfLoop(mlir::MLIRContext *context) : mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {} + SmallVector + rematerializeForLoop(mlir::PatternRewriter &rewriter, scf::ForOp &forOp, + size_t i, RankedTensorType newType, + triton::gpu::ConvertLayoutOp origConversion) const { + + auto newEncoding = newType.cast().getEncoding(); + auto ctx = forOp.getContext(); + auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; + // Rewrite init argument + Type origType = forOp.getInitArgs()[i].getType(); + SmallVector newInitArgs = forOp.getInitArgs(); + newInitArgs[i] = rewriter.create( + newInitArgs[i].getLoc(), newType, newInitArgs[i]); + // Clone for loop + scf::ForOp newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newInitArgs); + newForOp->moveBefore(forOp); + rewriter.setInsertionPointToStart(newForOp.getBody()); + BlockAndValueMapping mapping; + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + mapping.map(origConversion.getResult(), newForOp.getRegionIterArgs()[i]); + // the iter arg of interest may have other uses than the conversion + // we're hoisting out of the loop. If that's the case we will + // need to add extra conversions for all uses... which is only useful + // if these extra conversions can be removed by another pattern + auto oldArg = forOp.getRegionIterArgs()[i]; + auto newArg = newForOp.getRegionIterArgs()[i]; + auto newArgFallback = rewriter.create( + newForOp.getLoc(), origType, newArg); + + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (Operation &op : forOp.getBody()->without_terminator()) { + if (&op == (Operation *)(&origConversion)) + continue; + Operation *newOp = rewriter.clone(op, mapping); + if (find(oldArg.getUsers(), &op) != oldArg.getUsers().end()) + newOp->replaceUsesOfWith(newArg, newArgFallback); + } + + // create yield, inserting conversions if necessary + auto yieldOp = forOp.getBody()->getTerminator(); + SmallVector newYieldArgs; + for (Value arg : yieldOp->getOperands()) + newYieldArgs.push_back(mapping.lookup(arg)); + newYieldArgs[i] = rewriter.create( + yieldOp->getLoc(), newType, newYieldArgs[i]); + rewriter.create(forOp.getLoc(), newYieldArgs); + + // replace + SmallVector newResults = newForOp->getResults(); + newResults[i] = rewriter.create( + rewriter.getUnknownLoc(), origType, newForOp->getResult(i)); + newResults[i].getDefiningOp()->moveAfter(newForOp); + return newResults; + } + mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const { @@ -388,17 +415,38 @@ public: auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; auto iterArgs = forOp.getRegionIterArgs(); for (auto iterArg : llvm::enumerate(iterArgs)) { + // if (iterArg.index() != 1) + // continue; // skip non-tensor types if (!iterArg.value().getType().isa()) continue; + // we only move `iterArg` out of the loop if + // - there is only a single conversion use + // - moving this conversion out of the loop will not generate + // any extra non-removable conversion + auto users = iterArg.value().getUsers(); + // check first condition + SetVector cvtTargetTypes; + for (auto user : users) + if (isa(user)) + cvtTargetTypes.insert(user->getResults()[0].getType()); + if (cvtTargetTypes.size() != 1) + continue; + // TODO: check second condition + for (auto user : users) { + if (isa(user)) + continue; + } // check for (auto op : iterArg.value().getUsers()) { - if (isa(op)) { - auto newFor = tryConvertIterArg(forOp, rewriter, iterArg.index(), - op->getResult(0).getType()); - rewriter.replaceOp(forOp, newFor.first); - return success(); - } + auto cvt = dyn_cast(op); + if (!cvt) + continue; + auto targetType = op->getResultTypes()[0].cast(); + auto newFor = rematerializeForLoop(rewriter, forOp, iterArg.index(), + targetType, cvt); + rewriter.replaceOp(forOp, newFor); + return success(); } } return failure(); @@ -434,20 +482,27 @@ public: mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter); if (cvtSlices.empty()) return failure(); - // if other operands are in the loop - // then we don't touch anything - Operation *op = cvtSlices.front(); - for (Value _arg : op->getOperands()) { - Operation *arg = _arg.getDefiningOp(); - if (arg && isInLoop(arg) && (arg != cvt)) + + for (Operation *op : cvtSlices) { + if (!op->hasTrait() && + !op->hasTrait()) return failure(); + for (Value arg : op->getOperands()) { + Operation *argOp = arg.getDefiningOp(); + if (argOp && (argOp != cvt) && + !isa(argOp)) { + return failure(); + } + } } + // otherwise, we push the conversion forward // since we'll be able to move it out of // the loop once it reaches the yield op // op(cvt(arg_0), arg_1, ..., arg_n) // -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n))) BlockAndValueMapping mapping; + auto op = cvtSlices.front(); for (Value arg : op->getOperands()) { if (arg.getDefiningOp() == cvt) mapping.map(arg, cvt.getOperand()); @@ -492,7 +547,7 @@ public: oldAcc.getLoc(), newRetType, oldAcc); auto newDot = rewriter.create( dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1), - newAcc, dotOp.allowTF32()); + newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB()); rewriter.replaceOpWithNewOp( op, oldRetType, newDot.getResult()); @@ -515,6 +570,7 @@ public: mlir::RewritePatternSet patterns(context); patterns.add(context); + patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index b303efab8..13f97d577 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -42,7 +42,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { - llvm_unreachable("Not implemented"); + llvm_unreachable("Argument rematerialization not implemented"); return llvm::None; }); @@ -50,7 +50,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // convert origValue to newValue addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { - llvm_unreachable("Not implemented"); + llvm_unreachable("Source rematerialization not implemented"); return llvm::None; }); diff --git a/python/src/triton.cc b/python/src/triton.cc index 1ea57061d..8ec74c7e2 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -165,7 +165,13 @@ void init_triton_ir(py::module &&m) { else { /* issue an warning */ } - }); + }) + .def("replace_all_uses_with", + [](mlir::Value &self, mlir::Value &newValue) { + self.replaceAllUsesWith(newValue); + }) + + ; py::class_(m, "block_arguement"); py::class_(m, "region") @@ -189,7 +195,7 @@ void init_triton_ir(py::module &&m) { if (self.getNumArguments() != 0) throw std::runtime_error( "This block has arguments, don't merge"); - dst.getOperations().splice(dst.end(), self.getOperations()); + dst.getOperations().splice(dst.begin(), self.getOperations()); self.dropAllUses(); self.erase(); }) @@ -262,7 +268,9 @@ void init_triton_ir(py::module &&m) { return mlir::succeeded(mlir::verify(self.getOperation())); }); // scf Ops - py::class_(m, "ForOp"); + py::class_(m, "ForOp") + .def("get_induction_var", &mlir::scf::ForOp::getInductionVar); + py::class_(m, "IfOp") .def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference) .def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference) @@ -501,24 +509,18 @@ void init_triton_ir(py::module &&m) { }) // Ops - .def("create_function", - [](mlir::OpBuilder &self, std::string name, - mlir::Type &funcType) -> mlir::FuncOp { - // TODO: loc - auto loc = self.getUnknownLoc(); - if (auto funcTy = funcType.dyn_cast()) { - return self.create(loc, name, funcTy); - } - throw std::runtime_error("invalid function type"); - }) .def("get_or_insert_function", [](mlir::OpBuilder &self, mlir::ModuleOp &module, - std::string &funcName, mlir::Type &funcType) -> mlir::FuncOp { + std::string &funcName, mlir::Type &funcType, + std::string &visibility) -> mlir::FuncOp { if (mlir::Operation *funcOperation = module.lookupSymbol(funcName)) return llvm::dyn_cast(funcOperation); auto loc = self.getUnknownLoc(); if (auto funcTy = funcType.dyn_cast()) { - return self.create(loc, funcName, funcTy); + mlir::ArrayRef attrs = { + mlir::NamedAttribute(self.getStringAttr("sym_visibility"), + self.getStringAttr(visibility))}; + return self.create(loc, funcName, funcTy, attrs); } throw std::runtime_error("invalid function type"); }) @@ -648,6 +650,12 @@ void init_triton_ir(py::module &&m) { return self.create(loc, input, self.getIndexType()); }) + .def("create_index_to_si", + [](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, input, + self.getI32Type()); + }) .def("create_fmul", [](mlir::OpBuilder &self, mlir::Value &lhs, @@ -1065,10 +1073,11 @@ void init_triton_ir(py::module &&m) { }) .def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, - mlir::Value &c, bool allowTF32) -> mlir::Value { + mlir::Value &c, bool allowTF32, bool transA, + bool transB) -> mlir::Value { auto loc = self.getUnknownLoc(); return self.create(loc, c.getType(), a, b, c, - allowTF32); + allowTF32, transA, transB); }) .def("create_exp", [](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value { @@ -1095,7 +1104,6 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); return self.create(loc, val); }) - // .def("create_trans", &ir::builder::create_trans, ret::reference) .def("create_reduce", [](mlir::OpBuilder &self, mlir::Value &operand, mlir::triton::RedOp redOp, int axis) -> mlir::Value { @@ -1118,13 +1126,7 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); return self.create(loc, condition, trueValue, falseValue); - }) - // // Intrinsics - // // These have no place in the IR, and hopefully they can be removed at - // some point .def("create_umulhi", &ir::builder::create_umulhi, - // ret::reference) .def("create_barrier", &ir::builder::create_barrier, - // ret::reference); - ; + }); py::class_(m, "pass_manager") .def(py::init()) @@ -1144,8 +1146,11 @@ void init_triton_ir(py::module &&m) { printingFlags); }) .def("run", - [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool { - return mlir::succeeded(self.run(mod.getOperation())); + [](mlir::PassManager &self, mlir::ModuleOp &mod) { + // TODO: maybe dump module to file and print error for better + // diagnostics + if (mlir::failed(self.run(mod.getOperation()))) + throw std::runtime_error("PassManager::run failed"); }) .def( "add_sccp_pass", @@ -1168,6 +1173,10 @@ void init_triton_ir(py::module &&m) { }) .def("add_cse_pass", [](mlir::PassManager &self) { self.addPass(mlir::createCSEPass()); }) + .def("add_licm_pass", + [](mlir::PassManager &self) { + self.addPass(mlir::createLoopInvariantCodeMotionPass()); + }) .def("add_triton_combine_pass", [](mlir::PassManager &self) { self.addPass(mlir::triton::createCombineOpsPass()); diff --git a/python/tests/test_type.py b/python/tests/test_type.py index 07de3ce27..8580b967a 100644 --- a/python/tests/test_type.py +++ b/python/tests/test_type.py @@ -54,7 +54,7 @@ def test_binop_type_check(): kernel = triton.compiler._compile(binop_type_check, signature="*fp32", device=0, - output="ttgir") + output="ttir") assert (kernel) # TODO: Check types of the results @@ -75,6 +75,6 @@ def test_reduce_type_check(): kernel = triton.compiler._compile(reduce_type_check, signature="*fp32", device=0, - output="ttgir") + output="ttir") assert (kernel) # TODO: Check types of the results diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 355bfc605..7b9f99d8f 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -64,12 +64,12 @@ def mangle_ty(ty): return 'fp32' if ty.is_fp64(): return 'fp64' - if ty.is_void(): - return 'V' if ty.is_block(): elt = mangle_ty(ty.scalar) shape = '_'.join(map(str, ty.shape)) return f'{elt}S{shape}S' + if ty.is_void(): + return 'V' assert False, "Unsupported type" @@ -212,7 +212,8 @@ class CodeGenerator(ast.NodeVisitor): init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) self.visit(init_node) # initialize function - fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder)) + visibility = "public" if self.is_kernel else "private" + fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility) self.module.push_back(fn) entry = fn.add_entry_block() arg_values = [] @@ -585,6 +586,12 @@ class CodeGenerator(ast.NodeVisitor): lb = self.builder.create_to_index(lb) ub = self.builder.create_to_index(ub) step = self.builder.create_to_index(step) + # Create placeholder for the loop induction variable + # We can use any value because the variable isn't a constexpr + # but use a distinctive value (of the right type) to ease debugging + st_target = ast.Name(id=node.target.id, ctx=ast.Store()) + init_node = ast.Assign(targets=[st_target], value=ast.Num(value=0xBADF00D)) + self.visit(init_node) with enter_sub_region(self) as sr: liveins, insert_block = sr @@ -609,13 +616,22 @@ class CodeGenerator(ast.NodeVisitor): names.append(name) init_args.append(triton.language.core._to_tensor(liveins[name], self.builder)) yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder)) + # create ForOp self.builder.set_insertion_point_to_end(insert_block) for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) block.merge_block_before(for_op.get_body(0)) + + # update induction variable with actual value, and replace all uses + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + iv = self.builder.create_index_to_si(for_op.get_induction_var()) + self.lscope[node.target.id].handle.replace_all_uses_with(iv) + self.set_value(name, triton.language.core.tensor(iv, triton.language.core.int32)) + # create YieldOp self.builder.set_insertion_point_to_end(for_op.get_body(0)) - self.builder.create_yield_op([y.handle for y in yields]) + if len(yields) > 0: + self.builder.create_yield_op([y.handle for y in yields]) for_op_region = for_op.get_body(0).get_parent() assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" # replace global uses with block arguments @@ -625,8 +641,7 @@ class CodeGenerator(ast.NodeVisitor): # update lscope & local_defs (ForOp defines new values) for i, name in enumerate(names): - self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type) - self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type) + self.set_value(name, triton.language.core.tensor(for_op.get_result(i), yields[i].type)) for stmt in node.orelse: assert False, "Don't know what to do with else after for" @@ -672,7 +687,7 @@ class CodeGenerator(ast.NodeVisitor): ret_type = triton.language.void prototype = triton.language.function_type([ret_type], arg_types) gscope = sys.modules[fn.fn.__module__].__dict__ - generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types) + generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types) generator.visit(fn.parse()) callee_ret_type = generator.last_ret_type self.function_ret_types[fn_name] = callee_ret_type @@ -839,18 +854,16 @@ def optimize_triton_ir(mod): pm = _triton.ir.pass_manager(mod.context) pm.enable_debug() pm.add_inliner_pass() + pm.add_triton_combine_pass() pm.add_canonicalizer_pass() + pm.add_cse_pass() + pm.add_licm_pass() pm.run(mod) return mod def make_tritongpu_ir(mod, num_warps): pm = _triton.ir.pass_manager(mod.context) - pm.enable_debug() - pm.add_inliner_pass() - pm.add_triton_combine_pass() - pm.add_canonicalizer_pass() - pm.add_cse_pass() pm.add_convert_triton_to_tritongpu_pass(num_warps) pm.run(mod) return mod @@ -864,6 +877,7 @@ def optimize_tritongpu_ir(mod, num_stages): pm.add_cse_pass() pm.add_coalesce_pass() pm.add_triton_gpu_combine_pass() + pm.add_licm_pass() pm.add_triton_gpu_swizzle_pass() pm.add_triton_gpu_combine_pass() pm.add_cse_pass() diff --git a/python/triton/language/core.py b/python/triton/language/core.py index b253a3491..402c6cb86 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -741,7 +741,7 @@ def reshape(input, shape, _builder=None): @builtin -def dot(input, other, allow_tf32=True, _builder=None): +def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=None): """ Returns the matrix product of two blocks. @@ -753,7 +753,7 @@ def dot(input, other, allow_tf32=True, _builder=None): :type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} """ allow_tf32 = _constexpr_to_value(allow_tf32) - return semantic.dot(input, other, allow_tf32, _builder) + return semantic.dot(input, other, allow_tf32, trans_a, trans_b, _builder) # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 9bf1261fb..aaa19d7f9 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -930,6 +930,8 @@ def atomic_xchg(ptr: tl.tensor, def dot(lhs: tl.tensor, rhs: tl.tensor, allow_tf32: bool, + trans_a: bool, + trans_b: bool, builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() if lhs.type.scalar.is_int(): @@ -938,11 +940,11 @@ def dot(lhs: tl.tensor, else: _0 = builder.get_float32(0) ret_scalar_ty = tl.float32 - M = lhs.type.shape[0] - N = rhs.type.shape[1] + M = lhs.type.shape[1 if trans_a else 0] + N = rhs.type.shape[0 if trans_b else 1] _0 = builder.create_splat(_0, [M, N]) ret_ty = tl.block_type(ret_scalar_ty, [M, N]) - return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, trans_a, trans_b), ret_ty) diff --git a/python/triton/tools/aot.py b/python/triton/tools/aot.py index c1b6010df..72df49d4c 100644 --- a/python/triton/tools/aot.py +++ b/python/triton/tools/aot.py @@ -6,7 +6,7 @@ import triton._C.libtriton.triton as libtriton if __name__ == '__main__': # valid source and target formats - VALID_FORMATS = ['llvm-ir', 'ptx', 'triton-ir', 'triton-gpu-ir'] + VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx'] # set up the argument parser # TODO: conditional requirements diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py new file mode 100644 index 000000000..a5bce4454 --- /dev/null +++ b/python/tutorials/06-fused-attention.py @@ -0,0 +1,301 @@ +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) +""" + +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + # loop over k, v and update accumulator + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + # start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + start_n * stride_kn) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + qk *= sm_scale + qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + start_n * stride_vk) + p = p.to(tl.float16) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_i) + tl.store(m_ptrs, m_i) + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + + +@triton.jit +def _bwd_preprocess( + Out, DO, L, + NewDO, Delta, + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + +@triton.jit +def _bwd_kernel( + Q, K, V, sm_scale, Out, DO, + DQ, DK, DV, + L, M, + D, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + Z, H, N_CTX, + num_block, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_qz + off_h * stride_qh + V += off_z * stride_qz + off_h * stride_qh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, k, trans_b=True) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(p.to(tl.float16), do, trans_a=True) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, v, trans_b=True) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(ds.to(tl.float16), q, trans_a=True) + # # compute dq + dq = tl.load(dq_ptrs) + dq += tl.dot(ds.to(tl.float16), k) + tl.store(dq_ptrs, dq) + # # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) + tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + + _fwd_kernel[grid]( + q, k, v, sm_scale, + tmp, L, m, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, num_warps=num_warps, + num_stages=1, + ) + ctx.save_for_backward(q, k, v, o, L, m) + ctx.BLOCK = BLOCK + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, do, l, + do_scaled, delta, + BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + ) + + num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8 + _bwd_kernel[(ctx.grid[1],)]( + q, k, v, ctx.sm_scale, + o, do_scaled, + dq, dk, dv, + l, m, + delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + q.shape[0], q.shape[1], q.shape[2], + ctx.grid[0], + BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, + num_stages=1, + ) + return dq, dk, dv, None + + +attention = _attention.apply + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)]) +def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() + sm_scale = 0.3 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + for z in range(Z): + for h in range(H): + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # triton implementation + tri_out = attention(q, k, v, sm_scale) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + triton.testing.assert_almost_equal(ref_out, tri_out) + triton.testing.assert_almost_equal(ref_dv, tri_dv) + triton.testing.assert_almost_equal(ref_dk, tri_dk) + triton.testing.assert_almost_equal(ref_dq, tri_dq) diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 550455ccc..d3873e127 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -24,7 +24,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL> // CHECK-NEXT: %6 -> %6 %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 8b88f1787..9c7e7fc66 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -29,7 +29,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B // CHECK-NEXT: offset = 8192, size = 8192 %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> @@ -59,11 +59,11 @@ func @reusable(%A : !tt.ptr) { %a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> // CHECK-NEXT: offset = 16384, size = 8192 %a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> - %c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> // CHECK-NEXT: offset = 0, size = 8192 %a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A> - %c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> return // CHECK-NEXT: size = 24576 } diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index e4caf6294..a1c3eab76 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -27,7 +27,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> // CHECK: Membar 13 - %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d06671b8a..5c9ead51c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -15,11 +15,3 @@ set(TRITON_TEST_DEPENDS triton-opt FileCheck ) - -add_lit_testsuite(check-triton "Running the triton regression tests" - ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${TRITON_TEST_DEPENDS} - ) -set_target_properties(check-triton PROPERTIES FOLDER "Tests") - -add_lit_testsuites(TRITON ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TRITON_TEST_DEPENDS}) diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir index 15e3fc28c..55cb9049f 100644 --- a/test/Conversion/triton_ops.mlir +++ b/test/Conversion/triton_ops.mlir @@ -113,13 +113,13 @@ func @dot_ops_infer(%ptr: !tt.ptr, %v : f32) { %zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> - %r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> + %r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32> - %r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32> + %r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> - %r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32> + %r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32> - %r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32> + %r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32> %ptr128x128 = tt.splat %ptr : (!tt.ptr) -> tensor<128x128x!tt.ptr> %ptr32x32 = tt.splat %ptr : (!tt.ptr) -> tensor<32x32x!tt.ptr> diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index 7beb42b69..df5b85050 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -5,7 +5,7 @@ func @ops() { %a = arith.constant dense<1.00e+00> : tensor<128x32xf16> %b = arith.constant dense<2.00e+00> : tensor<32x128xf16> %c = arith.constant dense<3.00e+00> : tensor<128x128xf32> - %0 = tt.dot %a, %b, %c {allowTF32 = true} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> + %0 = tt.dot %a, %b, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> return } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 050c95845..eb9bc404d 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -662,7 +662,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 - %D = tt.dot %AA, %BB, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0> + %D = tt.dot %AA, %BB, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0> return } diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 7536e08f7..50ab4c456 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -11,12 +11,12 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32 %zero = arith.constant dense<0.0> : tensor<128x128xf32> %d = arith.constant dense<3.0> : tensor<128x128xf32> - %dot_out = tt.dot %a, %b, %zero {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + %dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> - // CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> %res0 = arith.addf %dot_out, %d : tensor<128x128xf32> - // CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> %res1 = arith.addf %d, %dot_out : tensor<128x128xf32> return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32> diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index f24dd67f1..ec0a06aee 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -53,7 +53,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> @@ -110,7 +110,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> @@ -160,7 +160,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir index fc15bdc04..962c8a89f 100644 --- a/test/TritonGPU/matmul.mlir +++ b/test/TritonGPU/matmul.mlir @@ -60,9 +60,9 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6 %45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr> %46 = arith.index_cast %arg5 : i32 to index %47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>) { - %76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32> - %77 = tt.load %arg15, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32> - %78 = tt.dot %76, %77, %cst_0 {allowTF32 = true} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> + %76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32> + %77 = tt.load %arg15, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32> + %78 = tt.dot %76, %77, %cst_0 {allowTF32 = true, transA = false, transB = false} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> %79 = arith.addf %arg13, %78 : tensor<64x64xf32> %80 = arith.muli %arg7, %c64_i32 : i32 %81 = tt.splat %80 : (i32) -> tensor<64x64xi32> diff --git a/test/TritonGPU/swizzle.mlir b/test/TritonGPU/swizzle.mlir index 8fd4d81db..bb97bf1cb 100644 --- a/test/TritonGPU/swizzle.mlir +++ b/test/TritonGPU/swizzle.mlir @@ -20,7 +20,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma8w> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x256xf16, {{.*}}>) -> tensor<64x256xf16, [[shared_v8p1m8]]> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x64xf16, #shared> * tensor<64x256xf16, #shared> -> tensor<128x256xf32, #mma8w> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #shared> * tensor<64x256xf16, #shared> -> tensor<128x256xf32, #mma8w> return } } @@ -32,7 +32,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x128xf16, {{.*}}>) -> tensor<64x128xf16, [[shared_v8p1m8]]> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x64xf16, #shared> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma4w> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #shared> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma4w> return } } @@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x32xf16, {{.*}}>) -> tensor<128x32xf16, [[shared_v8p2m4]]> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x128xf16, {{.*}}>) -> tensor<32x128xf16, [[shared_v8p1m8]]> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma4w> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma4w> return } } @@ -54,7 +54,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %cst0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma2w> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<32x32xf16, #shared> * tensor<32x32xf16, #shared> -> tensor<32x32xf32, #mma2w> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<32x32xf16, #shared> * tensor<32x32xf16, #shared> -> tensor<32x32xf32, #mma2w> return } } @@ -65,7 +65,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma1w> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared> * tensor<16x16xf16, #shared> -> tensor<16x16xf32, #mma1w> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #shared> * tensor<16x16xf16, #shared> -> tensor<16x16xf32, #mma1w> return } }