more progress on TritonGPU

This commit is contained in:
Yan Da
2022-04-28 18:51:31 +08:00
parent 012e8c5b2b
commit 2239ac1998
12 changed files with 115 additions and 57 deletions

View File

@@ -1,2 +1,2 @@
add_subdirectory(Triton) add_subdirectory(Triton)
# add_subdirectory(TritonGPU) add_subdirectory(TritonGPU)

View File

@@ -1,5 +1,5 @@
#ifndef TRITON_IR_DIALECT_H_ #ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
#define TRITON_IR_DIALECT_H_ #define TRITON_DIALECT_TRITON_IR_DIALECT_H_
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"

View File

@@ -1,60 +1,14 @@
#ifndef Triton_OPS #ifndef Triton_OPS
#define Triton_OPS #define Triton_OPS
include "TritonDialect.td" include "triton/Dialect/Triton/IR/TritonDialect.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "mlir/IR/EnumAttr.td" include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
//
// Types
//
class TritonTypeDef<string name, string _mnemonic>
: TypeDef<Triton_Dialect, name> {
// Used by printer/parser
let mnemonic = _mnemonic;
}
def F8 : TritonTypeDef<"Float8", "f8">;
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
// IntegerType
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
def TT_IntegerTensor : TensorOf<[TT_Int]>;
// PointerType
def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
let summary = "pointer type";
let description = [{
Triton PointerType
}];
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
let builders = [
TypeBuilderWithInferredContext<(ins
"Type":$pointeeType,
"int":$addressSpace
), [{
return $_get(pointeeType.getContext(), pointeeType, addressSpace);
}]>
];
let skipDefaultBuilders = 1;
}
def TT_PtrTensor : TensorOf<[TT_Pointer]>;
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>;
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
TT_Pointer, TT_PtrTensor]>;
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
// //

View File

@@ -0,0 +1,54 @@
#ifndef TRITON_TYPES
#define TRITON_TYPES
include "triton/Dialect/Triton/IR/TritonDialect.td"
//
// Types
//
class TritonTypeDef<string name, string _mnemonic>
: TypeDef<Triton_Dialect, name> {
// Used by printer/parser
let mnemonic = _mnemonic;
}
def F8 : TritonTypeDef<"Float8", "f8">;
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
// IntegerType
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
def TT_IntegerTensor : TensorOf<[TT_Int]>;
// PointerType
def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
let summary = "pointer type";
let description = [{
Triton PointerType
}];
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
let builders = [
TypeBuilderWithInferredContext<(ins
"Type":$pointeeType,
"int":$addressSpace
), [{
return $_get(pointeeType.getContext(), pointeeType, addressSpace);
}]>
];
let skipDefaultBuilders = 1;
}
def TT_PtrTensor : TensorOf<[TT_Pointer]>;
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>;
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
TT_Pointer, TT_PtrTensor]>;
#endif

View File

@@ -0,0 +1,12 @@
set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
add_public_tablegen_target(TritonGPUTableGen)
set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(TritonGPUAttrDefsIncGen)

View File

@@ -1,5 +1,16 @@
#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ #ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ #define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
// TritonGPU depends on Triton
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ #endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

View File

@@ -56,7 +56,7 @@ For example, a row-major coalesced layout may distribute a 64x16 tensor over 2 w
block| thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3] block| thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3]
tile | tile size 2 || A_{1, 0}[T0] A_{1, 1}[T0] ... A_{1, 6}[T3] A_{1, 7}[T3] A_{1, 8}[T0] A_{1, 9}[T0] ... A_{1, 14}[T3] A_{1, 15}[T3] tile | tile size 2 || A_{1, 0}[T0] A_{1, 1}[T0] ... A_{1, 6}[T3] A_{1, 7}[T3] A_{1, 8}[T0] A_{1, 9}[T0] ... A_{1, 14}[T3] A_{1, 15}[T3]
size } .... size } ....
16 | A_{30, 0}[T60] A_{14, 1}[T60] ... A_{14, 6}[T63] A_{14, 7}[T63] A_{14, 8}[T60] A_{14, 9}[T60] ... A_{14, 14}[T63] A_{14, 15}[T63] 32 | A_{30, 0}[T60] A_{14, 1}[T60] ... A_{14, 6}[T63] A_{14, 7}[T63] A_{14, 8}[T60] A_{14, 9}[T60] ... A_{14, 14}[T63] A_{14, 15}[T63]
| A_{31, 0}[T60] A_{15, 1}[T60] ... A_{15, 6}[T63] A_{15, 7}[T63] A_{15, 8}[T60] A_{15, 9}[T60] ... A_{15, 14}[T63] A_{15, 15}[T63] | A_{31, 0}[T60] A_{15, 1}[T60] ... A_{15, 6}[T63] A_{15, 7}[T63] A_{15, 8}[T60] A_{15, 9}[T60] ... A_{15, 14}[T63] A_{15, 15}[T63]
-----------------------------/\----------------------------------- -----------------------------/\-----------------------------------
block tile size 8 block tile size 8
@@ -71,7 +71,7 @@ size } ....
And the associated TritonGPU MLIR And the associated TritonGPU MLIR
#SMEM = #triton_gpu.encoding<{ #SMEM = #triton_gpu.encoding<{
threadTileSize = {2, 2} threadTileSize = {2, 2}
blockTileSize = {16, 8} blockTileSize = {32, 8}
}> }>
// note to Da: In current Triton codebase, `nanoTileSize = threadTileSize`, and `macro-tile size = blockTileSize / threadTileSize` // note to Da: In current Triton codebase, `nanoTileSize = threadTileSize`, and `macro-tile size = blockTileSize / threadTileSize`
@@ -81,7 +81,7 @@ And the associated TritonGPU MLIR
let parameters = ( let parameters = (
ins ins
ArrayRefParameter<"unsigned">:$threadTileSize, ArrayRefParameter<"unsigned">:$threadTileSize,
ArrayRefParameter<"unsigned">:$blockTileSize, ArrayRefParameter<"unsigned">:$blockTileSize
); );
// let genVerifyDecl = 1; // let genVerifyDecl = 1;

View File

@@ -1,10 +1,13 @@
#ifndef TRITONGPU_OPS #ifndef TRITONGPU_OPS
#define TRITONGPU_OPS #define TRITONGPU_OPS
include "TritonGPUDialect.td" include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
// include "" include "triton/Dialect/Triton/IR/TritonTypes.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
def TTG_Op<string mnemonic, list<Trait> traits = []> : class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>; Op<TritonGPU_Dialect, mnemonic, traits>;
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",

View File

@@ -1 +1,2 @@
add_subdirectory(Triton) add_subdirectory(Triton)
add_subdirectory(TritonGPU)

View File

@@ -0,0 +1 @@
add_subdirectory(IR)

View File

@@ -0,0 +1,10 @@
add_mlir_dialect_library(TritonGPUIR
Dialect.cpp
DEPENDS
TritonGPUTableGen
TritonGPUAttrDefsIncGen
LINK_LIBS PUBLIC
TritonIR
)

View File

@@ -0,0 +1,12 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
using namespace mlir::triton::gpu;
void TritonGPUDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
>();
}