Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)

This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
Philippe Tillet
2022-12-21 01:30:50 -08:00
committed by GitHub
parent 8650b4d1cb
commit 20100a7254
285 changed files with 26312 additions and 50143 deletions

View File

@@ -0,0 +1,2 @@
add_subdirectory(Triton)
add_subdirectory(TritonGPU)

View File

@@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,20 @@
add_mlir_dialect_library(TritonIR
Interfaces.cpp
Dialect.cpp
Ops.cpp
Types.cpp
Traits.cpp
DEPENDS
TritonTableGen
LINK_LIBS PUBLIC
MLIRIR
MLIRArithmetic
MLIRSCF
# Since LLVM 15
# MLIRFunc
# else
MLIRStandard
)

View File

@@ -0,0 +1,51 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
using namespace mlir;
using namespace mlir::triton;
//===----------------------------------------------------------------------===//
// TritonDialect Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct TritonInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
BlockAndValueMapping &) const final {
return true;
}
};
} // namespace
void TritonDialect::initialize() {
registerTypes();
addOperations<
#define GET_OP_LIST
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
>();
// We can also add interface here.
addInterfaces<TritonInlinerInterface>();
}
Operation *TritonDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<arith::ConstantOp>(loc, type, value);
}

View File

View File

@@ -0,0 +1,346 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h"
namespace mlir {
namespace triton {
// Type inference
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return i1Type;
}
static Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i32Type,
tensorType.getEncoding());
return i32Type;
}
static Type getPointerTypeSameShape(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Type elementType = tensorType.getElementType();
auto shape = tensorType.getShape();
PointerType ptrType = PointerType::get(elementType, 1);
return RankedTensorType::get(shape, ptrType, tensorType.getEncoding());
} else {
return PointerType::get(type, 1);
}
}
// Parser & printer for assembly forms
ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> allOperands;
Type resultTypes[1];
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(resultTypes[0]))
return failure();
result.addTypes(resultTypes);
SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeSameShape(resultTypes[0])); // ptr
int hasMask = 0, hasOther = 0;
if (allOperands.size() >= 2) {
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
hasMask = 1;
}
if (allOperands.size() >= 3) {
operandTypes.push_back(resultTypes[0]); // other
hasOther = 1;
}
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
// Deduce operand_segment_sizes from the number of the operands.
auto operand_segment_sizesAttrName =
LoadOp::operand_segment_sizesAttrName(result.name);
result.addAttribute(
operand_segment_sizesAttrName,
parser.getBuilder().getI32VectorAttr({1, hasMask, hasOther}));
return success();
}
void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
printer << " ";
printer << loadOp.getOperation()->getOperands();
// "operand_segment_sizes" can be deduced, so we don't print it.
printer.printOptionalAttrDict(loadOp->getAttrs(),
{loadOp.operand_segment_sizesAttrName()});
printer << " : ";
printer.printStrippedAttrOrType(loadOp.result().getType());
}
ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> allOperands;
Type valueType;
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(valueType))
return failure();
SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeSameShape(valueType)); // ptr
operandTypes.push_back(valueType); // value
if (allOperands.size() >= 3)
operandTypes.push_back(getI1SameShape(valueType)); // mask
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
return success();
}
void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
printer << " ";
printer << storeOp.getOperation()->getOperands();
printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{});
printer << " : ";
printer.printStrippedAttrOrType(storeOp.value().getType());
}
} // namespace triton
} // namespace mlir
#define GET_OP_CLASSES
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
// enum attribute definitions
#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc"
namespace mlir {
namespace triton {
//-- FpToFpOp --
bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,
::mlir::TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
auto srcEltType = inputs.front();
auto dstEltType = outputs.front();
auto srcTensorType = srcEltType.dyn_cast<mlir::RankedTensorType>();
auto dstTensorType = dstEltType.dyn_cast<mlir::RankedTensorType>();
if (srcTensorType && dstTensorType) {
srcEltType = srcTensorType.getElementType();
dstEltType = dstTensorType.getElementType();
}
// Check whether fp8 <=> fp16, bf16, f32, f64
// Make `srcEltType` always the fp8 side
if (dstEltType.dyn_cast<mlir::triton::Float8Type>())
std::swap(srcEltType, dstEltType);
if (!srcEltType.dyn_cast<mlir::triton::Float8Type>())
return false;
return dstEltType.isF16() || dstEltType.isBF16() || dstEltType.isF32() ||
dstEltType.isF64();
}
//-- StoreOp --
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value) {
StoreOp::build(builder, state, ptr, value, mlir::Value());
}
//-- LoadOp --
static Type getLoadOpResultType(::mlir::OpBuilder &builder, Type ptrType) {
auto ptrTensorType = ptrType.dyn_cast<RankedTensorType>();
if (!ptrTensorType)
return ptrType.cast<PointerType>().getPointeeType();
auto shape = ptrTensorType.getShape();
Type elementType =
ptrTensorType.getElementType().cast<PointerType>().getPointeeType();
return RankedTensorType::get(shape, elementType);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
LoadOp::build(builder, state, ptr, mlir::Value(), mlir::Value(), cache, evict,
isVolatile);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
LoadOp::build(builder, state, ptr, mask, mlir::Value(), cache, evict,
isVolatile);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
Type resultType = getLoadOpResultType(builder, ptr.getType());
state.addOperands(ptr);
if (mask) {
state.addOperands(mask);
if (other) {
state.addOperands(other);
}
}
state.addAttribute(
operand_segment_sizesAttrName(state.name),
builder.getI32VectorAttr({1, (mask ? 1 : 0), (other ? 1 : 0)}));
state.addAttribute(
cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(
evictAttrName(state.name),
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name),
builder.getBoolAttr(isVolatile));
state.addTypes({resultType});
}
//-- DotOp --
mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// type is the same as the accumulator
auto accTy = operands[2].getType().cast<RankedTensorType>();
inferredReturnTypes.push_back(accTy);
// verify encodings
auto aEnc = operands[0].getType().cast<RankedTensorType>().getEncoding();
auto bEnc = operands[1].getType().cast<RankedTensorType>().getEncoding();
auto retEnc = accTy.getEncoding();
if (aEnc) {
assert(bEnc);
Dialect &dialect = aEnc.getDialect();
auto interface = dyn_cast<DialectInferLayoutInterface>(&dialect);
if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed())
return mlir::failure();
if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed())
return mlir::failure();
}
return mlir::success();
}
//-- ReduceOp --
mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// infer shape
Value arg = operands[0];
auto argTy = arg.getType().cast<RankedTensorType>();
auto argEltTy = argTy.getElementType();
auto i32Ty = IntegerType::get(argEltTy.getContext(), 32);
auto redOp =
attributes.get("redOp").cast<mlir::triton::RedOpAttr>().getValue();
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
auto retEltTy = withIndex ? i32Ty : argEltTy;
auto retShape = argTy.getShape().vec();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
retShape.erase(retShape.begin() + axis);
if (retShape.empty()) {
// 0d-tensor -> scalar
inferredReturnTypes.push_back(retEltTy);
} else {
// nd-tensor where n >= 1
// infer encoding
Attribute argEncoding = argTy.getEncoding();
Attribute retEncoding;
if (argEncoding) {
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface =
dyn_cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
.failed()) {
llvm::report_fatal_error("failed to infer layout for ReduceOp");
return mlir::failure();
}
}
// create type
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, retEltTy, retEncoding));
}
return mlir::success();
}
bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) {
return redOp == mlir::triton::RedOp::ARGMIN ||
redOp == mlir::triton::RedOp::ARGMAX ||
redOp == mlir::triton::RedOp::ARGUMIN ||
redOp == mlir::triton::RedOp::ARGUMAX ||
redOp == mlir::triton::RedOp::ARGFMIN ||
redOp == mlir::triton::RedOp::ARGFMAX;
}
//-- SplatOp --
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
return ret;
}
//-- ExpandDimsOp --
mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
MLIRContext *context, Optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// infer shape
auto arg = operands[0];
auto argTy = arg.getType().cast<RankedTensorType>();
auto retShape = argTy.getShape().vec();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
retShape.insert(retShape.begin() + axis, 1);
// infer encoding
Attribute argEncoding = argTy.getEncoding();
Attribute retEncoding;
if (argEncoding) {
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc)
.failed())
return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp");
}
// create type
auto argEltTy = argTy.getElementType();
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, argEltTy, retEncoding));
return mlir::success();
}
//-- BroadcastOp --
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
auto value = constOperand.getValue();
if (auto denseElemsAttr = value.dyn_cast<DenseElementsAttr>()) {
if (!denseElemsAttr.isSplat())
return {};
return SplatElementsAttr::get(shapedType,
denseElemsAttr.getSplatValue<Attribute>());
} else if (value.getType().isIntOrIndexOrFloat()) {
return SplatElementsAttr::get(shapedType, value);
} else {
return {};
}
}
} // namespace triton
} // namespace mlir

