[CI] run clang-format (#24)
This commit is contained in:
@@ -11,19 +11,18 @@ using namespace mlir::triton::gpu;
|
||||
// parse an array of integers
|
||||
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||
const NamedAttribute &attr,
|
||||
/*SmallVector<unsigned, 2>*/auto &res,
|
||||
StringRef desc) {
|
||||
/*SmallVector<unsigned, 2>*/ auto &res,
|
||||
StringRef desc) {
|
||||
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
||||
if (!arrayAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an array for ")
|
||||
<< desc;
|
||||
parser.emitError(parser.getNameLoc(), "expected an array for ") << desc;
|
||||
return failure();
|
||||
}
|
||||
for (Attribute i : arrayAttr) {
|
||||
auto intAttr = i.dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
|
||||
<< desc;
|
||||
<< desc;
|
||||
return failure();
|
||||
}
|
||||
res.push_back(intAttr.getUInt());
|
||||
@@ -46,7 +45,7 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
|
||||
SmallVector<unsigned, 2> threadTileSize;
|
||||
SmallVector<unsigned, 2> warpTileSize;
|
||||
SmallVector<unsigned, 2> blockTileSize;
|
||||
@@ -55,19 +54,23 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "threadTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size").failed())
|
||||
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "warpTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size").failed())
|
||||
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "blockTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size").failed())
|
||||
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "order") {
|
||||
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "broadcastAxis") {
|
||||
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis").failed())
|
||||
if (parseIntArrayAttr(parser, attr, broadcastAxis, "broadcastAxis")
|
||||
.failed())
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
@@ -76,12 +79,9 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUBlockedEncodingAttr>(parser.getContext(),
|
||||
threadTileSize,
|
||||
warpTileSize,
|
||||
blockTileSize,
|
||||
order,
|
||||
broadcastAxis);
|
||||
return parser.getChecked<TritonGPUBlockedEncodingAttr>(
|
||||
parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order,
|
||||
broadcastAxis);
|
||||
}
|
||||
|
||||
static void printBlocked(AsmPrinter &printer, auto *attr) {
|
||||
@@ -94,8 +94,7 @@ static void printBlocked(AsmPrinter &printer, auto *attr) {
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
parseBlocked(parser, type);
|
||||
}
|
||||
|
||||
@@ -103,8 +102,8 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printBlocked(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser,
|
||||
Type type) {
|
||||
parseBlocked(parser, type);
|
||||
}
|
||||
|
||||
@@ -131,38 +130,37 @@ static Attribute parseMma(AsmParser &parser, Type type) {
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "fragmentPerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp").failed())
|
||||
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "shapePerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp").failed())
|
||||
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "warpPerTile") {
|
||||
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "shapePerTile") {
|
||||
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile").failed())
|
||||
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile")
|
||||
.failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "repetitions") {
|
||||
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "contigPerThread") {
|
||||
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread").failed())
|
||||
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread")
|
||||
.failed())
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
<< attr.getName().strref();
|
||||
<< attr.getName().strref();
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
|
||||
fragmentPerWarp,
|
||||
shapePerWarp,
|
||||
warpPerTile,
|
||||
shapePerTile,
|
||||
repetitions,
|
||||
contigPerThread,
|
||||
broadcastAxis);
|
||||
return parser.getChecked<TritonGPUMmaEncodingAttr>(
|
||||
parser.getContext(), fragmentPerWarp, shapePerWarp, warpPerTile,
|
||||
shapePerTile, repetitions, contigPerThread, broadcastAxis);
|
||||
}
|
||||
|
||||
static void printMma(AsmPrinter &printer, auto *attr) {
|
||||
@@ -176,8 +174,7 @@ static void printMma(AsmPrinter &printer, auto *attr) {
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return parseMma(parser, type);
|
||||
}
|
||||
|
||||
@@ -185,8 +182,8 @@ void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printMma(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser,
|
||||
Type type) {
|
||||
return parseMma(parser, type);
|
||||
}
|
||||
|
||||
@@ -194,8 +191,7 @@ void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printMma(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
@@ -210,8 +206,7 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
unsigned maxPhase = 0;
|
||||
SmallVector<unsigned, 2> order;
|
||||
|
||||
auto parseUInt = [&parser](const NamedAttribute &attr,
|
||||
unsigned &value,
|
||||
auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value,
|
||||
StringRef desc) -> LogicalResult {
|
||||
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
@@ -237,29 +232,25 @@ TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
<< attr.getName().strref();
|
||||
<< attr.getName().strref();
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUSharedEncodingAttr>(parser.getContext(),
|
||||
vec,
|
||||
perPhase,
|
||||
maxPhase,
|
||||
order);
|
||||
return parser.getChecked<TritonGPUSharedEncodingAttr>(
|
||||
parser.getContext(), vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "vec = " << getVec()
|
||||
<< ", perPhase = " << getPerPhase()
|
||||
<< ", maxPhase = " << getMaxPhase()
|
||||
<< ", order = [" << getOrder() << "]"
|
||||
<< "vec = " << getVec() << ", perPhase = " << getPerPhase()
|
||||
<< ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder()
|
||||
<< "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
|
||||
public:
|
||||
public:
|
||||
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||
|
||||
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
|
||||
@@ -289,7 +280,7 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
|
||||
OpAsmDialectInterface::getAlias(attr, os);
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
static void printMma(const auto &attr, raw_ostream &os) {
|
||||
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
|
||||
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
|
||||
@@ -338,7 +329,7 @@ void TritonGPUDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
>();
|
||||
>();
|
||||
addInterfaces<TritonGPUOpAsmInterface>();
|
||||
}
|
||||
|
||||
@@ -349,7 +340,8 @@ namespace triton {
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
}
|
||||
|
||||
@@ -368,8 +360,8 @@ static Type getPointeeType(Type type) {
|
||||
return Type();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
static LogicalResult verify(CopyAsyncOp op) {
|
||||
Type resType = op.getResult().getType();
|
||||
@@ -385,11 +377,9 @@ static LogicalResult verify(CopyAsyncOp op) {
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
|
||||
|
||||
// verify TritonGPU ops
|
||||
LogicalResult
|
||||
TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) {
|
||||
LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) {
|
||||
// TODO: fill this.
|
||||
return success();
|
||||
}
|
||||
|
Reference in New Issue
Block a user