[CI] run clang-format (#24)

This commit is contained in:
Philippe Tillet
2022-07-26 17:25:03 -07:00
committed by GitHub
parent 25357083e6
commit 6d62d88d4f
62 changed files with 13673 additions and 11367 deletions

View File

@@ -7,7 +7,6 @@
#include "mlir/IR/DialectImplementation.h"
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
using namespace mlir;
@@ -19,12 +18,13 @@ void TritonDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
>();
>();
// We can also add interface here.
}
Operation *TritonDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
Operation *TritonDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<arith::ConstantOp>(loc, type, value);
}

View File

@@ -13,14 +13,16 @@ namespace triton {
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return Type();
}
static Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i32Type, tensorType.getEncoding());
return RankedTensorType::get(tensorType.getShape(), i32Type,
tensorType.getEncoding());
return Type();
}
@@ -34,8 +36,8 @@ static Type getPointerTypeFromTensor(Type type) {
return Type();
}
}
}
} // namespace triton
} // namespace mlir
#define GET_OP_CLASSES
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
@@ -48,50 +50,48 @@ namespace triton {
//-- StoreOp --
// Default mask
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) {
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
auto shape = ptrType.getShape();
::mlir::Value mask = builder.create<arith::ConstantOp>(
ptr.getLoc(),
RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true
)
);
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true));
state.addOperands(ptr);
state.addOperands(value);
state.addOperands(mask);
}
//-- LoadOp --
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr,
::mlir::triton::CacheModifier cache, ::mlir::triton::EvictionPolicy evict, bool isVolatile) {
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
TensorType ptrType = ptr.getType().dyn_cast<TensorType>();
Type elementType = ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
Type elementType =
ptrType.getElementType().dyn_cast<PointerType>().getPointeeType();
auto shape = ptrType.getShape();
// mask
::mlir::Value mask = builder.create<arith::ConstantOp>(
ptr.getLoc(),
RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true
)
);
ptr.getLoc(), RankedTensorType::get(shape, builder.getI1Type()),
DenseIntElementsAttr::get(
RankedTensorType::get(shape, builder.getI1Type()), true));
// other
Type resultType = RankedTensorType::get(shape, elementType);
::mlir::Value other = builder.create<arith::ConstantOp>(
ptr.getLoc(),
resultType,
DenseElementsAttr::get(
resultType, builder.getZeroAttr(elementType)
)
);
ptr.getLoc(), resultType,
DenseElementsAttr::get(resultType, builder.getZeroAttr(elementType)));
state.addOperands(ptr);
state.addOperands(mask);
state.addOperands(other);
state.addAttribute(cacheAttrName(state.name), ::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(evictAttrName(state.name), ::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name), builder.getBoolAttr(isVolatile));
state.addAttribute(
cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
state.addAttribute(
evictAttrName(state.name),
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
state.addAttribute(isVolatileAttrName(state.name),
builder.getBoolAttr(isVolatile));
state.addTypes({resultType});
}

View File

@@ -1,6 +1,6 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
using namespace mlir;
@@ -16,7 +16,7 @@ void TritonDialect::registerTypes() {
addTypes<
#define GET_TYPEDEF_LIST
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
>();
>();
}
Type PointerType::parse(AsmParser &parser) {

View File

@@ -17,21 +17,23 @@ namespace {
class CombineDotOp : public mlir::RewritePattern {
public:
CombineDotOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (llvm::isa<mlir::arith::AddIOp, mlir::arith::AddFOp>(op)) {
if (isCandidate(op->getOperand(0)).succeeded()) {
auto dotOp = op->getOperand(0).getDefiningOp<mlir::triton::DotOp>();
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
op, dotOp->getResultTypes().front(), dotOp.a(),
dotOp.b(), op->getOperand(1), dotOp.allowTF32());
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
op->getOperand(1), dotOp.allowTF32());
return mlir::success();
} else if (isCandidate(op->getOperand(1)).succeeded()) {
auto dotOp = op->getOperand(1).getDefiningOp<mlir::triton::DotOp>();
rewriter.replaceOpWithNewOp<mlir::triton::DotOp>(
op, dotOp->getResultTypes().front(), dotOp.a(),
dotOp.b(), op->getOperand(0), dotOp.allowTF32());
op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(),
op->getOperand(0), dotOp.allowTF32());
return mlir::success();
}
}
@@ -54,7 +56,7 @@ private:
return true;
// broadcast(constant_0)
if (auto bc = val.getDefiningOp<mlir::triton::BroadcastOp>()) {
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
if (mlir::matchPattern(bc.src(), mlir::m_Zero()) ||
mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat()))
return true;
}
@@ -68,18 +70,19 @@ private:
class CombineGEPOp : public mlir::RewritePattern {
public:
CombineGEPOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (llvm::isa<mlir::triton::GEPOp>(op)) {
if (auto gep2 = op->getOperand(0).getDefiningOp<mlir::triton::GEPOp>()) {
auto loc = op->getLoc();
mlir::Value newIdx = rewriter.create<mlir::arith::AddIOp>(
loc, op->getOperand(1), gep2->getOperand(1));
loc, op->getOperand(1), gep2->getOperand(1));
rewriter.replaceOpWithNewOp<mlir::triton::GEPOp>(
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx
);
op, op->getResultTypes().front(), gep2->getOperand(0), newIdx);
return mlir::success();
}
}
@@ -92,20 +95,21 @@ public:
class CombineSelectMaskedLoadOp : public mlir::RewritePattern {
public:
CombineSelectMaskedLoadOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, context) {}
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (llvm::isa<mlir::SelectOp>(op)) {
if (auto load = op->getOperand(1).getDefiningOp<mlir::triton::LoadOp>()) {
mlir::Value cond = op->getOperand(0);
if (auto bc = load.mask().getDefiningOp<mlir::triton::BroadcastOp>()) {
if (bc.src().getDefiningOp() == cond.getDefiningOp()) {
rewriter.replaceOpWithNewOp<mlir::triton::LoadOp>(
op, op->getResultTypes().front(),
load.ptr(), load.mask(), op->getOperand(2),
load.cache(), load.evict(), load.isVolatile()
);
op, op->getResultTypes().front(), load.ptr(), load.mask(),
op->getOperand(2), load.cache(), load.evict(),
load.isVolatile());
return mlir::success();
}
}
@@ -120,11 +124,11 @@ public:
class CombineBroadcastConstantOp : public mlir::RewritePattern {
public:
CombineBroadcastConstantOp(mlir::MLIRContext *context)
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
: mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1,
context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
PatternRewriter &rewriter) const override {
if (auto broadcast = llvm::dyn_cast<triton::BroadcastOp>(op)) {
if (auto cst = broadcast.src().getDefiningOp<arith::ConstantOp>()) {
Attribute value = cst.getValue();
@@ -132,15 +136,14 @@ public:
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
if (!denseValue.isSplat())
return failure();
value = DenseElementsAttr::get(resType, denseValue.getSplatValue<Attribute>());
value = DenseElementsAttr::get(resType,
denseValue.getSplatValue<Attribute>());
} else {
if (!value.isa<FloatAttr, IntegerAttr>())
return failure();
value = DenseElementsAttr::get(resType, value);
}
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, value, resType
);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, value, resType);
return success();
}
}