View File

@@ -0,0 +1,71 @@
#include "triton/Dialect/Triton/IR/Traits.h"
static mlir::LogicalResult verifySameEncoding(mlir::Type tyA, mlir::Type tyB) {
using namespace mlir;
auto encA = tyA.dyn_cast<RankedTensorType>();
auto encB = tyA.dyn_cast<RankedTensorType>();
if (!encA || !encB)
return success();
return encA.getEncoding() == encB.getEncoding() ? success() : failure();
}
mlir::LogicalResult
mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) {
if (failed(verifyAtLeastNOperands(op, 1)) ||
failed(verifyAtLeastNResults(op, 1)))
return failure();
auto type = op->getOperand(0).getType();
for (auto resultType : op->getResultTypes())
if (failed(verifySameEncoding(resultType, type)))
return op->emitOpError()
<< "requires the same encoding for all operands and results";
return verifySameOperandsEncoding(op);
}
mlir::LogicalResult
mlir::OpTrait::impl::verifySameOperandsEncoding(Operation *op) {
if (failed(verifyAtLeastNOperands(op, 1)))
return failure();
auto type = op->getOperand(0).getType();
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
if (failed(verifySameEncoding(opType, type)))
return op->emitOpError() << "requires the same encoding for all operands";
return success();
}
mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
for (auto opType : op->getOperandTypes()) {
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
int64_t numElements = 1;
for (int64_t s : tensorType.getShape())
numElements *= s;
if (numElements > maxTensorNumElements)
return op->emitError("Maximum allowed number of elements is ")
<< maxTensorNumElements << ", but " << *op
<< " has more than that";
if ((numElements & (numElements - 1)) != 0)
return op->emitError("Number of elements must be power-of-two, but ")
<< *op << " doesn't follow the rule (" << numElements << ")"
<< " elements";
}
}
for (auto opType : op->getResultTypes()) {
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
int64_t numElements = 1;
for (int64_t s : tensorType.getShape())
numElements *= s;
if (numElements > maxTensorNumElements)
return op->emitError("Maximum allowed number of elements is ")
<< maxTensorNumElements << ", but " << *op
<< " has more than that";
if ((numElements & (numElements - 1)) != 0)
return op->emitError("Number of elements must be power-of-two, but ")
<< *op << " doesn't follow the rule (" << numElements << ")"
<< " elements";
}
}
return success();
}

View File

@@ -0,0 +1,39 @@
#include "triton/Dialect/Triton/IR/Types.h"
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
using namespace mlir;
using namespace mlir::triton;
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
//===----------------------------------------------------------------------===//
// Triton Dialect
//===----------------------------------------------------------------------===//
void TritonDialect::registerTypes() {
addTypes<
#define GET_TYPEDEF_LIST
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
>();
}
Type PointerType::parse(AsmParser &parser) {
if (parser.parseLess())
return Type();
Type pointeeType;
if (parser.parseType(pointeeType))
return Type();
if (parser.parseGreater())
return Type();
// TODO: also print address space?
return PointerType::get(pointeeType, 1);
}
void PointerType::print(AsmPrinter &printer) const {
printer << "<" << getPointeeType() << ">";
}

View File

@@ -0,0 +1,11 @@
set(LLVM_TARGET_DEFINITIONS Combine.td)
mlir_tablegen(TritonCombine.inc -gen-rewriters)
add_public_tablegen_target(TritonCombineIncGen)
add_mlir_dialect_library(TritonTransforms
Combine.cpp
DEPENDS
TritonTransformsIncGen
TritonCombineIncGen
)

View File

@@ -0,0 +1,209 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include <memory>
using namespace mlir;
namespace {
bool isZero(mlir::Value val) {
if (mlir::matchPattern(val, mlir::m_Zero()) ||
mlir::matchPattern(val, mlir::m_AnyZeroFloat()))
return true;
// broadcast(constant_0)
if (auto bc = val.getDefiningOp<mlir::triton::BroadcastOp>()) {
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat()))
return true;
}
return false;
}
bool isBroadcastConstantCombinable(Attribute value) {
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
return denseValue.isSplat();
}
return value.isa<FloatAttr, IntegerAttr>();
}
DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
Value bcast_res) {
Type resType = bcast_res.getType();
DenseElementsAttr res;
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
res =
DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
} else {
res = DenseElementsAttr::get(resType, value);
}
return res;
}
#include "TritonCombine.inc"
} // anonymous namespace
// select(cond, load(ptrs, broadcast(cond), ???), other)
// => load(ptrs, broadcast(cond), other)
class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {
public:
CombineSelectMaskedLoadPattern(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::SelectOp::getOperationName(), 3, context,
{triton::LoadOp::getOperationName()}) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto selectOp = llvm::dyn_cast<mlir::SelectOp>(op);
if (!selectOp)
return mlir::failure();
mlir::Value trueValue = selectOp.getTrueValue();
mlir::Value falseValue = selectOp.getFalseValue();
auto *loadOpCandidate = trueValue.getDefiningOp();
auto loadOp = llvm::dyn_cast_or_null<triton::LoadOp>(loadOpCandidate);
if (!loadOp)
return mlir::failure();
mlir::Value mask = loadOp.mask();
if (!mask)
return mlir::failure();
auto *broadcastOpCandidate = mask.getDefiningOp();
auto broadcastOp =
llvm::dyn_cast_or_null<triton::BroadcastOp>(broadcastOpCandidate);
if (!broadcastOp)
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, loadOp.ptr(), loadOp.mask(), falseValue, loadOp.cache(),
loadOp.evict(), loadOp.isVolatile());
return mlir::success();
}
};
// load(ptr, splat(1), ...) -> load(ptr, ...)
// load(ptr, splat(0), other, ...) -> other
struct CanonicalizeMaskedLoadPattern
: public mlir::OpRewritePattern<triton::LoadOp> {
CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context)
: OpRewritePattern<triton::LoadOp>(context, 1) {}
mlir::LogicalResult
matchAndRewrite(triton::LoadOp loadOp,
mlir::PatternRewriter &rewriter) const override {
auto mask = loadOp.mask();
if (!mask)
return mlir::failure();
auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
if (!splatMask)
return mlir::failure();
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
// mask = splat(1)
rewriter.replaceOpWithNewOp<triton::LoadOp>(
loadOp, loadOp.getType(), loadOp.ptr(), Value(), Value(),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile());
} else {
// mask = splat(0)
// If there's no "other", the value is "undef". Perhaps we want to
// optimize it in the future.x
auto otherVal = loadOp.other();
if (!otherVal)
return mlir::failure();
rewriter.replaceOp(loadOp, otherVal);
}
return mlir::success();
}
};
void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CanonicalizeMaskedLoadPattern>(context);
}
// store(ptr, value, splat(1), ...) -> store(ptr, value, ...)
// store(ptr, value, splat(0), ...) -> [none]
struct CanonicalizeMaskedStorePattern
: public mlir::OpRewritePattern<triton::StoreOp> {
CanonicalizeMaskedStorePattern(mlir::MLIRContext *context)
: OpRewritePattern<triton::StoreOp>(context, 1) {}
mlir::LogicalResult
matchAndRewrite(triton::StoreOp storeOp,
mlir::PatternRewriter &rewriter) const override {
auto mask = storeOp.mask();
if (!mask)
return mlir::failure();
auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
if (!splatMask)
return mlir::failure();
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
// mask = splat(1)
rewriter.replaceOpWithNewOp<triton::StoreOp>(storeOp, storeOp.ptr(),
storeOp.value());
} else {
// mask = splat(0)
rewriter.eraseOp(storeOp);
}
return mlir::success();
}
};
void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CanonicalizeMaskedStorePattern>(context);
}
#define GEN_PASS_CLASSES
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
public:
void runOnOperation() override {
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
mlir::ModuleOp m = getOperation();
// Dot Add %{
patterns.add<CombineDotAddIPattern>(context);
patterns.add<CombineDotAddFPattern>(context);
patterns.add<CombineDotAddIRevPattern>(context);
patterns.add<CombineDotAddFRevPattern>(context);
// %}
patterns.add<CombineSelectMaskedLoadPattern>(context);
// patterns.add<CombineAddPtrPattern>(context);
patterns.add<CombineBroadcastConstantPattern>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
}
};
std::unique_ptr<mlir::Pass> mlir::triton::createCombineOpsPass() {
return std::make_unique<CombineOpsPass>();
}

View File

