[Triton-IR] Added type inference and verifier for Triton-IR operations (#767)

This commit is contained in:
Philippe Tillet
2022-10-11 18:16:41 -07:00
committed by GitHub
parent b6e5a231e5
commit 623c99609f
27 changed files with 494 additions and 348 deletions

View File

@@ -1,10 +1,19 @@
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
set(LLVM_TARGET_DEFINITIONS TritonDialect.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
mlir_tablegen(Types.h.inc -gen-typedef-decls)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td)
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
add_public_tablegen_target(TritonTableGen)

View File

@@ -17,4 +17,24 @@
#define GET_OP_CLASSES
#include "triton/Dialect/Triton/IR/Ops.h.inc"
namespace mlir {
namespace triton {
class DialectInferLayoutInterface
: public DialectInterface::Base<DialectInferLayoutInterface> {
public:
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
virtual LogicalResult
inferReduceOpEncoding(Attribute operandEncoding, int axis,
Attribute &resultEncoding) const = 0;
virtual LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
Attribute &resultEncoding) const = 0;
};
} // namespace triton
} // namespace mlir
#endif // TRITON_IR_DIALECT_H_

View File

@@ -0,0 +1,9 @@
#ifndef TRITON_IR_INTERFACES_H_
#define TRITON_IR_INTERFACES_H_
#include "mlir/IR/OpDefinition.h"
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
#endif // TRITON_IR_TYPES_H_

View File

@@ -10,50 +10,47 @@
namespace mlir {
namespace OpTrait {
// TODO: should have `namespace triton {}` here
// These functions are out-of-line implementations of the methods in the
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
LogicalResult verifySameOperandsAndResultEncoding(Operation *op);
LogicalResult verifySameOperandsEncoding(Operation *op);
// The rationale for this trait is to prevent users from creating programs
// that would have catastrophic register pressure and cause the compiler to
// hang.
// Since H100 has 256KB registers, we should allow users to create tensors
// of size up to 256K elements. It will spill for datatypes wider than 1B,
// but we probably should limit number of elements (rather than bytes) to
// keep specs simple
int constexpr maxTensorNumElements = 1048576;
LogicalResult verifyTensorSize(Operation *op);
} // namespace impl
template <class ConcreteType>
class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
public:
// TODO: move impl to .cc files
static LogicalResult verifyTrait(Operation *op) {
// The rationale for this number is to prevent users from creating programs
// that would have catastrophic register pressure and cause the compiler to
// hang.
// Since H100 has 256KB registers, we should allow users to create tensors
// of size up to 256K elements. It will spill for datatypes wider than 1B,
// but we probably should limit number of elements (rather than bytes) to
// keep specs simple
int constexpr maxElement = 1048576;
for (auto opType : op->getOperandTypes()) {
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
int64_t numElements = 1;
for (int64_t s : tensorType.getShape())
numElements *= s;
if (numElements > maxElement)
return op->emitError("Maximum allowed number of elements is ")
<< maxElement << ", but " << *op << " has more than that";
if ((numElements & (numElements - 1)) != 0)
return op->emitError("Number of elements must be power-of-two, but ")
<< *op << " doesn't follow the rule";
}
}
return impl::verifyTensorSize(op);
}
};
for (auto opType : op->getResultTypes()) {
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
int64_t numElements = 1;
for (int64_t s : tensorType.getShape())
numElements *= s;
if (numElements > maxElement)
return op->emitError("Maximum allowed number of elements is ")
<< maxElement << ", but " << *op << " has more than that";
if ((numElements & (numElements - 1)) != 0)
return op->emitError("Number of elements must be power-of-two, but ")
<< *op << " doesn't follow the rule";
}
}
template <typename ConcreteType>
class SameOperandsAndResultEncoding
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsAndResultEncoding(op);
}
};
return success();
template <typename ConcreteType>
class SameOperandsEncoding
: public TraitBase<ConcreteType, SameOperandsEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsEncoding(op);
}
};

View File

@@ -41,4 +41,7 @@ def Triton_Dialect : Dialect {
let hasConstantMaterializer = 1;
}
include "triton/Dialect/Triton/IR/TritonTypes.td"
#endif // TRITON_DIALECT

View File

@@ -0,0 +1,6 @@
#ifndef TRITON_INTERFACES
#define TRITON_INTERFACES
include "mlir/IR/OpBase.td"
#endif // TRITON_INTERFACES

View File