View File

@@ -11,19 +11,18 @@ using namespace mlir::triton::gpu;
// parse an array of integers
static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr,
/*SmallVector<unsigned, 2>*/auto &res,
StringRef desc) {
/*SmallVector<unsigned, 2>*/ auto &res,
StringRef desc) {
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr) {
parser.emitError(parser.getNameLoc(), "expected an array for ")
<< desc;
parser.emitError(parser.getNameLoc(), "expected an array for ") << desc;
return failure();
}
for (Attribute i : arrayAttr) {
auto intAttr = i.dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
<< desc;
<< desc;
return failure();
}
res.push_back(intAttr.getUInt());
@@ -46,7 +45,7 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
return {};
if (parser.parseGreater().failed())
return {};
SmallVector<unsigned, 2> threadTileSize;
SmallVector<unsigned, 2> warpTileSize;
SmallVector<unsigned, 2> blockTileSize;
@@ -55,19 +54,23 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "threadTileSize") {
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size").failed())
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size")
.failed())
return {};
} else if (attr.getName() == "warpTileSize") {
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size").failed())
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size")
.failed())
return {};
} else if (attr.getName() == "blockTileSize") {
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size").failed())
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size")
.failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else if (attr.getName() == "broadcastAxis") {
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis").failed())
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
@@ -76,12 +79,9 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
}
}
return parser.getChecked<TritonGPUBlockedEncodingAttr>(parser.getContext(),
threadTileSize,
warpTileSize,
blockTileSize,
order,
broadcastAxis);
return parser.getChecked<TritonGPUBlockedEncodingAttr>(
parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order,
broadcastAxis);
}
static void printBlocked(AsmPrinter &printer, auto *attr) {
@@ -94,8 +94,7 @@ static void printBlocked(AsmPrinter &printer, auto *attr) {
<< "}>";
}
Attribute
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
parseBlocked(parser, type);
}
@@ -103,8 +102,8 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printBlocked(printer, this);
}
Attribute
TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser,
Type type) {
parseBlocked(parser, type);
}
@@ -131,38 +130,37 @@ static Attribute parseMma(AsmParser &parser, Type type) {
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "fragmentPerWarp") {
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp").failed())
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp")
.failed())
return {};
} else if (attr.getName() == "shapePerWarp") {
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp").failed())
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp")
.failed())
return {};
} else if (attr.getName() == "warpPerTile") {
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
return {};
} else if (attr.getName() == "shapePerTile") {
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile").failed())
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile")
.failed())
return {};
} else if (attr.getName() == "repetitions") {
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
return {};
} else if (attr.getName() == "contigPerThread") {
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread").failed())
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
fragmentPerWarp,
shapePerWarp,
warpPerTile,
shapePerTile,
repetitions,
contigPerThread,
broadcastAxis);
return parser.getChecked<TritonGPUMmaEncodingAttr>(
parser.getContext(), fragmentPerWarp, shapePerWarp, warpPerTile,
shapePerTile, repetitions, contigPerThread, broadcastAxis);
}
static void printMma(AsmPrinter &printer, auto *attr) {
@@ -176,8 +174,7 @@ static void printMma(AsmPrinter &printer, auto *attr) {
<< "}>";
}
Attribute
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
return parseMma(parser, type);
}
@@ -185,8 +182,8 @@ void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
}
Attribute
TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser,
Type type) {
return parseMma(parser, type);
}
@@ -194,8 +191,7 @@ void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
}
Attribute
TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
@@ -210,8 +206,7 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
unsigned maxPhase = 0;
SmallVector<unsigned, 2> order;
auto parseUInt = [&parser](const NamedAttribute &attr,
unsigned &value,
auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value,
StringRef desc) -> LogicalResult {
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) {
@@ -237,29 +232,25 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<TritonGPUSharedEncodingAttr>(parser.getContext(),
vec,
perPhase,
maxPhase,
order);
return parser.getChecked<TritonGPUSharedEncodingAttr>(
parser.getContext(), vec, perPhase, maxPhase, order);
}
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "vec = " << getVec()
<< ", perPhase = " << getPerPhase()
<< ", maxPhase = " << getMaxPhase()
<< ", order = [" << getOrder() << "]"
<< "vec = " << getVec() << ", perPhase = " << getPerPhase()
<< ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder()
<< "]"
<< "}>";
}
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
public:
public:
using OpAsmDialectInterface::OpAsmDialectInterface;
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
@@ -289,7 +280,7 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
OpAsmDialectInterface::getAlias(attr, os);
}
private:
private:
static void printMma(const auto &attr, raw_ostream &os) {
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
@@ -338,7 +329,7 @@ void TritonGPUDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
>();
>();
addInterfaces<TritonGPUOpAsmInterface>();
}
@@ -349,7 +340,8 @@ namespace triton {
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return Type();
}
@@ -368,8 +360,8 @@ static Type getPointeeType(Type type) {
return Type();
}
}
}
} // namespace triton
} // namespace mlir
static LogicalResult verify(CopyAsyncOp op) {
Type resType = op.getResult().getType();
@@ -385,11 +377,9 @@ static LogicalResult verify(CopyAsyncOp op) {
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
// verify TritonGPU ops
LogicalResult
TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// TODO: fill this.
return success();
}

View File

@@ -27,8 +27,8 @@ namespace {
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUCombineOpsPass
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
class TritonGPUCombineOpsPass
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();

View File

@@ -6,12 +6,11 @@
//===----------------------------------------------------------------------===//
//
// This file implements loop software pipelining
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
// The implementation here is inspired by the pipeline pass in Triton (-v2.0)
// and SCF's LoopPipelining.
//
//===----------------------------------------------------------------------===//
using namespace mlir;
#define GEN_PASS_CLASSES
@@ -41,14 +40,15 @@ class LoopPipeliner {
/// Block arguments that loads depend on
DenseSet<BlockArgument> depArgs;
/// Operations (inside the loop body) that loads depend on
DenseSet<Operation*> depOps;
DenseSet<Operation *> depOps;
/// collect values that v depends on and are defined inside the loop
void collectDeps(Value v, int stages, DenseSet<Value> &deps);
void setValueMapping(Value origin, Value newValue, int stage);
public:
LoopPipeliner(scf::ForOp forOp, int numStages)
LoopPipeliner(scf::ForOp forOp, int numStages)
: forOp(forOp), numStages(numStages) {
// cache yieldOp
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
@@ -86,7 +86,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
if (auto arg = v.dyn_cast<BlockArgument>()) {
deps.insert(v);
// Note: we have iv as the first arg, so the op idx is arg.getArgNumber()-1
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages-1, deps);
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
} else { // value
// v might be in deps, but we still need to visit v.
// This is because v might depends on value in previous iterations
@@ -123,8 +123,8 @@ LogicalResult LoopPipeliner::initialize() {
}
// for (triton::LoadOp loadOp : allLoads) {
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() << " values\n";
// for (Value dep : loadDeps[loadOp])
// llvm::errs() << loadOp << " depends on: #" << loadDeps[loadOp].size() <<
// " values\n"; for (Value dep : loadDeps[loadOp])
// llvm::errs() << dep << "\n";
// llvm::errs() << "\n";
// }
@@ -147,9 +147,13 @@ LogicalResult LoopPipeliner::initialize() {
if (isCandiate && loadOp.getResult().hasOneUse()) {
isCandiate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout = llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding().isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
if (auto convertLayout =
llvm::dyn_cast<triton::gpu::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (tensorType.getEncoding()
.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
isCandiate = true;
loadsMapping[loadOp] = convertLayout;
}
@@ -162,7 +166,6 @@ LogicalResult LoopPipeliner::initialize() {
loads.insert(loadOp);
}
// we have some loads to pipeline
if (!loads.empty()) {
// update depArgs & depOps
@@ -202,10 +205,10 @@ void LoopPipeliner::emitPrologue() {
// special handling for loop condition as there is no condition in ForOp
Value loopCond = builder.create<arith::CmpIOp>(
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound());
// rematerialize peeled values
SmallVector<Operation*> orderedDeps;
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
@@ -221,10 +224,9 @@ void LoopPipeliner::emitPrologue() {
// TODO: check if the hardware supports copyasync
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
newOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[loadOp].getType(),
loadOp.ptr(), loadOp.mask(), loadOp.other(),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
op->getLoc(), loadsMapping[loadOp].getType(), loadOp.ptr(),
loadOp.mask(), loadOp.other(), loadOp.cache(), loadOp.evict(),
loadOp.isVolatile());
} else
llvm_unreachable("This should be LoadOp");
} else
@@ -245,12 +247,10 @@ void LoopPipeliner::emitPrologue() {
// assert(I1 or TensorOf<[I1]>);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(newOp);
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
mask.getType(),
loopCond);
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
mask,
splatCond);
Value splatCond = builder.create<triton::BroadcastOp>(
mask.getLoc(), mask.getType(), loopCond);
Value newMask =
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
newOp->setOperand(1, newMask);
}
@@ -264,8 +264,9 @@ void LoopPipeliner::emitPrologue() {
// update mapping for loop-carried values (args)
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == op->getResult(dstIdx))
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1);
setValueMapping(
forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(dstIdx), stage + 1);
}
}
}
@@ -296,21 +297,19 @@ scf::ForOp LoopPipeliner::createNewForOp() {
size_t depArgsBeginIdx = newLoopArgs.size();
for (BlockArgument depArg : depArgs) {
depArgsIdx[depArg] = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[depArg][numStages-1]);
newLoopArgs.push_back(valueMapping[depArg][numStages - 1]);
}
size_t nextIVIdx = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]);
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
for (size_t i = 0; i < newLoopArgs.size(); ++i)
assert(newLoopArgs[i]);
// 1. signature of the new ForOp
auto newForOp = builder.create<scf::ForOp>(forOp.getLoc(),
forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(),
newLoopArgs);
auto newForOp = builder.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newLoopArgs);
// 2. body of the new ForOp
builder.setInsertionPointToStart(newForOp.getBody());
@@ -329,15 +328,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// 3. replace loads with block args (from prologue)
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
assert(load.hasOneUse() && "we assume that this load has one use (ConvertLayout)");
assert(load.hasOneUse() &&
"we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
mapping.lookup(loadUse).replaceAllUsesWith(
newForOp.getRegionIterArgs()[loadIdx + idx*(numStages-1)]);
newForOp.getRegionIterArgs()[loadIdx + idx * (numStages - 1)]);
}
// 4. prefetch the next iteration
SmallVector<Operation*> orderedDeps;
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
@@ -350,41 +349,39 @@ scf::ForOp LoopPipeliner::createNewForOp() {
DenseMap<BlockArgument, Value> depArgsMapping;
size_t argIdx = 0;
for (BlockArgument arg : depArgs) {
nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
nextMapping.map(arg,
newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]);
++argIdx;
}
// special handling for iv & loop condition
Value nextIV = builder.create<arith::AddIOp>(newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx],
newForOp.getStep());
Value nextLoopCond = builder.create<arith::CmpIOp>(
nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
Value nextIV = builder.create<arith::AddIOp>(
newForOp.getInductionVar().getLoc(),
newForOp.getRegionIterArgs()[nextIVIdx], newForOp.getStep());
Value nextLoopCond =
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, newForOp.getUpperBound());
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
// update loading mask
if (loads.contains(op->getResult(0))) {
auto loadOp = llvm::cast<triton::LoadOp>(op);
Value mask = loadOp.mask();
Value splatCond = builder.create<triton::BroadcastOp>(mask.getLoc(),
mask.getType(),
nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(mask.getLoc(),
splatCond,
nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than once
Value splatCond = builder.create<triton::BroadcastOp>(
mask.getLoc(), mask.getType(), nextLoopCond);
Value newMask = builder.create<arith::AndIOp>(
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
// if mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(mask, newMask);
// TODO: more elegant way to do this?
nextOp = builder.create<triton::gpu::CopyAsyncOp>(
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()),
loadOp.cache(), loadOp.evict(), loadOp.isVolatile()
);
}
else
op->getLoc(), loadsMapping[op->getResult(0)].getType(),
nextMapping.lookupOrDefault(loadOp.ptr()),
nextMapping.lookupOrDefault(loadOp.mask()),
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
loadOp.evict(), loadOp.isVolatile());
} else
nextOp = builder.clone(*op, nextMapping);
// llvm::errs() << "epilogue cloning...: " << *op << "\n";
// update mapping of results
@@ -411,15 +408,16 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (size_t idx = 0; idx < loads.size(); ++idx) {
Value load = loads[idx];
for (int stage = 1; stage < numStages - 1; ++stage) {
yieldValues.push_back(newForOp.getRegionIterArgs()[
loadIdx + idx*(numStages-1) + stage
]);
yieldValues.push_back(
newForOp
.getRegionIterArgs()[loadIdx + idx * (numStages - 1) + stage]);
}
yieldValues.push_back(nextMapping.lookup(load));
}
for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i)
yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
yieldValues.push_back(
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
yieldValues.push_back(nextIV);
builder.setInsertionPointToEnd(newForOp.getBody());
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
@@ -430,9 +428,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
PipelinePass() = default;
PipelinePass(int numStages) {
this->numStages = numStages;
}
PipelinePass(int numStages) { this->numStages = numStages; }
void runOnOperation() override {
int numStages = this->numStages;

View File

@@ -1,7 +1,7 @@
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include <algorithm>
using namespace mlir;
@@ -10,7 +10,7 @@ using namespace mlir::triton::gpu;
//
// TypeConverter
//
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numThreads)
: context(context), numThreads(numThreads) {
// TODO: how does MLIR pick the right conversion?
@@ -38,14 +38,14 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// or assert no encoding?
// Now we assume:
// contiguous = 1, order = 0, 1, 2, ...,
// contiguous = 1, order = 0, 1, 2, ...,
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank);
llvm::SmallVector<unsigned> broadcastAxis;
int remainingThreads = numThreads;
int remainingLanes = /*warp size*/32;
int remainingLanes = /*warp size*/ 32;
for (int64_t dim = 0; dim < rank; ++dim) {
blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim]));
warpTileSize[dim] = std::clamp(remainingLanes, 1, int(shape[dim]));
@@ -56,7 +56,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// TODO: will we need repetition?
}
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis);
context, threadTileSize, warpTileSize, blockTileSize, order,
broadcastAxis);
return RankedTensorType::get(shape, elementType, encoding);
});
@@ -65,8 +66,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
//
// This will be called when (newArgType != origArgType)
// This will create newArg, and map(origArg, newArg)
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -74,7 +76,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// If the origValue still has live user(s), use this to
// convert origValue to newValue
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -83,7 +85,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// where, desiredType = typeConverter->convertType(origType)
// NOTE: only for remapped values.
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
ValueRange inputs, Location loc) {
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -93,30 +95,31 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// TritonGPUConversion
//
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
// TODO: we should also verify ops of TritonGPUDialect
addLegalDialect<triton::gpu::TritonGPUDialect>();
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
scf::ReduceOp, scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect,
triton::TritonDialect,
StandardOpsDialect,
scf::SCFDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect, triton::TritonDialect,
StandardOpsDialect, scf::SCFDialect>(
[&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
Attribute aEncoding =
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
Attribute bEncoding =
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
if (aEncoding &&
aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return true;
// // TODO: we should delete this
@@ -124,7 +127,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
// return true;
return false;
});
}
// %dst = tt.broadcast %src
@@ -133,12 +135,10 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
// %bcst = tt.broadcast %newSrc
// %dst = convert_layout %bcst
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
int numThreads) {
int numThreads) {
// collect broadcasts
SmallVector<triton::BroadcastOp> broadcasts;
mod.walk([&](triton::BroadcastOp op) {
broadcasts.push_back(op);
});
mod.walk([&](triton::BroadcastOp op) { broadcasts.push_back(op); });
BlockAndValueMapping mapping;
for (auto broadcast : broadcasts) {
@@ -161,20 +161,14 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
broadcastAxis.push_back(ax);
Attribute originSrcEnc = tensorType.getEncoding();
if (auto blockedEnc = originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
if (auto blockedEnc =
originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
blockedEnc.getContext(),
blockedEnc.getThreadTileSize(),
blockedEnc.getWarpTileSize(),
blockedEnc.getBlockTileSize(),
blockedEnc.getOrder(),
broadcastAxis
);
blockedEnc.getContext(), blockedEnc.getThreadTileSize(),
blockedEnc.getWarpTileSize(), blockedEnc.getBlockTileSize(),
blockedEnc.getOrder(), broadcastAxis);
newSrcType = RankedTensorType::get(
tensorType.getShape(),
tensorType.getElementType(),
newSrcEnc
);
tensorType.getShape(), tensorType.getElementType(), newSrcEnc);
} else
llvm_unreachable("src of broadcast should have blocked encoding");
} else {
@@ -186,34 +180,25 @@ LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
// create new src
if (!isSrcScalar) // we don't need to convert layout for scalar values
src = builder.create<triton::gpu::ConvertLayoutOp>(
src.getLoc(), newSrcType, src
);
src = builder.create<triton::gpu::ConvertLayoutOp>(src.getLoc(),
newSrcType, src);
// create new broadcast
// compute new type (encoding)
auto originDstEnc = originDstTensorType.getEncoding()
.dyn_cast<TritonGPUBlockedEncodingAttr>();
.dyn_cast<TritonGPUBlockedEncodingAttr>();
auto newEnc = TritonGPUBlockedEncodingAttr::get(
originDstEnc.getContext(),
originDstEnc.getThreadTileSize(),
originDstEnc.getWarpTileSize(),
originDstEnc.getBlockTileSize(),
originDstEnc.getOrder(),
broadcastAxis
);
auto newType = RankedTensorType::get(
originDstTensorType.getShape(),
originDstTensorType.getElementType(),
newEnc
);
Value newBroadcast = builder.create<triton::BroadcastOp>(
broadcast.getLoc(), newType, src
);
originDstEnc.getContext(), originDstEnc.getThreadTileSize(),
originDstEnc.getWarpTileSize(), originDstEnc.getBlockTileSize(),
originDstEnc.getOrder(), broadcastAxis);
auto newType =
RankedTensorType::get(originDstTensorType.getShape(),
originDstTensorType.getElementType(), newEnc);
Value newBroadcast =
builder.create<triton::BroadcastOp>(broadcast.getLoc(), newType, src);
// we don't want to change the encoding of the result
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
broadcast.getLoc(), originDstType, newBroadcast
);
broadcast.getLoc(), originDstType, newBroadcast);
broadcast.replaceAllUsesWith(newDst);
mapping.map(broadcast, newDst);

