From 2239ac199864fb97e6a2493da27a8400023f0db1 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Thu, 28 Apr 2022 18:51:31 +0800 Subject: [PATCH] more progress on TritonGPU --- include/triton/Dialect/CMakeLists.txt | 2 +- include/triton/Dialect/Triton/IR/Dialect.h | 4 +- include/triton/Dialect/Triton/IR/TritonOps.td | 50 +---------------- .../triton/Dialect/Triton/IR/TritonTypes.td | 54 +++++++++++++++++++ .../Dialect/TritonGPU/IR/CMakeLists.txt | 12 +++++ include/triton/Dialect/TritonGPU/IR/Dialect.h | 11 ++++ .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 6 +-- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 9 ++-- lib/Dialect/CMakeLists.txt | 1 + lib/Dialect/TritonGPU/CMakeLists.txt | 1 + lib/Dialect/TritonGPU/IR/CMakeLists.txt | 10 ++++ lib/Dialect/TritonGPU/IR/Dialect.cpp | 12 +++++ 12 files changed, 115 insertions(+), 57 deletions(-) create mode 100644 include/triton/Dialect/Triton/IR/TritonTypes.td create mode 100644 include/triton/Dialect/TritonGPU/IR/CMakeLists.txt create mode 100644 lib/Dialect/TritonGPU/CMakeLists.txt create mode 100644 lib/Dialect/TritonGPU/IR/CMakeLists.txt create mode 100644 lib/Dialect/TritonGPU/IR/Dialect.cpp diff --git a/include/triton/Dialect/CMakeLists.txt b/include/triton/Dialect/CMakeLists.txt index 819e542c8..27cb65ce5 100644 --- a/include/triton/Dialect/CMakeLists.txt +++ b/include/triton/Dialect/CMakeLists.txt @@ -1,2 +1,2 @@ add_subdirectory(Triton) -# add_subdirectory(TritonGPU) +add_subdirectory(TritonGPU) diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index d5c8d2f43..445f0beb7 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -1,5 +1,5 @@ -#ifndef TRITON_IR_DIALECT_H_ -#define TRITON_IR_DIALECT_H_ +#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITON_IR_DIALECT_H_ #include "mlir/IR/BuiltinOps.h" diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index de7e11d43..ef0126f3a 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -1,60 +1,14 @@ #ifndef Triton_OPS #define Triton_OPS -include "TritonDialect.td" +include "triton/Dialect/Triton/IR/TritonDialect.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType -// -// Types -// -class TritonTypeDef - : TypeDef { - // Used by printer/parser - let mnemonic = _mnemonic; -} - -def F8 : TritonTypeDef<"Float8", "f8">; -def BF8 : TritonTypeDef<"BFloat8", "bf8">; - -def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">; -def TT_FloatTensor : TensorOf<[TT_Float]>; - -// IntegerType -def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; -def TT_IntegerTensor : TensorOf<[TT_Int]>; - -// PointerType -def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> { - let summary = "pointer type"; - - let description = [{ - Triton PointerType - }]; - - let parameters = (ins "Type":$pointeeType, "int":$addressSpace); - - let builders = [ - TypeBuilderWithInferredContext<(ins - "Type":$pointeeType, - "int":$addressSpace - ), [{ - return $_get(pointeeType.getContext(), pointeeType, addressSpace); - }]> - ]; - - let skipDefaultBuilders = 1; -} -def TT_PtrTensor : TensorOf<[TT_Pointer]>; - -def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>; -def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>; - -def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor, - TT_Pointer, TT_PtrTensor]>; def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; // diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td new file mode 100644 index 000000000..b5238996e --- /dev/null +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -0,0 +1,54 @@ +#ifndef TRITON_TYPES +#define TRITON_TYPES + +include "triton/Dialect/Triton/IR/TritonDialect.td" + +// +// Types +// +class TritonTypeDef + : TypeDef { + // Used by printer/parser + let mnemonic = _mnemonic; +} + +def F8 : TritonTypeDef<"Float8", "f8">; +def BF8 : TritonTypeDef<"BFloat8", "bf8">; + +def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">; +def TT_FloatTensor : TensorOf<[TT_Float]>; + +// IntegerType +def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; +def TT_IntegerTensor : TensorOf<[TT_Int]>; + +// PointerType +def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> { + let summary = "pointer type"; + + let description = [{ + Triton PointerType + }]; + + let parameters = (ins "Type":$pointeeType, "int":$addressSpace); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Type":$pointeeType, + "int":$addressSpace + ), [{ + return $_get(pointeeType.getContext(), pointeeType, addressSpace); + }]> + ]; + + let skipDefaultBuilders = 1; +} +def TT_PtrTensor : TensorOf<[TT_Pointer]>; + +def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>; +def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>; + +def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor, + TT_Pointer, TT_PtrTensor]>; + +#endif diff --git a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..c44acaa3d --- /dev/null +++ b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -0,0 +1,12 @@ +set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +add_public_tablegen_target(TritonGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td) +mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(TritonGPUAttrDefsIncGen) + diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index bd08e1195..ae70d683e 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -1,5 +1,16 @@ #ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ #define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +// TritonGPU depends on Triton +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Ops.h.inc" + #endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 83b55b836..82e77a798 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -56,7 +56,7 @@ For example, a row-major coalesced layout may distribute a 64x16 tensor over 2 w block| thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3] tile | tile size 2 || A_{1, 0}[T0] A_{1, 1}[T0] ... A_{1, 6}[T3] A_{1, 7}[T3] A_{1, 8}[T0] A_{1, 9}[T0] ... A_{1, 14}[T3] A_{1, 15}[T3] size } .... -16 | A_{30, 0}[T60] A_{14, 1}[T60] ... A_{14, 6}[T63] A_{14, 7}[T63] A_{14, 8}[T60] A_{14, 9}[T60] ... A_{14, 14}[T63] A_{14, 15}[T63] +32 | A_{30, 0}[T60] A_{14, 1}[T60] ... A_{14, 6}[T63] A_{14, 7}[T63] A_{14, 8}[T60] A_{14, 9}[T60] ... A_{14, 14}[T63] A_{14, 15}[T63] | A_{31, 0}[T60] A_{15, 1}[T60] ... A_{15, 6}[T63] A_{15, 7}[T63] A_{15, 8}[T60] A_{15, 9}[T60] ... A_{15, 14}[T63] A_{15, 15}[T63] -----------------------------/\----------------------------------- block tile size 8 @@ -71,7 +71,7 @@ size } .... And the associated TritonGPU MLIR #SMEM = #triton_gpu.encoding<{ threadTileSize = {2, 2} - blockTileSize = {16, 8} + blockTileSize = {32, 8} }> // note to Da: In current Triton codebase, `nanoTileSize = threadTileSize`, and `macro-tile size = blockTileSize / threadTileSize` @@ -81,7 +81,7 @@ And the associated TritonGPU MLIR let parameters = ( ins ArrayRefParameter<"unsigned">:$threadTileSize, - ArrayRefParameter<"unsigned">:$blockTileSize, + ArrayRefParameter<"unsigned">:$blockTileSize ); // let genVerifyDecl = 1; diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 934d01f5b..135ff65d4 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -1,10 +1,13 @@ #ifndef TRITONGPU_OPS #define TRITONGPU_OPS -include "TritonGPUDialect.td" -// include "" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType -def TTG_Op traits = []> : +class TTG_Op traits = []> : Op; def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index 5e601271e..27cb65ce5 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(Triton) +add_subdirectory(TritonGPU) diff --git a/lib/Dialect/TritonGPU/CMakeLists.txt b/lib/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/lib/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..7df435006 --- /dev/null +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_dialect_library(TritonGPUIR + Dialect.cpp + + DEPENDS + TritonGPUTableGen + TritonGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + TritonIR +) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp new file mode 100644 index 000000000..dd7019ab4 --- /dev/null +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -0,0 +1,12 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" + +using namespace mlir::triton::gpu; + +void TritonGPUDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" + >(); +}