@@ -0,0 +1,48 @@
#ifndef TRITON_PATTERNS
#define TRITON_PATTERNS
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td"
include "triton/Dialect/Triton/IR/TritonOps.td"
// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
def CombineDotAddIPattern : Pat<
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddIRevPattern : Pat<
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d),
(TT_DotOp $a, $b, $d, $allowTF32),
[(Constraint<CPred<"isZero($0)">> $c)]>;
// TODO: this fails for addptr(addptr(ptr, i32), i64)
// Commented out until fixed
// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1))
// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect
// (ref: ArithmeticCanonicalization.td)
// def CombineAddPtrPattern : Pat<
// (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1),
// (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
// broadcast(cst) => cst
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
def CombineBroadcastConstantPattern : Pat<
(TT_BroadcastOp:$bcast_res (Arith_ConstantOp $value)),
(Arith_ConstantOp (getConstantValue $value, $bcast_res)),
[(Constraint<CPred<"isBroadcastConstantCombinable($0)">> $value)]>;
#endif

View File

@@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,11 @@
add_mlir_dialect_library(TritonGPUIR
Dialect.cpp
Traits.cpp
DEPENDS
TritonGPUTableGen
TritonGPUAttrDefsIncGen
LINK_LIBS PUBLIC
TritonIR
)

View File

