diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 445f0beb7..80a2aab2e 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -9,6 +9,7 @@ #include "mlir/Dialect/SCF/SCF.h" #include "triton/Dialect/Triton/IR/Traits.h" +#include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Dialect.h.inc" #include "triton/Dialect/Triton/IR/OpsEnums.h.inc" diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index ae70d683e..dfa5ef864 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -9,6 +9,9 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc" + #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.h.inc" diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 82e77a798..6368c0056 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -2,6 +2,7 @@ #define TRITONGPU_ATTRDEFS include "TritonGPUDialect.td" +// include "mlir/IR/TensorEncoding.td" class TritonGPU_Attr traits = []> : AttrDef; @@ -43,7 +44,7 @@ And the associated TritonGPU MLIR ); } -def TritonGPUCoalescedEncodingAttr : TritonGPU_Attr<"TritonGPUCoalescedEncoding"> { +def TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> { let mnemonic = "coalesced encoding"; let description = [{ @@ -81,7 +82,9 @@ And the associated TritonGPU MLIR let parameters = ( ins ArrayRefParameter<"unsigned">:$threadTileSize, - ArrayRefParameter<"unsigned">:$blockTileSize + ArrayRefParameter<"unsigned">:$blockTileSize, + // fastest-changing axis first + ArrayRefParameter<"unsigned">:$order ); // let genVerifyDecl = 1; @@ -90,7 +93,7 @@ And the associated TritonGPU MLIR def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> { let mnemonic = "mma encoding"; - let description = [{TODO: I think we may be able to implement it as a special-case of Coalesced encoding with maybe one more warpTileSize attribute!}]; + let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}]; let parameters = ( ins diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 135ff65d4..9e3431ffc 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -2,6 +2,7 @@ #define TRITONGPU_OPS include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect diff --git a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h new file mode 100644 index 000000000..fddcf2905 --- /dev/null +++ b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// +// Defines utilities to use while converting to the TritonGPU dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +class TritonGPUTypeConverter : public TypeConverter { +public: + TritonGPUTypeConverter(MLIRContext *context, int numThreads); +private: + MLIRContext *context; + int numThreads; +}; + +class TritonGPUConversionTarget : public ConversionTarget { +public: + explicit TritonGPUConversionTarget(MLIRContext &ctx); +}; + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index 382b2c977..5c044d026 100644 --- a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -15,4 +15,5 @@ add_mlir_conversion_library(TritonToTritonGPU MLIRPass TritonIR TritonGPUIR -) \ No newline at end of file + TritonGPUConversion +) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 5387df5d6..461bb7ab5 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -1,7 +1,8 @@ -#include "mlir/Transforms/DialectConversion.h" #include "triton/Dialect/Triton/IR/Dialect.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "../PassDetail.h" using namespace mlir; @@ -39,7 +40,7 @@ void populateArithmeticPatternsAndLegality( target.addDynamicallyLegalDialect( // TODO: check above rule here [](Operation *op){ - return false; + return true; } ); // Rewrite rule @@ -47,26 +48,27 @@ void populateArithmeticPatternsAndLegality( } -class ConvertTritonToTritonGPU: +class ConvertTritonToTritonGPU : public ConvertTritonToTritonGPUBase { public: - void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionTarget target(*context); - // type converter - TypeConverter typeConverter; - // rewrite patterns - RewritePatternSet patterns(context); - // add rules - populateArithmeticPatternsAndLegality(typeConverter, patterns, target); + void runOnOperation() override { + MLIRContext *context = &getContext(); + TritonGPUConversionTarget target(*context); + ModuleOp mod = getOperation(); + // int numThreads = mod.getAttr(); + // type converter + TritonGPUTypeConverter typeConverter(context, /*numThreads*/4*32); + // rewrite patterns + RewritePatternSet patterns(context); + // add rules + populateArithmeticPatternsAndLegality(typeConverter, patterns, target); - if(failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); - - } + if(failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } }; } diff --git a/lib/Dialect/TritonGPU/CMakeLists.txt b/lib/Dialect/TritonGPU/CMakeLists.txt index f33061b2d..9f57627c3 100644 --- a/lib/Dialect/TritonGPU/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index dd7019ab4..4524faa6c 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1,12 +1,52 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" +#include #include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" using namespace mlir::triton::gpu; +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" + +mlir::Attribute +TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type) { + llvm_unreachable("Not implemented"); +} + +void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const { + llvm_unreachable("Not implemented"); +} + +mlir::Attribute +TritonGPUMmaEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) { + llvm_unreachable("Not implemented"); +} + +void TritonGPUMmaEncodingAttr::print(mlir::AsmPrinter &printer) const { + llvm_unreachable("Not implemented"); +} + +mlir::Attribute +TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) { + llvm_unreachable("Not implemented"); +} + +void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const { + llvm_unreachable("Not implemented"); +} + void TritonGPUDialect::initialize() { addOperations< #define GET_OP_LIST #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" >(); } + + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..5d089f297 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_dialect_library(TritonGPUConversion + TritonGPUConversion.cpp + + # ADDITIONAL_HEADER_DIRS + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR + # MLIRTransformUtils +) diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp new file mode 100644 index 000000000..afe96d06d --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -0,0 +1,68 @@ +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include + +using namespace mlir; + +// +// TypeConverter +// +TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, + int numThreads) + : context(context), numThreads(numThreads) { + addConversion([&](RankedTensorType tensorType) -> RankedTensorType { + llvm::ArrayRef shape = tensorType.getShape(); + Type elementType = tensorType.getElementType(); + int64_t rank = tensorType.getRank(); + int64_t numElements = tensorType.getNumElements(); + + // TODO: we should raise exception here. + assert(numElements > numThreads); + assert(numElements % numThreads == 0); + + // assert no encoding? + + // Now we assume: + // contiguous = 1, order = 0, 1, 2, ..., + llvm::SmallVector threadTileSize(rank, 1); // naive layout + llvm::SmallVector blockTileSize(rank); + llvm::SmallVector order(rank); + int remainingThreads = numThreads; + for (int64_t dim = 0; dim < rank; ++dim) { + blockTileSize[dim] = std::clamp(remainingThreads, 1, int(shape[dim])); + order[dim] = dim; + + remainingThreads /= blockTileSize[dim]; + // TODO: will we need repetition? + } + Attribute encoding = triton::gpu::TritonGPUDistributedEncodingAttr::get( + context, threadTileSize, blockTileSize, order); + return RankedTensorType::get(shape, elementType, encoding); + }); +} + +// +// TritonGPUConversion +// +TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context) + : ConversionTarget(context) { + addLegalDialect(); + + // Some ops from SCF are illegal + addIllegalOp(); + + // // We have requirements for the data layouts + // addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { + // Attribute aEncoding = dotOp.a().getType().cast().getEncoding(); + // Attribute bEncoding = dotOp.b().getType().cast().getEncoding(); + // if (aEncoding && aEncoding.isa() && + // bEncoding && bEncoding.isa()) + // return true; + // return false; + // }); + +} diff --git a/rewrite-test/jit/matmul/matmul.py b/rewrite-test/jit/matmul/matmul.py index 10d2b7da5..416d8f0b9 100644 --- a/rewrite-test/jit/matmul/matmul.py +++ b/rewrite-test/jit/matmul/matmul.py @@ -101,6 +101,7 @@ pm = _triton.ir.pass_manager(ctx) pm.add_inliner_pass() pm.add_triton_combine_pass() pm.add_canonicalizer_pass() +pm.add_convert_triton_to_tritongpu_pass() pm.run(mod) assert mod.verify()