More on TritonGPU conversion
This commit is contained in:
@@ -41,6 +41,10 @@ void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
}
|
||||
|
||||
void TritonGPUDialect::initialize() {
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
||||
>();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
|
@@ -11,7 +11,12 @@ using namespace mlir;
|
||||
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
int numThreads)
|
||||
: context(context), numThreads(numThreads) {
|
||||
addConversion([&](RankedTensorType tensorType) -> RankedTensorType {
|
||||
// TODO: how does MLIR pick the right conversion?
|
||||
addConversion([](Type type) { return type; });
|
||||
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
|
||||
MLIRContext *context = this->context;
|
||||
int numThreads = this->numThreads;
|
||||
|
||||
llvm::ArrayRef<int64_t> shape = tensorType.getShape();
|
||||
Type elementType = tensorType.getElementType();
|
||||
int64_t rank = tensorType.getRank();
|
||||
@@ -45,15 +50,28 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
//
|
||||
// TritonGPUConversion
|
||||
//
|
||||
TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
|
||||
: ConversionTarget(context) {
|
||||
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||
: ConversionTarget(context), typeConverter(typeConverter) {
|
||||
addLegalDialect<triton::TritonDialect,
|
||||
arith::ArithmeticDialect,
|
||||
StandardOpsDialect,
|
||||
scf::SCFDialect>();
|
||||
|
||||
// Some ops from SCF are illegal
|
||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
||||
scf::ReduceOp, scf::ReduceReturnOp>();
|
||||
|
||||
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
||||
// // We have requirements for the data layouts
|
||||
// addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
|
||||
|
Reference in New Issue
Block a user