@@ -0,0 +1,783 @@
#include <numeric>
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/TypeSwitch.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
using namespace mlir;
using namespace mlir::triton::gpu;
// Utility
namespace mlir {
namespace triton {
// Type inference
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return Type();
}
static Type getPointeeType(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
// Tensor of pointers
auto shape = tensorType.getShape();
auto ptrType = tensorType.getElementType().dyn_cast<PointerType>();
Type pointeeType = ptrType.getPointeeType();
return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding());
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
// scalar pointer
Type pointeeType = ptrType.getPointeeType();
return pointeeType;
}
return Type();
}
namespace gpu {
// TODO: Inheritance of layout attributes
// so that all distributed layouts implement
// these utilities
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
return sliceLayout.getElemsPerThread(shape);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return mmaLayout.getElemsPerThread(shape);
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
return sharedLayout.getElemsPerThread(shape);
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
return dotLayout.getElemsPerThread(shape);
} else {
assert(0 && "getElemsPerThread not implemented");
return 0;
}
}
unsigned getElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || type.isa<triton::Float8Type>() ||
type.isa<triton::PointerType>())
return 1;
auto tensorType = type.cast<RankedTensorType>();
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape());
}
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getThreadsPerWarp().begin(),
blockedLayout.getThreadsPerWarp().end());
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return {4, 8};
if (mmaLayout.isAmpere())
return {8, 4};
}
assert(0 && "getThreadsPerWarp not implemented");
return {};
}
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
blockedLayout.getWarpsPerCTA().end());
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return SmallVector<unsigned>(mmaLayout.getWarpsPerCTA().begin(),
mmaLayout.getWarpsPerCTA().end());
}
assert(0 && "getWarpsPerCTA not implemented");
return {};
}
SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end());
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
return getSizePerThread(sliceLayout.getParent());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) {
return {2, 2};
} else if (mmaLayout.isVolta()) {
// Note: here the definition of sizePerThread is obscure, which doesn't
// mean vecSize=4 can be supported in the last dimension.
return {2, 4};
} else {
llvm_unreachable("Unexpected mma version");
}
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
if (auto parentMmaLayout = parentLayout.dyn_cast<MmaEncodingAttr>()) {
assert(parentMmaLayout.isAmpere() &&
"mmaLayout version = 1 is not implemented yet");
auto parentShapePerCTA = getShapePerCTA(parentLayout);
auto opIdx = dotLayout.getOpIdx();
if (opIdx == 0) {
return {2, 4};
} else if (opIdx == 1) {
return {4, 1};
} else {
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
return {};
}
} else {
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
"supported yet");
return {};
}
} else {
assert(0 && "getSizePerThread not implemented");
return {};
}
}
SmallVector<unsigned> getContigPerThread(Attribute layout) {
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.isVolta() || mmaLayout.isAmpere());
return {1, 2};
} else {
return getSizePerThread(layout);
}
}
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout) {
SmallVector<unsigned> threads;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
threads.push_back(blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d]);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(0 && "Unimplemented usage of MmaEncodingAttr");
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
}
return threads;
}
SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
SmallVector<unsigned> shape;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
for (unsigned d = 0, n = blockedLayout.getOrder().size(); d < n; ++d)
shape.push_back(blockedLayout.getSizePerThread()[d] *
blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d]);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parent = sliceLayout.getParent();
for (unsigned d = 0, n = getOrder(parent).size(); d < n; ++d) {
if (d == dim)
continue;
shape.push_back(getShapePerCTA(parent)[d]);
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere())
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
if (mmaLayout.isVolta())
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
assert(0 && "Unexpected MMA layout version found");
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
if (auto parentMmaLayout = parentLayout.dyn_cast<MmaEncodingAttr>()) {
assert(parentMmaLayout.isAmpere() &&
"mmaLayout version = 1 is not implemented yet");
auto parentShapePerCTA = getShapePerCTA(parentLayout);
auto opIdx = dotLayout.getOpIdx();
if (opIdx == 0) {
return {parentShapePerCTA[0], 16};
} else if (opIdx == 1) {
return {16, parentShapePerCTA[1]};
} else {
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
}
} else {
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
"supported yet");
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
} else if (mmaLayout.isVolta()) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
} else {
llvm_unreachable("Unexpected mma version");
}
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
}
return shape;
}
SmallVector<unsigned> getOrder(const Attribute &layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getOrder().begin(),
blockedLayout.getOrder().end());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return {1, 0};
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
return {1, 0};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
unsigned dim = sliceLayout.getDim();
SmallVector<unsigned> order;
for (unsigned d : parentOrder) {
if (d == dim)
continue;
else if (d > dim)
order.push_back(d - 1);
else
order.push_back(d);
}
return order;
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
} else {
assert(0 && "Unimplemented usage of getOrder");
return {};
}
};
} // namespace gpu
} // namespace triton
} // namespace mlir
static LogicalResult parseIntAttrValue(AsmParser &parser, const Attribute &attr,
unsigned &value, StringRef desc) {
auto intAttr = attr.dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer type in ")
<< desc;
return failure();
}
if (intAttr.getType().isSignedInteger()) {
int64_t attrVal = intAttr.getSInt();
if (attrVal < 0) {
parser.emitError(parser.getNameLoc(),
"expected an unsigned integer value in ")
<< desc;
return failure();
}
value = attrVal;
} else if (intAttr.getType().isSignlessInteger()) {
int64_t attrVal = intAttr.getInt();
if (attrVal < 0) {
parser.emitError(parser.getNameLoc(),
"expected an unsigned integer value in ")
<< desc;
return failure();
}
value = attrVal;
} else {
value = intAttr.getUInt();
}
return success();
}
// parse an array of integers
static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr,
SmallVector<unsigned, 2> &res,
StringRef desc) {
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr) {
parser.emitError(parser.getNameLoc(), "expected an array for ") << desc;
return failure();
}
for (Attribute i : arrayAttr) {
unsigned value;
if (parseIntAttrValue(parser, i, value, desc).failed())
return failure();
res.push_back(value);
}
return success();
};
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
unsigned &value, StringRef desc) {
return parseIntAttrValue(parser, attr.getValue(), value, desc);
};
//===----------------------------------------------------------------------===//
// Attribute methods
//===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
return SliceEncodingAttr::get(getContext(), axis, *this);
}
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
auto sizePerThread = getSizePerThread();
auto warpsPerCTA = getWarpsPerCTA();
auto threadsPerWarp = getThreadsPerWarp();
assert(rank == sizePerThread.size() &&
"unexpected rank in BlockedEncodingAttr::getElemsPerThread");
SmallVector<unsigned> elemsPerThread(rank);
for (size_t i = 0; i < rank; ++i) {
unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i];
elemsPerThread[i] = ceil<unsigned>(shape[i], t) * sizePerThread[i];
}
return product<unsigned>(elemsPerThread);
}
template <class T>
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
size_t rank = shape.size();
unsigned dim = getDim();
SmallVector<T> retShape(rank + 1);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d < dim)
retShape[d] = shape[d];
else if (d == dim)
retShape[d] = 1;
else
retShape[d] = shape[d - 1];
}
return retShape;
}
template SmallVector<unsigned>
SliceEncodingAttr::paddedShape<unsigned>(ArrayRef<unsigned> shape) const;
template SmallVector<int64_t>
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
auto parent = getParent();
return ::getElemsPerThread(parent, paddedShape(shape));
}
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
assert(rank == 2 && "Unexpected rank of mma layout");
assert((isVolta() || isAmpere()) && "Only version 1 and 2 is supported");
int res = 0;
if (isVolta()) {
unsigned mmasRow = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]);
unsigned mmasCol = ceil<unsigned>(shape[1], 16 * getWarpsPerCTA()[1]);
// Each warp-level mma884 will perform a m16xn16xk4 mma, thus get a m16xn16
// matrix as result.
res = mmasRow * mmasCol * (16 * 16 / 32);
} else if (isAmpere()) {
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
res = elemsCol * elemsRow;
} else {
llvm_unreachable("Unexpected mma version");
}
return res;
}
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
// TODO:
assert(0 && "SharedEncodingAttr::getElemsPerThread not implemented");
return 0;
}
unsigned
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
if (auto blockedLayout = getParent().dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape);
}
assert(0 && "DotOperandEncodingAttr::getElemsPerThread not implemented");
return 0;
}
//===----------------------------------------------------------------------===//
// Blocked Encoding
//===----------------------------------------------------------------------===//
Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
SmallVector<unsigned, 2> sizePerThread;
SmallVector<unsigned, 2> threadsPerWarp;
SmallVector<unsigned, 2> warpsPerCTA;
SmallVector<unsigned, 2> order;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "sizePerThread") {
if (parseIntArrayAttr(parser, attr, sizePerThread,
"number of elements per thread")
.failed())
return {};
} else if (attr.getName() == "threadsPerWarp") {
if (parseIntArrayAttr(parser, attr, threadsPerWarp,
"number of threads per warp")
.failed())
return {};
} else if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA,
"number of warps per CTA")
.failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
return {};
}
}
auto ret = parser.getChecked<BlockedEncodingAttr>(
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
return ret;
}
void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "sizePerThread = [" << getSizePerThread() << "]"
<< ", threadsPerWarp = [" << getThreadsPerWarp() << "]"
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]"
<< ", order = [" << getOrder() << "]"
<< "}>";
}
//===----------------------------------------------------------------------===//
// MMA encoding
//===----------------------------------------------------------------------===//
Attribute MmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned versionMajor = 0;
unsigned versionMinor = 0;
SmallVector<unsigned, 2> warpsPerCTA;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "versionMajor") {
if (parseUInt(parser, attr, versionMajor, "versionMajor").failed())
return {};
}
if (attr.getName() == "versionMinor") {
if (parseUInt(parser, attr, versionMinor, "versionMinor").failed())
return {};
}
if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
return {};
}
}
return parser.getChecked<MmaEncodingAttr>(parser.getContext(), versionMajor,
versionMinor, warpsPerCTA);
}
void MmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "versionMajor = " << getVersionMajor() << ", "
<< "versionMinor = " << getVersionMinor() << ", "
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
<< "}>";
}
//===----------------------------------------------------------------------===//
// Sliced Encoding
//===----------------------------------------------------------------------===//
Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
NamedAttrList attrs;
if (parser.parseOptionalAttrDict(attrs).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned dim = attrs.get("dim").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
}
void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "dim = " << getDim() << ", "
<< "parent = " << getParent() << "}>";
}
//===----------------------------------------------------------------------===//
// Shared encoding
//===----------------------------------------------------------------------===//
Attribute SharedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned vec = 0;
unsigned perPhase = 0;
unsigned maxPhase = 0;
SmallVector<unsigned, 2> order;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "vec") {
if (parseUInt(parser, attr, vec, "vec").failed())
return {};
} else if (attr.getName() == "perPhase") {
if (parseUInt(parser, attr, perPhase, "perPhase").failed())
return {};
} else if (attr.getName() == "maxPhase") {
if (parseUInt(parser, attr, maxPhase, "maxPhase").failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<SharedEncodingAttr>(parser.getContext(), vec,
perPhase, maxPhase, order);
}
void SharedEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "vec = " << getVec() << ", perPhase = " << getPerPhase()
<< ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder()
<< "]"
<< "}>";
}
//===----------------------------------------------------------------------===//
// Mma encoding
//===----------------------------------------------------------------------===//
bool MmaEncodingAttr::isVolta() const { return getVersionMajor() == 1; }
bool MmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; }
// Get [isARow, isBRow, isAVec4, isBVec4] from versionMinor
std::tuple<bool, bool, bool, bool>
MmaEncodingAttr::decodeVoltaLayoutStates() const {
unsigned versionMinor = getVersionMinor();
bool isARow = versionMinor & (1 << 0);
bool isBRow = versionMinor & (1 << 1);
bool isAVec4 = versionMinor & (1 << 2);
bool isBVec4 = versionMinor & (1 << 3);
return std::make_tuple(isARow, isBRow, isAVec4, isBVec4);
}
//===----------------------------------------------------------------------===//
// DotOperand Encoding
//===----------------------------------------------------------------------===//
Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
NamedAttrList attrs;
if (parser.parseOptionalAttrDict(attrs).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
Attribute isMMAv1Row;
if (parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().isVolta()) {
isMMAv1Row = attrs.get("isMMAv1Row");
if (!isMMAv1Row)
llvm::report_fatal_error("isMMAv1Row attribute is missing");
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent, isMMAv1Row);
}
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent();
if (getIsMMAv1Row())
printer << ", isMMAv1Row = " << getIsMMAv1Row();
printer << "}>";
}
//===----------------------------------------------------------------------===//
// InsertSliceAsyncOp
//===----------------------------------------------------------------------===//
ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 8> allOperands;
Type srcType, dstType;
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(srcType) || parser.parseArrow() ||
parser.parseCustomTypeWithFallback(dstType))
return failure();
result.addTypes(dstType);
SmallVector<Type> operandTypes;
operandTypes.push_back(srcType); // src
operandTypes.push_back(dstType); // dst
operandTypes.push_back(
IntegerType::get(parser.getBuilder().getContext(), 32)); // index
int hasMask = 0, hasOther = 0;
if (allOperands.size() >= 4) {
operandTypes.push_back(triton::getI1SameShape(srcType)); // mask
hasMask = 1;
}
if (allOperands.size() >= 5) {
operandTypes.push_back(triton::getPointeeType(srcType)); // other
hasOther = 1;
}
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
// Deduce operand_segment_sizes from the number of the operands.
auto operand_segment_sizesAttrName =
InsertSliceAsyncOp::operand_segment_sizesAttrName(result.name);
result.addAttribute(
operand_segment_sizesAttrName,
parser.getBuilder().getI32VectorAttr({1, 1, 1, hasMask, hasOther}));
return success();
}
void printInsertSliceAsyncOp(OpAsmPrinter &printer,
InsertSliceAsyncOp insertSliceAsyncOp) {
printer << " ";
printer << insertSliceAsyncOp.getOperation()->getOperands();
// "operand_segment_sizes" can be deduced, so we don't print it.
printer.printOptionalAttrDict(
insertSliceAsyncOp->getAttrs(),
{insertSliceAsyncOp.operand_segment_sizesAttrName()});
printer << " : ";
printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType());
printer << " -> ";
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
}
//===----------------------------------------------------------------------===//
// ASM Interface (i.e.: alias)
//===----------------------------------------------------------------------===//
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
public:
using OpAsmDialectInterface::OpAsmDialectInterface;
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (auto mmaAttr = attr.dyn_cast<MmaEncodingAttr>()) {
os << "mma";
return AliasResult::FinalAlias;
} else if (auto sharedAttr = attr.dyn_cast<SharedEncodingAttr>()) {
os << "shared";
return AliasResult::FinalAlias;
} else if (auto blockedAttr = attr.dyn_cast<BlockedEncodingAttr>()) {
os << "blocked";
return AliasResult::FinalAlias;
} /* else if (auto sliceAttr = attr.dyn_cast<SliceEncodingAttr>()) {
os << "slice";
return AliasResult::FinalAlias;
} */
return OpAsmDialectInterface::getAlias(attr, os);
}
};
struct TritonGPUInferLayoutInterface
: public triton::DialectInferLayoutInterface {
using DialectInferLayoutInterface::DialectInferLayoutInterface;
LogicalResult
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const override {
resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis,
operandEncoding);
return success();
}
LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding,
Optional<Location> location) const override {
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
if (!sliceEncoding)
return emitOptionalError(
location, "ExpandDimsOp operand encoding must be SliceEncodingAttr");
if (sliceEncoding.getDim() != axis)
return emitOptionalError(
location, "Incompatible slice dimension for ExpandDimsOp operand");
resultEncoding = sliceEncoding.getParent();
return success();
}
LogicalResult inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
Attribute retEncoding,
Optional<Location> location) const override {
if (auto dotOpEnc = operandEncoding.dyn_cast<DotOperandEncodingAttr>()) {
if (opIdx != dotOpEnc.getOpIdx())
return emitOptionalError(location, "Wrong opIdx");
if (retEncoding != dotOpEnc.getParent())
return emitOptionalError(location, "Incompatible parent encoding");
} else
return emitOptionalError(
location, "Dot's a/b's encoding should be of DotOperandEncodingAttr");
return success();
}
};
void TritonGPUDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
>();
addInterfaces<TritonGPUOpAsmInterface>();
addInterfaces<TritonGPUInferLayoutInterface>();
}
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
// verify TritonGPU ops
LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// TODO: fill this.
return success();
}

View File