@@ -4,18 +4,23 @@
include "triton/Dialect/Triton/IR/TritonDialect.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
//
// Op Base
//
class TT_Op<string mnemonic, list<Trait> traits = []> :
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])>;
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])> {
}
//
// CastOps
@@ -25,7 +30,9 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
// fptoui, fptosi, uitofp, sitofp,
// extf, tructf,
// extui, extsi, tructi
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffect,
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast int64 to pointer";
@@ -36,7 +43,9 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffec
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffect,
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast pointer to int64";
@@ -47,7 +56,9 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffec
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Floating point casting for custom types";
@@ -67,11 +78,33 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
// TODO: We need a verifier here.
}
//
// Pointer Arith Ops
//
def TT_AddPtrOp : TT_Op<"addptr",
[NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">,
TypesMatchWith<"result shape matches offset shape",
"result", "offset",
"getI32SameShape($_self)">]> {
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
let results = (outs TT_PtrLike:$result);
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
}
//
// Load/Store Ops
//
def TT_LoadOp : TT_Op<"load",
[SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
SameVariadicOperandSize,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer ptr type from result type",
@@ -110,6 +143,7 @@ def TT_LoadOp : TT_Op<"load",
def TT_StoreOp : TT_Op<"store",
[SameOperandsShape,
SameOperandsEncoding,
MemoryEffects<[MemWrite]>,
TypesMatchWith<"infer ptr type from value type",
"value", "ptr",
@@ -133,130 +167,11 @@ def TT_StoreOp : TT_Op<"store",
let hasCanonicalizer = 1;
}
def TT_AddPtrOp : TT_Op<"addptr",
[NoSideEffect, SameOperandsAndResultShape,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">,
TypesMatchWith<"result shape matches offset shape",
"result", "offset",
"getI32SameShape($_self)">]> {
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
let results = (outs TT_PtrLike:$result);
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
}
//
// Shape Manipulation Ops
// Atomic Op
//
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "expand_dims";
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_ViewOp : TT_Op<"view", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "view";
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_SplatOp : TT_Op<"splat", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "splat";
let arguments = (ins TT_Type:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "broadcast. No left-padding as of now.";
let arguments = (ins TT_Type:$src);
let results = (outs TT_Type:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_CatOp : TT_Op<"cat", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "concatenate 2 tensors";
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
}
//
// builtin Ops
//
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot";
let description = [{
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let results = (outs TT_FpIntTensor:$d);
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
// let hasCanonicalizer = 1;
}
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect]> {
let summary = "reduce";
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let builders = [
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
];
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
}
def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "atomic rmw";
let description = [{
@@ -271,7 +186,8 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
let results = (outs TT_Type:$result);
}
def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "atomic cas";
let description = [{
@@ -289,10 +205,133 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
let results = (outs TT_Type:$result);
}
//
// External Function Ops
// Shape Manipulation Ops
//
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "splat";
let arguments = (ins TT_Type:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
SameOperandsAndResultElementType]> {
let summary = "expand_dims";
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "view";
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "broadcast. No left-padding as of now.";
let arguments = (ins TT_Type:$src);
let results = (outs TT_Type:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "concatenate 2 tensors";
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
}
//
// SPMD Ops
//
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
//
// Dot Op
//
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot";
let description = [{
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let results = (outs TT_FpIntTensor:$d);
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
}
//
// Reduce Op
//
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "reduce";
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let builders = [
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
];
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
}
//
// External elementwise op
//
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
SameVariadicOperandSize]> {
let summary = "ext_elemwise";
@@ -307,10 +346,11 @@ def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOpe
let results = (outs TT_Tensor:$result);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
}
//
// Intrinsics
// Make Range Op
//
// TODO: should have ConstantLike as Trait
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {

View File

@@ -11,6 +11,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
#define GET_OP_CLASSES

View File

@@ -2,6 +2,7 @@
#define TRITONGPU_ATTRDEFS
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
//===----------------------------------------------------------------------===//
// TritonGPU Attribute Definitions
@@ -34,6 +35,7 @@ Right now, Triton implements two classes of layouts: shared, and distributed.
code extraBaseClassDeclaration = [{
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
}];
}
@@ -301,15 +303,15 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
TODO: improve docs
A = [x x x x x x x x]
[x x x x x x x x]
L_parent = [0 1 2 3 ]
[4 5 6 7 ]
[8 9 10 11]
[12 13 14 15]
parent = [0 1 2 3 ]
[4 5 6 7 ]
[8 9 10 11]
[12 13 14 15]
dim = 0
Then the data of A would be distributed as follow between the 16 CUDA threads:
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15} ]
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ]
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.

View File

@@ -76,15 +76,4 @@ def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::Modu
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> {
let summary = "verify TritonGPU IR";
let description = [{}];
let constructor = "mlir::createTritonGPUVerifier()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::gpu::GPUDialect"];
}
#endif