View File

@@ -5,7 +5,6 @@
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
@@ -37,28 +36,30 @@ private:
if (!encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return dotOp.emitError() << name << " should be of shared layout";
} else
return dotOp.emitError() << name << "'s type should be of RankedTensorType";
return dotOp.emitError()
<< name << "'s type should be of RankedTensorType";
}
Attribute cLayout;
for (auto it : llvm::zip(llvm::SmallVector<Type>{cType, dType},
llvm::SmallVector<char>{'c', 'd'})) {
llvm::SmallVector<char>{'c', 'd'})) {
Type type = std::get<0>(it);
char name = std::get<1>(it);
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
if (!encoding)
return dotOp.emitError() << name << " should have encoding";
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
!encoding.isa<triton::gpu::TritonGPUBlockedEncodingAttr>())
return dotOp.emitError() << name << " should be of distributed layout";
return dotOp.emitError()
<< name << " should be of distributed layout";
if (name == 'c')
cLayout = encoding;
else if (encoding != cLayout)
return dotOp.emitError() << "d & c should have the same layout";
} else
return dotOp.emitError() << name
<< "'s type should be of RankedTensorType";
return dotOp.emitError()
<< name << "'s type should be of RankedTensorType";
}
// signalPassFailure();
@@ -89,7 +90,7 @@ private:
}
void verifyImpl(Operation *op) {
if(verifySingleOp(op).failed())
if (verifySingleOp(op).failed())
signalPassFailure();
// verify that all child regions are ok