2022-08-18 12:49:37 -07:00
|
|
|
#include <numeric>
|
|
|
|
|
2022-05-01 22:06:54 +08:00
|
|
|
#include "mlir/IR/DialectImplementation.h"
|
2022-07-26 10:50:11 -07:00
|
|
|
#include "mlir/IR/OpImplementation.h"
|
2022-08-18 12:49:37 -07:00
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
2022-05-01 22:06:54 +08:00
|
|
|
#include "llvm/ADT/TypeSwitch.h"
|
2022-04-28 18:51:31 +08:00
|
|
|
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
|
|
|
|
2022-05-31 11:43:21 +00:00
|
|
|
using namespace mlir;
|
2022-04-28 18:51:31 +08:00
|
|
|
using namespace mlir::triton::gpu;
|
|
|
|
|
2022-08-24 12:55:49 -07:00
|
|
|
// Utility
|
|
|
|
namespace mlir {
|
|
|
|
namespace triton {
|
|
|
|
|
|
|
|
// Type inference
|
|
|
|
static Type getI1SameShape(Type type) {
|
|
|
|
auto i1Type = IntegerType::get(type.getContext(), 1);
|
|
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
|
|
|
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
|
|
|
tensorType.getEncoding());
|
|
|
|
return Type();
|
|
|
|
}
|
|
|
|
|
|
|
|
static Type getPointeeType(Type type) {
|
|
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
|
|
|
// Tensor of pointers
|
|
|
|
auto shape = tensorType.getShape();
|
|
|
|
auto ptrType = tensorType.getElementType().dyn_cast<PointerType>();
|
|
|
|
Type pointeeType = ptrType.getPointeeType();
|
|
|
|
return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding());
|
|
|
|
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
|
|
|
|
// scalar pointer
|
|
|
|
Type pointeeType = ptrType.getPointeeType();
|
|
|
|
return pointeeType;
|
|
|
|
}
|
|
|
|
return Type();
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace triton
|
|
|
|
} // namespace mlir
|
|
|
|
|
2022-08-11 21:20:47 -07:00
|
|
|
static LogicalResult parseIntAttrValue(AsmParser &parser, const Attribute &attr,
|
|
|
|
unsigned &value, StringRef desc) {
|
|
|
|
auto intAttr = attr.dyn_cast<IntegerAttr>();
|
|
|
|
if (!intAttr) {
|
|
|
|
parser.emitError(parser.getNameLoc(), "expected an integer type in ")
|
|
|
|
<< desc;
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
if (intAttr.getType().isSignedInteger()) {
|
|
|
|
int64_t attrVal = intAttr.getSInt();
|
|
|
|
if (attrVal < 0) {
|
|
|
|
parser.emitError(parser.getNameLoc(),
|
|
|
|
"expected an unsigned integer value in ")
|
|
|
|
<< desc;
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
value = attrVal;
|
|
|
|
} else if (intAttr.getType().isSignlessInteger()) {
|
|
|
|
int64_t attrVal = intAttr.getInt();
|
|
|
|
if (attrVal < 0) {
|
|
|
|
parser.emitError(parser.getNameLoc(),
|
|
|
|
"expected an unsigned integer value in ")
|
|
|
|
<< desc;
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
value = attrVal;
|
|
|
|
} else {
|
|
|
|
value = intAttr.getUInt();
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-05-31 11:43:21 +00:00
|
|
|
// parse an array of integers
|
|
|
|
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
|
|
|
const NamedAttribute &attr,
|
2022-07-27 01:32:10 -07:00
|
|
|
SmallVector<unsigned, 2> &res,
|
2022-07-26 17:25:03 -07:00
|
|
|
StringRef desc) {
|
2022-05-31 11:43:21 +00:00
|
|
|
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
|
|
|
if (!arrayAttr) {
|
2022-07-26 17:25:03 -07:00
|
|
|
parser.emitError(parser.getNameLoc(), "expected an array for ") << desc;
|
2022-05-31 11:43:21 +00:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
for (Attribute i : arrayAttr) {
|
2022-08-11 21:20:47 -07:00
|
|
|
unsigned value;
|
|
|
|
if (parseIntAttrValue(parser, i, value, desc).failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return failure();
|
2022-08-11 21:20:47 -07:00
|
|
|
res.push_back(value);
|
2022-05-31 11:43:21 +00:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
};
|
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
|
|
|
|
unsigned &value, StringRef desc) {
|
2022-08-11 21:20:47 -07:00
|
|
|
return parseIntAttrValue(parser, attr.getValue(), value, desc);
|
2022-07-31 13:59:44 -07:00
|
|
|
};
|
|
|
|
|
2022-05-01 22:06:54 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Attribute methods
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_ATTRDEF_CLASSES
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
|
|
|
|
return SliceEncodingAttr::get(getContext(), axis, *this);
|
|
|
|
}
|
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Blocked Encoding
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
2022-05-31 11:43:21 +00:00
|
|
|
if (parser.parseLess().failed())
|
|
|
|
return {};
|
|
|
|
// Parse the data as a dictionary
|
|
|
|
DictionaryAttr dict;
|
|
|
|
if (parser.parseAttribute(dict).failed())
|
|
|
|
return {};
|
|
|
|
if (parser.parseGreater().failed())
|
|
|
|
return {};
|
2022-07-26 17:25:03 -07:00
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
SmallVector<unsigned, 2> sizePerThread;
|
|
|
|
SmallVector<unsigned, 2> threadsPerWarp;
|
|
|
|
SmallVector<unsigned, 2> warpsPerCTA;
|
2022-05-31 11:43:21 +00:00
|
|
|
SmallVector<unsigned, 2> order;
|
|
|
|
|
|
|
|
for (const NamedAttribute &attr : dict) {
|
2022-07-31 13:59:44 -07:00
|
|
|
if (attr.getName() == "sizePerThread") {
|
|
|
|
if (parseIntArrayAttr(parser, attr, sizePerThread,
|
|
|
|
"number of elements per thread")
|
2022-07-26 17:25:03 -07:00
|
|
|
.failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
2022-07-31 13:59:44 -07:00
|
|
|
} else if (attr.getName() == "threadsPerWarp") {
|
|
|
|
if (parseIntArrayAttr(parser, attr, threadsPerWarp,
|
|
|
|
"number of threads per warp")
|
2022-07-26 17:25:03 -07:00
|
|
|
.failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
2022-07-31 13:59:44 -07:00
|
|
|
} else if (attr.getName() == "warpsPerCTA") {
|
|
|
|
if (parseIntArrayAttr(parser, attr, warpsPerCTA,
|
|
|
|
"number of warps per CTA")
|
2022-07-26 17:25:03 -07:00
|
|
|
.failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
|
|
|
} else if (attr.getName() == "order") {
|
|
|
|
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
|
|
|
return {};
|
|
|
|
} else {
|
|
|
|
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
|
|
|
<< attr.getName().strref();
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
return parser.getChecked<BlockedEncodingAttr>(
|
2022-07-31 13:59:44 -07:00
|
|
|
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
2022-06-05 14:25:09 +08:00
|
|
|
printer << "<{"
|
2022-07-31 13:59:44 -07:00
|
|
|
<< "sizePerThread = [" << getSizePerThread() << "]"
|
|
|
|
<< ", threadsPerWarp = [" << getThreadsPerWarp() << "]"
|
|
|
|
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]"
|
|
|
|
<< ", order = [" << getOrder() << "]"
|
2022-06-05 14:25:09 +08:00
|
|
|
<< "}>";
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MMA encoding
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-06-18 21:16:45 +08:00
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
Attribute MmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
2022-06-06 21:03:58 +08:00
|
|
|
if (parser.parseLess().failed())
|
|
|
|
return {};
|
|
|
|
DictionaryAttr dict;
|
|
|
|
if (parser.parseAttribute(dict).failed())
|
|
|
|
return {};
|
|
|
|
if (parser.parseGreater().failed())
|
|
|
|
return {};
|
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
unsigned version = 0;
|
|
|
|
SmallVector<unsigned, 2> warpsPerCTA;
|
2022-06-06 21:03:58 +08:00
|
|
|
|
|
|
|
for (const NamedAttribute &attr : dict) {
|
2022-07-31 13:59:44 -07:00
|
|
|
if (attr.getName() == "version") {
|
|
|
|
if (parseUInt(parser, attr, version, "version").failed())
|
2022-06-06 21:03:58 +08:00
|
|
|
return {};
|
2022-07-31 13:59:44 -07:00
|
|
|
}
|
|
|
|
if (attr.getName() == "warpsPerCTA") {
|
|
|
|
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
|
2022-06-06 21:03:58 +08:00
|
|
|
return {};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
return parser.getChecked<MmaEncodingAttr>(parser.getContext(), version,
|
|
|
|
warpsPerCTA);
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
void MmaEncodingAttr::print(AsmPrinter &printer) const {
|
2022-06-06 21:03:58 +08:00
|
|
|
printer << "<{"
|
2022-07-31 13:59:44 -07:00
|
|
|
<< "version = " << getVersion() << ", "
|
|
|
|
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
|
2022-06-06 21:03:58 +08:00
|
|
|
<< "}>";
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sliced Encoding
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
|
|
|
|
if (parser.parseLess().failed())
|
|
|
|
return {};
|
|
|
|
// Parse the data as a dictionary
|
|
|
|
DictionaryAttr dict;
|
|
|
|
if (parser.parseAttribute(dict).failed())
|
|
|
|
return {};
|
|
|
|
if (parser.parseGreater().failed())
|
|
|
|
return {};
|
|
|
|
|
|
|
|
unsigned dim = 0;
|
|
|
|
Attribute parent;
|
|
|
|
|
|
|
|
for (const NamedAttribute &attr : dict) {
|
|
|
|
if (attr.getName() == "dim") {
|
|
|
|
if (parseUInt(parser, attr, dim, "dim").failed())
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
if (attr.getName() == "parent") {
|
|
|
|
if (parser.parseAttribute(parent).failed())
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
|
|
|
|
}
|
|
|
|
|
|
|
|
void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
|
|
|
printer << "<{"
|
|
|
|
<< "dim = " << getDim() << ", "
|
|
|
|
<< "parent = " << getParent() << "}>";
|
|
|
|
}
|
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Shared encoding
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-06-18 21:16:45 +08:00
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
Attribute SharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
2022-05-31 11:43:21 +00:00
|
|
|
if (parser.parseLess().failed())
|
|
|
|
return {};
|
|
|
|
// Parse the data as a dictionary
|
|
|
|
DictionaryAttr dict;
|
|
|
|
if (parser.parseAttribute(dict).failed())
|
|
|
|
return {};
|
|
|
|
if (parser.parseGreater().failed())
|
|
|
|
return {};
|
|
|
|
|
|
|
|
unsigned vec = 0;
|
|
|
|
unsigned perPhase = 0;
|
|
|
|
unsigned maxPhase = 0;
|
|
|
|
SmallVector<unsigned, 2> order;
|
|
|
|
|
|
|
|
for (const NamedAttribute &attr : dict) {
|
|
|
|
if (attr.getName() == "vec") {
|
2022-07-31 13:59:44 -07:00
|
|
|
if (parseUInt(parser, attr, vec, "vec").failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
|
|
|
} else if (attr.getName() == "perPhase") {
|
2022-07-31 13:59:44 -07:00
|
|
|
if (parseUInt(parser, attr, perPhase, "perPhase").failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
|
|
|
} else if (attr.getName() == "maxPhase") {
|
2022-07-31 13:59:44 -07:00
|
|
|
if (parseUInt(parser, attr, maxPhase, "maxPhase").failed())
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
|
|
|
} else if (attr.getName() == "order") {
|
|
|
|
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
|
|
|
return {};
|
|
|
|
} else {
|
|
|
|
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
2022-07-26 17:25:03 -07:00
|
|
|
<< attr.getName().strref();
|
2022-05-31 11:43:21 +00:00
|
|
|
return {};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
return parser.getChecked<SharedEncodingAttr>(parser.getContext(), vec,
|
|
|
|
perPhase, maxPhase, order);
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
void SharedEncodingAttr::print(AsmPrinter &printer) const {
|
2022-06-05 14:25:09 +08:00
|
|
|
printer << "<{"
|
2022-07-26 17:25:03 -07:00
|
|
|
<< "vec = " << getVec() << ", perPhase = " << getPerPhase()
|
|
|
|
<< ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder()
|
|
|
|
<< "]"
|
2022-06-05 14:25:09 +08:00
|
|
|
<< "}>";
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
|
|
|
|
2022-08-24 12:55:49 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// CopyAsyncOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
ParseResult parseCopyAsyncOp(OpAsmParser &parser, OperationState &result) {
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
|
|
|
Type resultTypes[1], ptrType;
|
|
|
|
SMLoc allOperandLoc = parser.getCurrentLocation();
|
|
|
|
if (parser.parseOperandList(allOperands) ||
|
|
|
|
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
|
|
|
parser.parseCustomTypeWithFallback(ptrType) || parser.parseArrow() ||
|
|
|
|
parser.parseCustomTypeWithFallback(resultTypes[0]))
|
|
|
|
return failure();
|
|
|
|
result.addTypes(resultTypes);
|
|
|
|
|
|
|
|
SmallVector<Type> operandTypes;
|
|
|
|
operandTypes.push_back(ptrType); // ptr
|
|
|
|
if (allOperands.size() >= 2)
|
|
|
|
operandTypes.push_back(triton::getI1SameShape(ptrType)); // mask
|
|
|
|
if (allOperands.size() >= 3)
|
|
|
|
operandTypes.push_back(triton::getPointeeType(ptrType)); // other
|
|
|
|
|
|
|
|
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
|
|
|
result.operands))
|
|
|
|
return failure();
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
void printCopyAsyncOp(OpAsmPrinter &printer, CopyAsyncOp copyAsyncOp) {
|
|
|
|
printer << " ";
|
|
|
|
printer << copyAsyncOp.getOperation()->getOperands();
|
|
|
|
printer.printOptionalAttrDict(copyAsyncOp->getAttrs(), /*elidedAttrs=*/{});
|
|
|
|
printer << " : ";
|
|
|
|
printer.printStrippedAttrOrType(copyAsyncOp.ptr().getType());
|
|
|
|
printer << " -> ";
|
|
|
|
printer.printStrippedAttrOrType(copyAsyncOp.result().getType());
|
|
|
|
}
|
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ASM Interface (i.e.: alias)
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-07-26 10:50:11 -07:00
|
|
|
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
|
2022-07-26 17:25:03 -07:00
|
|
|
public:
|
2022-07-26 10:50:11 -07:00
|
|
|
using OpAsmDialectInterface::OpAsmDialectInterface;
|
|
|
|
|
|
|
|
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
|
2022-08-18 12:49:37 -07:00
|
|
|
if (auto mmaAttr = attr.dyn_cast<MmaEncodingAttr>()) {
|
2022-07-26 10:50:11 -07:00
|
|
|
os << "mma";
|
|
|
|
return AliasResult::FinalAlias;
|
2022-08-18 12:49:37 -07:00
|
|
|
} else if (auto sharedAttr = attr.dyn_cast<SharedEncodingAttr>()) {
|
2022-07-26 10:50:11 -07:00
|
|
|
os << "shared";
|
|
|
|
return AliasResult::FinalAlias;
|
2022-08-18 12:49:37 -07:00
|
|
|
} else if (auto blockedAttr = attr.dyn_cast<BlockedEncodingAttr>()) {
|
2022-07-26 10:50:11 -07:00
|
|
|
os << "blocked";
|
2022-07-27 01:32:10 -07:00
|
|
|
return AliasResult::FinalAlias;
|
2022-08-18 12:49:37 -07:00
|
|
|
} /* else if (auto sliceAttr = attr.dyn_cast<SliceEncodingAttr>()) {
|
|
|
|
os << "slice";
|
|
|
|
return AliasResult::FinalAlias;
|
|
|
|
} */
|
2022-08-24 12:55:49 -07:00
|
|
|
return OpAsmDialectInterface::getAlias(attr, os);
|
2022-07-26 10:50:11 -07:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-04-28 18:51:31 +08:00
|
|
|
void TritonGPUDialect::initialize() {
|
2022-05-02 21:51:00 +08:00
|
|
|
addAttributes<
|
|
|
|
#define GET_ATTRDEF_LIST
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
|
|
|
>();
|
2022-04-28 18:51:31 +08:00
|
|
|
addOperations<
|
|
|
|
#define GET_OP_LIST
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
2022-07-26 17:25:03 -07:00
|
|
|
>();
|
2022-07-26 10:50:11 -07:00
|
|
|
addInterfaces<TritonGPUOpAsmInterface>();
|
2022-04-28 18:51:31 +08:00
|
|
|
}
|
2022-05-01 22:06:54 +08:00
|
|
|
|
2022-08-18 12:49:37 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Verification
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-06-07 19:34:59 +08:00
|
|
|
static LogicalResult verify(CopyAsyncOp op) {
|
|
|
|
Type resType = op.getResult().getType();
|
|
|
|
if (auto tensorType = resType.dyn_cast<RankedTensorType>()) {
|
|
|
|
Attribute encoding = tensorType.getEncoding();
|
2022-08-18 12:49:37 -07:00
|
|
|
if (!encoding.isa<SharedEncodingAttr>())
|
2022-06-07 19:34:59 +08:00
|
|
|
return op.emitOpError("copy_async should return a shared memory tensor");
|
|
|
|
} else
|
|
|
|
return op.emitOpError("copy_async should return a tensor");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-05-01 22:06:54 +08:00
|
|
|
#define GET_OP_CLASSES
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
2022-05-24 19:48:56 +08:00
|
|
|
|
|
|
|
// verify TritonGPU ops
|
2022-07-26 17:25:03 -07:00
|
|
|
LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
|
|
|
NamedAttribute attr) {
|
2022-05-24 19:48:56 +08:00
|
|
|
// TODO: fill this.
|
|
|
|
return success();
|
2022-08-24 12:55:49 -07:00
|
|
|
}
|