@@ -0,0 +1,14 @@
#include "triton/Dialect/TritonGPU/IR/Traits.h"
#include "triton/Analysis/Utility.h"
mlir::LogicalResult
mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) {
if (failed(verifyAtLeastNResults(op, 1)))
return failure();
for (auto result : op->getResults())
if (!isSharedEncoding(result))
return op->emitOpError() << "requires all results to be shared encoding";
return success();
};

View File

@@ -0,0 +1,21 @@
set(LLVM_TARGET_DEFINITIONS Combine.td)
mlir_tablegen(TritonGPUCombine.inc -gen-rewriters)
add_public_tablegen_target(TritonGPUCombineIncGen)
add_mlir_dialect_library(TritonGPUTransforms
Coalesce.cpp
CanonicalizeLoops.cpp
Combine.cpp
Pipeline.cpp
Prefetch.cpp
TritonGPUConversion.cpp
DEPENDS
TritonGPUTransformsIncGen
TritonGPUCombineIncGen
LINK_LIBS PUBLIC
TritonIR
TritonGPUIR
MLIRTransformUtils
)

View File

@@ -0,0 +1,55 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
struct CanonicalizePass
: public TritonGPUCanonicalizeLoopsBase<CanonicalizePass> {
CanonicalizePass() = default;
void runOnOperation() override {
// Canonicalize pass may have created dead code that
// standard scf.for canonicalization cannot handle
// as of LLVM 14. For example, the iteration arguments
// for the pointer of the synchronous loads that are
// discarded.
// The following piece of code is a workaround to
// very crudely remove dead code, by making an iteration
// argument yield itself if it is not used to create
// side effects anywhere.
getOperation()->walk([&](scf::ForOp forOp) -> void {
for (size_t i = 0; i < forOp.getNumResults(); ++i) {
// condition 1: no other iter arguments depend on it
SetVector<Operation *> fwdSlice;
mlir::getForwardSlice(forOp.getRegionIterArgs()[i], &fwdSlice);
Operation *yieldOp = forOp.getBody()->getTerminator();
bool noOtherDependency = std::all_of(
yieldOp->operand_begin(), yieldOp->operand_end(), [&](Value arg) {
return arg == yieldOp->getOperand(i) ||
!fwdSlice.contains(arg.getDefiningOp());
});
// condition 2: final value is not used after the loop
auto retVal = forOp.getResult(i);
bool noUserAfterLoop = retVal.getUsers().empty();
// yielding the region iter arg will cause loop canonicalization
// to clean up the dead code
if (noOtherDependency && noUserAfterLoop) {
yieldOp->setOperand(i, forOp.getRegionIterArgs()[i]);
}
}
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUCanonicalizeLoopsPass() {
return std::make_unique<CanonicalizePass>();
}

View File

@@ -0,0 +1,139 @@
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <numeric>
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr,
int numWarps) {
auto origType = ptr.getType().cast<RankedTensorType>();
// Get the shape of the tensor.
size_t rank = origType.getRank();
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
// Layout order in decreasing order of contiguity
SmallVector<unsigned, 4> order(rank);
std::iota(order.begin(), order.end(), 0);
auto contiguity = info.getContiguity();
std::sort(order.begin(), order.end(), [&](unsigned x, unsigned y) {
return contiguity[x] > contiguity[y];
});
int numElems = product(origType.getShape());
int numThreads = numWarps * 32;
int numElemsPerThread = std::max(numElems / numThreads, 1);
// Thread tile size depends on memory alignment
SmallVector<unsigned, 4> sizePerThread(rank, 1);
PointerType ptrType = origType.getElementType().cast<PointerType>();
auto pointeeType = ptrType.getPointeeType();
unsigned numBits = pointeeType.isa<triton::Float8Type>()
? 8
: pointeeType.getIntOrFloatBitWidth();
unsigned maxMultiple = info.getDivisibility(order[0]);
unsigned maxContig = info.getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
unsigned perThread = std::min(alignment, 128 / numBits);
sizePerThread[order[0]] = std::min<int>(perThread, numElemsPerThread);
SmallVector<unsigned> dims(rank);
std::iota(dims.begin(), dims.end(), 0);
// create encoding
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
&getContext(), origType.getShape(), sizePerThread, order, numWarps);
return encoding;
}
std::function<Type(Type)> getTypeConverter(AxisInfoAnalysis &axisInfo,
Value ptr, int numWarps) {
Attribute encoding = getCoalescedEncoding(axisInfo, ptr, numWarps);
return [encoding](Type _type) {
RankedTensorType type = _type.cast<RankedTensorType>();
return RankedTensorType::get(type.getShape(), type.getElementType(),
encoding);
};
}
template <class T>
void coalesceOp(AxisInfoAnalysis &axisInfo, Operation *op, Value ptr,
OpBuilder builder) {
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
if (!ty)
return;
auto mod = op->getParentOfType<ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
auto convertType = getTypeConverter(axisInfo, ptr, numWarps);
// convert operands
SmallVector<Value, 4> newArgs;
for (auto v : op->getOperands()) {
auto vTy = v.getType().dyn_cast<RankedTensorType>();
if (vTy && !vTy.getEncoding().isa<triton::gpu::SharedEncodingAttr>())
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), convertType(v.getType()), v));
else
newArgs.push_back(v);
}
// convert output types
SmallVector<Type, 4> newTypes;
for (auto t : op->getResultTypes()) {
bool is_async = std::is_same<T, triton::gpu::InsertSliceAsyncOp>::value;
newTypes.push_back(is_async ? t : convertType(t));
}
// construct new op with the new encoding
Operation *newOp =
builder.create<T>(op->getLoc(), newTypes, newArgs, op->getAttrs());
// cast the results back to the original layout
for (size_t i = 0; i < op->getNumResults(); i++) {
Value newResult = newOp->getResult(i);
if (newTypes[i] != op->getResultTypes()[i]) {
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), op->getResult(i).getType(), newResult);
}
op->getResult(i).replaceAllUsesWith(newResult);
}
op->erase();
}
void runOnOperation() override {
Operation *op = getOperation();
// Run axis info analysis
AxisInfoAnalysis axisInfo(&getContext());
axisInfo.run(op);
OpBuilder builder(op);
// For each memory op that has a layout L1:
// 1. Create a coalesced memory layout L2 of the pointer operands
// 2. Convert all operands from layout L1 to layout L2
// 3. Create a new memory op that consumes these operands and
// produces a tensor with layout L2
// 4. Convert the output of this new memory op back to L1
// 5. Replace all the uses of the original memory op by the new one
op->walk([&](Operation *curr) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(curr);
if (auto load = dyn_cast<triton::LoadOp>(curr))
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
if (auto op = dyn_cast<triton::AtomicRMWOp>(curr))
coalesceOp<triton::AtomicRMWOp>(axisInfo, curr, op.ptr(), builder);
if (auto op = dyn_cast<triton::AtomicCASOp>(curr))
coalesceOp<triton::AtomicCASOp>(axisInfo, curr, op.ptr(), builder);
if (auto load = dyn_cast<triton::gpu::InsertSliceAsyncOp>(curr))
coalesceOp<triton::gpu::InsertSliceAsyncOp>(axisInfo, curr, load.src(),
builder);
if (auto store = dyn_cast<triton::StoreOp>(curr))
coalesceOp<triton::StoreOp>(axisInfo, curr, store.ptr(), builder);
});
}
};
std::unique_ptr<Pass> mlir::createTritonGPUCoalescePass() {
return std::make_unique<CoalescePass>();
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
#ifndef TRITONGPU_PATTERNS
#define TRITONGPU_PATTERNS
include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td"
include "triton/Dialect/Triton/IR/TritonOps.td"
#endif

View File

@@ -0,0 +1,656 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
//===----------------------------------------------------------------------===//
//
// This file implements loop software pipelining
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
// and SCF's LoopPipelining.
//
//===----------------------------------------------------------------------===//
using namespace mlir;
namespace ttg = triton::gpu;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
static Type getI1SameShape(Value v) {
Type vType = v.getType();
auto i1Type = IntegerType::get(vType.getContext(), 1);
auto tensorType = vType.cast<RankedTensorType>();
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
}
#define int_attr(num) builder.getI64IntegerAttr(num)
namespace {
class LoopPipeliner {
/// Cache forOp we are working on
scf::ForOp forOp;
/// Cache YieldOp for this forOp
scf::YieldOp yieldOp;
/// Loads to be pipelined
SetVector<Value> loads;
/// The value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// load => buffer
DenseMap<Value, Value> loadsBuffer;
/// load => buffer type (with shared layout after swizzling)
DenseMap<Value, RankedTensorType> loadsBufferType;
/// load => buffer at stage N
DenseMap<Value, SmallVector<Value>> loadStageBuffer;
/// load => after extract
DenseMap<Value, Value> loadsExtract;
///
Value pipelineIterIdx;
///
Value loopIterIdx;
/// Comments on numStages:
/// [0, numStages-1) are in the prologue
/// numStages-1 is appended after the loop body
int numStages;
/// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping;
/// Block arguments that loads depend on
DenseSet<BlockArgument> depArgs;
/// Operations (inside the loop body) that loads depend on
DenseSet<Operation *> depOps;
/// collect values that v depends on and are defined inside the loop
void collectDeps(Value v, int stages, DenseSet<Value> &deps);
void setValueMapping(Value origin, Value newValue, int stage);
Value lookupOrDefault(Value origin, int stage);
/// Returns a empty buffer of size <numStages, ...>
ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder);
public:
LoopPipeliner(scf::ForOp forOp, int numStages)
: forOp(forOp), numStages(numStages) {
// cache yieldOp
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
}
/// Collect loads to pipeline. Return success if we can pipeline this loop
LogicalResult initialize();
/// Emit pipelined loads (before loop body)
void emitPrologue();
/// emit pipelined loads (after loop body)
void emitEpilogue();
/// create the new ForOp (add new args & insert prefetched ops)
scf::ForOp createNewForOp();
friend struct PipelinePass;
};
// helpers
void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) {
if (valueMapping.find(origin) == valueMapping.end())
valueMapping[origin] = SmallVector<Value>(numStages);
valueMapping[origin][stage] = newValue;
}
Value LoopPipeliner::lookupOrDefault(Value origin, int stage) {
if (valueMapping.find(origin) == valueMapping.end())
return origin;
return valueMapping[origin][stage];
}
void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
// Loop-invariant value, skip
if (v.getParentRegion() != &forOp.getLoopBody())
return;
// Since we only need to peel the loop numStages-1 times, don't worry about
// depends that are too far away
if (stages < 0)
return;
if (auto arg = v.dyn_cast<BlockArgument>()) {
if (arg.getArgNumber() > 0) {
// Skip the first arg (loop induction variable)
// Otherwise the op idx is arg.getArgNumber()-1
deps.insert(v);
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1,
deps);
}
} else { // value
// v might be in deps, but we still need to visit v.
// This is because v might depend on value in previous iterations
deps.insert(v);
for (Value op : v.getDefiningOp()->getOperands())
collectDeps(op, stages, deps);
}
}
ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
OpBuilder &builder) {
// Allocate a buffer for each pipelined tensor
// shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16>
Value convertLayout = loadsMapping[op->getResult(0)];
if (auto tensorType = convertLayout.getType().dyn_cast<RankedTensorType>()) {
return builder.create<ttg::AllocTensorOp>(
convertLayout.getLoc(), loadsBufferType[op->getResult(0)]);
}
llvm_unreachable("Async copy's return should be of RankedTensorType");
}
/// A load instruction can be pipelined if:
/// - the load doesn't depend on any other loads (after loop peeling)
/// - (?) this load is not a loop-invariant value (we should run LICM before
/// this pass?)
LogicalResult LoopPipeliner::initialize() {
Block *loop = forOp.getBody();
// can we use forOp.walk(...) here?
SmallVector<triton::LoadOp, 2> allLoads;
for (Operation &op : *loop)
if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
allLoads.push_back(loadOp);
// Early stop: no need to continue if there is no load in the loop.
if (allLoads.empty())
return failure();
// load => values that it depends on
DenseMap<Value, DenseSet<Value>> loadDeps;
for (triton::LoadOp loadOp : allLoads) {
DenseSet<Value> deps;
for (Value op : loadOp->getOperands())
collectDeps(op, numStages - 1, deps);
loadDeps[loadOp] = deps;
}
// Don't pipeline loads that depend on other loads
// (Because if a load depends on another load, this load needs to wait on the
// other load in the prologue, which is against the point of the pipeline
// pass)
for (triton::LoadOp loadOp : allLoads) {
bool isCandidate = true;
for (triton::LoadOp other : allLoads) {
if (loadDeps[loadOp].contains(other)) {
isCandidate = false;
break;
}
}
// We only pipeline loads that have one covert_layout (to dot_op) use
// TODO: lift this constraint in the future
if (isCandidate && loadOp.getResult().hasOneUse()) {
isCandidate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (auto dotOpEnc = tensorType.getEncoding()
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
isCandidate = true;
loadsMapping[loadOp] = convertLayout;
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()), ty.getElementType());
loadsBufferType[loadOp] = RankedTensorType::get(
bufferShape, ty.getElementType(), sharedEnc);
}
}
}
} else
isCandidate = false;
if (isCandidate)
loads.insert(loadOp);
}
// We have some loads to pipeline
if (!loads.empty()) {
// Update depArgs & depOps
for (Value loadOp : loads) {
for (Value dep : loadDeps[loadOp]) {
// TODO: we should record the stage that the value is depended on
if (auto arg = dep.dyn_cast<BlockArgument>())
depArgs.insert(arg);
else
depOps.insert(dep.getDefiningOp());
}
}
return success();
}
return failure();
}
void LoopPipeliner::emitPrologue() {
// llvm::errs() << "loads to pipeline...:\n";
// for (Value load : loads)
// llvm::errs() << load << "\n";
OpBuilder builder(forOp);
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
setValueMapping(arg, operand.get(), 0);
}
// prologue from [0, numStage-1)
Value iv = forOp.getLowerBound();
pipelineIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (int stage = 0; stage < numStages - 1; ++stage) {
// Special handling for induction variable as the increment is implicit
if (stage != 0)
iv = builder.create<arith::AddIOp>(iv.getLoc(), iv, forOp.getStep());
setValueMapping(forOp.getInductionVar(), iv, stage);
// Special handling for loop condition as there is no condition in ForOp
Value loopCond = builder.create<arith::CmpIOp>(
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
// Rematerialize peeled values
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
else if (loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
"depOps contains invalid values");
for (Operation *op : orderedDeps) {
Operation *newOp = nullptr;
if (loads.contains(op->getResult(0))) {
// Allocate empty buffer
if (stage == 0) {
loadsBuffer[op->getResult(0)] = allocateEmptyBuffer(op, builder);
loadStageBuffer[op->getResult(0)] = {loadsBuffer[op->getResult(0)]};
}
// load => copy async
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
Value mask = lookupOrDefault(loadOp.mask(), stage);
Value newMask;
if (mask) {
Value splatCond = builder.create<triton::SplatOp>(
mask.getLoc(), mask.getType(), loopCond);
newMask =
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
} else {
newMask = builder.create<triton::SplatOp>(
loopCond.getLoc(), getI1SameShape(loadOp), loopCond);
}
// TODO: check if the hardware supports async copy
newOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
op->getLoc(), loadsBuffer[loadOp].getType(),
lookupOrDefault(loadOp.ptr(), stage),
loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask,
lookupOrDefault(loadOp.other(), stage), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
loadStageBuffer[loadOp].push_back(newOp->getResult(0));
} else
llvm_unreachable("This should be LoadOp");
} else {
newOp = builder.clone(*op);
// Update loop-carried uses
for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) {
auto it = valueMapping.find(op->getOperand(opIdx));
if (it != valueMapping.end()) {
Value v = it->second[stage];
assert(v);
newOp->setOperand(opIdx, v);
} // else, op at opIdx is a loop-invariant value
}
}
// Update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
Value originalResult = op->getResult(dstIdx);
// copy_async will update the value of its only use
// TODO: load should not be used in the preheader?
if (loads.contains(originalResult)) {
break;
// originalResult = loadsMapping[originalResult];
}
setValueMapping(originalResult, newOp->getResult(dstIdx), stage);
// update mapping for loop-carried values (args)
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx))
setValueMapping(
forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1);
}
}
} // for (Operation *op : orderedDeps)
pipelineIterIdx = builder.create<arith::AddIOp>(
iv.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
} // for (int stage = 0; stage < numStages - 1; ++stage)
// async.wait & extract_slice
builder.create<ttg::AsyncWaitOp>(loads[0].getLoc(),
loads.size() * (numStages - 2));
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
for (Value loadOp : loads) {
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
sliceType =
RankedTensorType::get(sliceType.getShape(), sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
int_attr(sliceType.getShape()[0]),
int_attr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
loadsExtract[loadOp] = extractSlice;
}
// Bump up loopIterIdx, this is used for getting the correct slice for the
// *next* iteration
loopIterIdx = builder.create<arith::AddIOp>(
loopIterIdx.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(loopIterIdx.getLoc(), 1, 32));
}
void LoopPipeliner::emitEpilogue() {
// If there's any outstanding async copies, we need to wait for them.
OpBuilder builder(forOp);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointAfter(forOp);
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
}
scf::ForOp LoopPipeliner::createNewForOp() {
OpBuilder builder(forOp);
// Order of new args:
// (original args)
// (insertSliceAsync buffer at stage numStages - 1) for each load
// (extracted tensor) for each load
// (depArgs at stage numStages - 1)
// (iv at stage numStages - 2)
// (pipeline iteration index)
// (loop iteration index)
SmallVector<Value> newLoopArgs;
// We need this to update operands for yield
// original block arg => new arg's idx
DenseMap<BlockArgument, size_t> depArgsIdx;
for (auto v : forOp.getIterOperands())
newLoopArgs.push_back(v);
size_t bufferIdx = newLoopArgs.size();
for (Value loadOp : loads)
newLoopArgs.push_back(loadStageBuffer[loadOp].back());
size_t loadIdx = newLoopArgs.size();
for (Value loadOp : loads)
newLoopArgs.push_back(loadsExtract[loadOp]);
size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs) {
depArgsIdx[depArg] = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[depArg][numStages - 1]);
}
size_t nextIVIdx = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
newLoopArgs.push_back(pipelineIterIdx);
newLoopArgs.push_back(loopIterIdx);
for (size_t i = 0; i < newLoopArgs.size(); ++i)
assert(newLoopArgs[i]);
// 1. signature of the new ForOp
auto newForOp = builder.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newLoopArgs);
// 2. body of the new ForOp
builder.setInsertionPointToStart(newForOp.getBody());
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
// 2.1 clone the loop body, replace original args with args of the new ForOp
// Insert async wait if necessary.
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping);
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
}
// 3. replace loads with block args (from prologue)
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
assert(load.hasOneUse() &&
"we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
mapping.lookup(loadUse).replaceAllUsesWith(
newForOp.getRegionIterArgs()[loadIdx + idx]);
// delete old load and layout conversion
mapping.lookup(loadUse).getDefiningOp()->erase();
mapping.lookup(load).getDefiningOp()->erase();
}
// 4. prefetch the next iteration
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
else if (loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
"depOps contains invalid values");
BlockAndValueMapping nextMapping;
DenseMap<BlockArgument, Value> depArgsMapping;
size_t argIdx = 0;
for (BlockArgument arg : depArgs) {
nextMapping.map(arg,
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
++argIdx;
}
// Special handling for iv & loop condition
Value nextIV = builder.create<arith::AddIOp>(
newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
Value nextLoopCond =
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
nextMapping.map(forOp.getInductionVar(), nextIV);
// Slice index
SmallVector<Value> nextBuffers;
SmallVector<Value> extractSlices;
pipelineIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 1];
Value insertSliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
loopIterIdx = newForOp.getRegionIterArgs()[nextIVIdx + 2];
Value extractSliceIndex = builder.create<arith::RemSIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
extractSliceIndex = builder.create<arith::IndexCastOp>(
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// Update loading mask
if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask();
Value newMask;
if (mask) {
Value splatCond = builder.create<triton::SplatOp>(
mask.getLoc(), mask.getType(), nextLoopCond);
newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// If mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
newMask = nextMapping.lookupOrDefault(loadOp.mask());
} else
newMask = builder.create<triton::SplatOp>(
loadOp.getLoc(), getI1SameShape(loadOp), nextLoopCond);
Value insertAsyncOp = builder.create<triton::gpu::InsertSliceAsyncOp>(
op->getLoc(), loadsBuffer[loadOp].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()],
insertSliceIndex, newMask,
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
nextBuffers.push_back(insertAsyncOp);
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
sliceType = RankedTensorType::get(sliceType.getShape(),
sliceType.getElementType(),
loadsBufferType[loadOp].getEncoding());
nextOp = builder.create<tensor::ExtractSliceOp>(
op->getLoc(), sliceType, insertAsyncOp,
SmallVector<OpFoldResult>{extractSliceIndex, int_attr(0),
int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1),
int_attr(sliceType.getShape()[0]),
int_attr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
extractSlices.push_back(nextOp->getResult(0));
} else
nextOp = builder.clone(*op, nextMapping);
// Update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
// If this is a loop-carried value, update the mapping for yield
auto originYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (OpOperand &operand : originYield->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx)) {
size_t originIdx = operand.getOperandNumber();
size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]];
BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx];
depArgsMapping[newArg] = nextOp->getResult(dstIdx);
}
}
}
}
{
OpBuilder::InsertionGuard guard(builder);
for (Operation &op : *newForOp.getBody()) {
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(&op)) {
builder.setInsertionPoint(&op);
auto dotType = dotOp.getType().cast<RankedTensorType>();
Value a = dotOp.a();
Value b = dotOp.b();
auto layoutCast = [&](Value dotOperand, int opIdx) -> Value {
auto tensorType = dotOperand.getType().cast<RankedTensorType>();
if (!tensorType.getEncoding().isa<ttg::DotOperandEncodingAttr>()) {
auto newEncoding = ttg::DotOperandEncodingAttr::get(
tensorType.getContext(), opIdx, dotType.getEncoding());
auto newType =
RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), newEncoding);
return builder.create<ttg::ConvertLayoutOp>(dotOperand.getLoc(),
newType, dotOperand);
}
return dotOperand;
};
a = layoutCast(a, 0);
b = layoutCast(b, 1);
dotOp->setOperand(0, a);
dotOp->setOperand(1, b);
}
}
}
// async.wait & extract_slice
Operation *asyncWait = builder.create<ttg::AsyncWaitOp>(
loads[0].getLoc(), loads.size() * (numStages - 2));
for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) {
// move extract_slice after asyncWait
it->getDefiningOp()->moveAfter(asyncWait);
}
// Bump iteration count
pipelineIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
loopIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
// Finally, the YieldOp, need to sync with the order of newLoopArgs
SmallVector<Value> yieldValues;
for (Value v : forOp.getBody()->getTerminator()->getOperands())
yieldValues.push_back(mapping.lookup(v));
for (Value nextBuffer : nextBuffers)
yieldValues.push_back(nextBuffer);
for (Value nextSlice : extractSlices)
yieldValues.push_back(nextSlice);
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) {
auto arg = newForOp.getRegionIterArgs()[i];
assert(depArgsMapping.count(arg) && "Missing loop-carried value");
yieldValues.push_back(depArgsMapping[arg]);
}
yieldValues.push_back(nextIV);
yieldValues.push_back(pipelineIterIdx);
yieldValues.push_back(loopIterIdx);
builder.setInsertionPointToEnd(newForOp.getBody());
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
yieldValues);
return newForOp;
}
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
PipelinePass() = default;
PipelinePass(int numStages) { this->numStages = numStages; }
void runOnOperation() override {
int numStages = this->numStages;
if (numStages <= 1)
return;
getOperation()->walk([&](scf::ForOp forOp) -> void {
LoopPipeliner pipeliner(forOp, numStages);
if (pipeliner.initialize().failed())
return;
pipeliner.emitPrologue();
scf::ForOp newForOp = pipeliner.createNewForOp();
pipeliner.emitEpilogue();
// replace the original loop
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
forOp->erase();
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUPipelinePass(int numStages) {
return std::make_unique<PipelinePass>(numStages);
}

View File

@@ -0,0 +1,313 @@
//===----------------------------------------------------------------------===//
//
// This pass tries to prefetch operands (a and b) of tt.dot.
// Those ConvertLayoutOps will be lowered to shared memory loads.
//
// For example:
// %a: tensor<128x32xf16, #enc>
// scf.for %iv = ... iter_args(%a_arg = %a, ...) {
// %d = tt.dot %a_arg, %b, %c
// ...
// scf.yield %a_next, ...
// }
//
// will be translated to
//
// %a: tensor<128x32xf16, #enc>
// %a_tmp = tensor.extract_slice %a[0, 0] [128, 16]
// %a_prefetch = triton_gpu.convert_layout %a_tmp
// scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch)
// {
// %x = tt.dot %a_arg, %b, %c
// %a_tmp_rem = tensor.extract_slice %a_buf[0, 16] [128, 16]
// %a_prefetch_next = triton_gpu.convert_layout %a_tmp_rem
// ...
// scf.yield %next_a, ..., %a_prefetch_next
// }
//===----------------------------------------------------------------------===//
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
class Prefetcher {
/// cache the ForOp we are working on
scf::ForOp forOp;
/// cache the YieldOp of this ForOp
scf::YieldOp yieldOp;
///
// TODO: add a hook to infer prefetchWidth
unsigned prefetchWidth = 16;
/// dots to be prefetched
SetVector<Value> dots;
/// dot => dot operand
DenseMap<Value, Value> dot2aLoopArg;
DenseMap<Value, Value> dot2aHeaderDef;
DenseMap<Value, Value> dot2bLoopArg;
DenseMap<Value, Value> dot2bHeaderDef;
DenseMap<Value, Value> dot2aYield;
DenseMap<Value, Value> dot2bYield;
/// operand => defining
DenseMap<Value, Value> operand2headPrefetch;
LogicalResult isForOpOperand(Value v);
Value generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
Attribute dotEncoding, OpBuilder &builder,
llvm::Optional<int64_t> offsetK = llvm::None,
llvm::Optional<int64_t> shapeK = llvm::None);
public:
Prefetcher() = delete;
Prefetcher(scf::ForOp forOp) : forOp(forOp) {
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
}
LogicalResult initialize();
void emitPrologue();
scf::ForOp createNewForOp();
};
Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
Attribute dotEncoding, OpBuilder &builder,
llvm::Optional<int64_t> offsetK,
llvm::Optional<int64_t> shapeK) {
// opIdx: 0 => a, 1 => b
auto type = v.getType().cast<RankedTensorType>();
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
SmallVector<int64_t> offset{0, 0};
Type elementType = type.getElementType();
auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); };
// k => (prefetchWidth, k - prefetchWidth)
int64_t kIdx = opIdx == 0 ? 1 : 0;
offset[kIdx] = isPrologue ? 0 : prefetchWidth;
shape[kIdx] = isPrologue ? prefetchWidth : (shape[kIdx] - prefetchWidth);
if (shapeK)
shape[kIdx] = *shapeK;
if (offsetK)
offset[kIdx] = *offsetK;
Value newSmem = builder.create<tensor::ExtractSliceOp>(
v.getLoc(),
// TODO: encoding?
RankedTensorType::get(shape, elementType, type.getEncoding()), v,
SmallVector<OpFoldResult>{intAttr(offset[0]), intAttr(offset[1])},
SmallVector<OpFoldResult>{intAttr(shape[0]), intAttr(shape[1])},
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding);
Value prefetchSlice = builder.create<triton::gpu::ConvertLayoutOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
return prefetchSlice;
}
LogicalResult Prefetcher::initialize() {
Block *loop = forOp.getBody();
SmallVector<triton::DotOp> dotsInFor;
for (Operation &op : *loop)
if (auto dotOp = dyn_cast<triton::DotOp>(op))
dotsInFor.push_back(dotOp);
if (dotsInFor.empty())
return failure();
// TODO: segfault (original for still has uses)
// when used in flash attention that has 2 dots in the loop
if (dotsInFor.size() > 1)
return failure();
// returns source of cvt
auto getPrefetchSrc = [](Value v) -> Value {
if (auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>())
if (isSharedEncoding(cvt.getOperand()))
return cvt.src();
return Value();
};
auto getIncomingOp = [this](Value v) -> Value {
if (auto arg = v.dyn_cast<BlockArgument>())
if (arg.getOwner()->getParentOp() == forOp.getOperation())
return forOp.getOpOperandForRegionIterArg(arg).get();
return Value();
};
auto getYieldOp = [this](Value v) -> Value {
auto arg = v.cast<BlockArgument>();
unsigned yieldIdx = arg.getArgNumber() - forOp.getNumInductionVars();
return yieldOp.getOperand(yieldIdx);
};
for (triton::DotOp dot : dotsInFor) {
auto kSize = dot.a().getType().cast<RankedTensorType>().getShape()[1];
// Skip prefetching if kSize is less than prefetchWidth
if (kSize < prefetchWidth)
continue;
Value aSmem = getPrefetchSrc(dot.a());
Value bSmem = getPrefetchSrc(dot.b());
if (aSmem && bSmem) {
Value aHeaderDef = getIncomingOp(aSmem);
Value bHeaderDef = getIncomingOp(bSmem);
// Only prefetch loop arg
if (aHeaderDef && bHeaderDef) {
dots.insert(dot);
dot2aHeaderDef[dot] = aHeaderDef;
dot2bHeaderDef[dot] = bHeaderDef;
dot2aLoopArg[dot] = aSmem;
dot2bLoopArg[dot] = bSmem;
dot2aYield[dot] = getYieldOp(aSmem);
dot2bYield[dot] = getYieldOp(bSmem);
}
}
}
return success();
}
void Prefetcher::emitPrologue() {
OpBuilder builder(forOp);
for (Value dot : dots) {
Attribute dotEncoding =
dot.getType().cast<RankedTensorType>().getEncoding();
Value aPrefetched =
generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder);
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().a()] = aPrefetched;
Value bPrefetched =
generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder);
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().b()] = bPrefetched;
}
}
scf::ForOp Prefetcher::createNewForOp() {
OpBuilder builder(forOp);
SmallVector<Value> loopArgs;
for (auto v : forOp.getIterOperands())
loopArgs.push_back(v);
for (Value dot : dots) {
loopArgs.push_back(
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().a()]);
loopArgs.push_back(
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().b()]);
}
auto newForOp = builder.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), loopArgs);
auto largestPow2 = [](int64_t n) -> int64_t {
while ((n & (n - 1)) != 0)
n = n & (n - 1);
return n;
};
builder.setInsertionPointToStart(newForOp.getBody());
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping);
auto dot = dyn_cast<triton::DotOp>(&op);
if (dots.contains(dot)) {
Attribute dotEncoding =
dot.getType().cast<RankedTensorType>().getEncoding();
// prefetched dot
Operation *firstDot = builder.clone(*dot, mapping);
if (Value a = operand2headPrefetch.lookup(dot.a()))
firstDot->setOperand(
0, newForOp.getRegionIterArgForOpOperand(*a.use_begin()));
if (Value b = operand2headPrefetch.lookup(dot.b()))
firstDot->setOperand(
1, newForOp.getRegionIterArgForOpOperand(*b.use_begin()));
// remaining part
int64_t kOff = prefetchWidth;
int64_t kRem = dot.a().getType().cast<RankedTensorType>().getShape()[1] -
prefetchWidth;
Operation *prevDot = firstDot;
while (kRem != 0) {
int64_t kShape = largestPow2(kRem);
Value aRem =
generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false,
dotEncoding, builder, kOff, kShape);
Value bRem =
generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false,
dotEncoding, builder, kOff, kShape);
newOp = builder.clone(*dot, mapping);
newOp->setOperand(0, aRem);
newOp->setOperand(1, bRem);
newOp->setOperand(2, prevDot->getResult(0));
prevDot = newOp;
kOff += kShape;
kRem -= kShape;
}
}
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
}
// prefetch next iteration
SmallVector<Value> yieldValues;
for (Value v : forOp.getBody()->getTerminator()->getOperands())
yieldValues.push_back(mapping.lookup(v));
for (Value dot : dots) {
Attribute dotEncoding =
dot.getType().cast<RankedTensorType>().getEncoding();
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2aYield[dot]), 0,
true, dotEncoding, builder));
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2bYield[dot]), 1,
true, dotEncoding, builder));
}
// Update ops of yield
builder.create<scf::YieldOp>(yieldOp.getLoc(), yieldValues);
return newForOp;
}
struct PrefetchPass : public TritonGPUPrefetchBase<PrefetchPass> {
void runOnOperation() override {
getOperation()->walk([&](scf::ForOp forOp) {
Prefetcher prefetcher(forOp);
if (prefetcher.initialize().failed())
return;
prefetcher.emitPrologue();
scf::ForOp newForOp = prefetcher.createNewForOp();
// replace the original loop
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
forOp->erase();
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUPrefetchPass() {
return std::make_unique<PrefetchPass>();
}

View File

@@ -0,0 +1,103 @@
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <numeric>
using namespace mlir;
using namespace mlir::triton::gpu;
//
// TypeConverter
//
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numWarps)
: context(context), numWarps(numWarps) {
// TODO: how does MLIR pick the right conversion?
addConversion([](Type type) { return type; });
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
// types with encoding are already in the right format
// TODO: check for layout encodings specifically
if (tensorType.getEncoding())
return tensorType;
// pessimistic values for attributes:
// - 1 element per thread
// - order = arange(rank)
ArrayRef<int64_t> shape = tensorType.getShape();
int rank = shape.size();
llvm::SmallVector<unsigned> order(rank);
std::iota(order.begin(), order.end(), 0);
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
this->context, shape, sizePerThread, order, this->numWarps);
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
});
//
// Materializations
//
// This will be called when (newArgType != origArgType)
// This will create newArg, and map(origArg, newArg)
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) {
llvm_unreachable("Argument rematerialization not implemented");
return llvm::None;
});
// If the origValue still has live user(s), use this to
// convert origValue to newValue
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
llvm_unreachable("Source rematerialization not implemented");
return llvm::None;
});
// This will be called when (desiredType != newOperandType)
// where, desiredType = typeConverter->convertType(origType)
// NOTE: only for remapped values.
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
auto cast =
builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType, inputs);
return Optional<Value>(cast.getResult());
// return Optional<Value>(cast.getResult(0));
// llvm_unreachable("Not implemented");
// return llvm::None;
});
}
//
// TritonGPUConversion
//
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context) {
// TODO: we should also verify ops of TritonGPUDialect
addLegalDialect<triton::gpu::TritonGPUDialect>();
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect, math::MathDialect,
triton::TritonDialect, StandardOpsDialect,
scf::SCFDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
Attribute aEncoding =
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding =
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding && aEncoding.isa<triton::gpu::DotOperandEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
return true;
return false;
});
}