[Triton-IR] Added type inference and verifier for Triton-IR operations (#767)
This commit is contained in:
@@ -1,10 +1,19 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonDialect.td)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||
mlir_tablegen(Types.h.inc -gen-typedef-decls)
|
||||
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td)
|
||||
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
|
||||
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
|
||||
|
||||
add_public_tablegen_target(TritonTableGen)
|
||||
|
@@ -17,4 +17,24 @@
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/Ops.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
class DialectInferLayoutInterface
|
||||
: public DialectInterface::Base<DialectInferLayoutInterface> {
|
||||
public:
|
||||
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
|
||||
|
||||
virtual LogicalResult
|
||||
inferReduceOpEncoding(Attribute operandEncoding, int axis,
|
||||
Attribute &resultEncoding) const = 0;
|
||||
|
||||
virtual LogicalResult
|
||||
inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
|
||||
Attribute &resultEncoding) const = 0;
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_IR_DIALECT_H_
|
||||
|
9
include/triton/Dialect/Triton/IR/Interfaces.h
Normal file
9
include/triton/Dialect/Triton/IR/Interfaces.h
Normal file
@@ -0,0 +1,9 @@
|
||||
#ifndef TRITON_IR_INTERFACES_H_
|
||||
#define TRITON_IR_INTERFACES_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
@@ -10,50 +10,47 @@
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
// TODO: should have `namespace triton {}` here
|
||||
|
||||
// These functions are out-of-line implementations of the methods in the
|
||||
// corresponding trait classes. This avoids them being template
|
||||
// instantiated/duplicated.
|
||||
namespace impl {
|
||||
LogicalResult verifySameOperandsAndResultEncoding(Operation *op);
|
||||
LogicalResult verifySameOperandsEncoding(Operation *op);
|
||||
// The rationale for this trait is to prevent users from creating programs
|
||||
// that would have catastrophic register pressure and cause the compiler to
|
||||
// hang.
|
||||
// Since H100 has 256KB registers, we should allow users to create tensors
|
||||
// of size up to 256K elements. It will spill for datatypes wider than 1B,
|
||||
// but we probably should limit number of elements (rather than bytes) to
|
||||
// keep specs simple
|
||||
int constexpr maxTensorNumElements = 1048576;
|
||||
LogicalResult verifyTensorSize(Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
template <class ConcreteType>
|
||||
class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
|
||||
public:
|
||||
// TODO: move impl to .cc files
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
// The rationale for this number is to prevent users from creating programs
|
||||
// that would have catastrophic register pressure and cause the compiler to
|
||||
// hang.
|
||||
// Since H100 has 256KB registers, we should allow users to create tensors
|
||||
// of size up to 256K elements. It will spill for datatypes wider than 1B,
|
||||
// but we probably should limit number of elements (rather than bytes) to
|
||||
// keep specs simple
|
||||
int constexpr maxElement = 1048576;
|
||||
for (auto opType : op->getOperandTypes()) {
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > maxElement)
|
||||
return op->emitError("Maximum allowed number of elements is ")
|
||||
<< maxElement << ", but " << *op << " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule";
|
||||
}
|
||||
}
|
||||
return impl::verifyTensorSize(op);
|
||||
}
|
||||
};
|
||||
|
||||
for (auto opType : op->getResultTypes()) {
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > maxElement)
|
||||
return op->emitError("Maximum allowed number of elements is ")
|
||||
<< maxElement << ", but " << *op << " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule";
|
||||
}
|
||||
}
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsAndResultEncoding
|
||||
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySameOperandsAndResultEncoding(op);
|
||||
}
|
||||
};
|
||||
|
||||
return success();
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsEncoding
|
||||
: public TraitBase<ConcreteType, SameOperandsEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySameOperandsEncoding(op);
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -41,4 +41,7 @@ def Triton_Dialect : Dialect {
|
||||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
|
||||
|
||||
#endif // TRITON_DIALECT
|
||||
|
6
include/triton/Dialect/Triton/IR/TritonInterfaces.td
Normal file
6
include/triton/Dialect/Triton/IR/TritonInterfaces.td
Normal file
@@ -0,0 +1,6 @@
|
||||
#ifndef TRITON_INTERFACES
|
||||
#define TRITON_INTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
#endif // TRITON_INTERFACES
|
@@ -4,18 +4,23 @@
|
||||
include "triton/Dialect/Triton/IR/TritonDialect.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
|
||||
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
||||
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
|
||||
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
|
||||
|
||||
//
|
||||
// Op Base
|
||||
//
|
||||
class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])>;
|
||||
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])> {
|
||||
}
|
||||
|
||||
//
|
||||
// CastOps
|
||||
@@ -25,7 +30,9 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
// fptoui, fptosi, uitofp, sitofp,
|
||||
// extf, tructf,
|
||||
// extui, extsi, tructi
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffect,
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast int64 to pointer";
|
||||
|
||||
@@ -36,7 +43,9 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffec
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffect,
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast pointer to int64";
|
||||
|
||||
@@ -47,7 +56,9 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffec
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Floating point casting for custom types";
|
||||
|
||||
@@ -67,11 +78,33 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
||||
// TODO: We need a verifier here.
|
||||
}
|
||||
|
||||
//
|
||||
// Pointer Arith Ops
|
||||
//
|
||||
|
||||
def TT_AddPtrOp : TT_Op<"addptr",
|
||||
[NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
"result", "ptr", "$_self">,
|
||||
TypesMatchWith<"result shape matches offset shape",
|
||||
"result", "offset",
|
||||
"getI32SameShape($_self)">]> {
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
|
||||
|
||||
let results = (outs TT_PtrLike:$result);
|
||||
|
||||
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Load/Store Ops
|
||||
//
|
||||
def TT_LoadOp : TT_Op<"load",
|
||||
[SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize,
|
||||
MemoryEffects<[MemRead]>,
|
||||
TypesMatchWith<"infer ptr type from result type",
|
||||
@@ -110,6 +143,7 @@ def TT_LoadOp : TT_Op<"load",
|
||||
|
||||
def TT_StoreOp : TT_Op<"store",
|
||||
[SameOperandsShape,
|
||||
SameOperandsEncoding,
|
||||
MemoryEffects<[MemWrite]>,
|
||||
TypesMatchWith<"infer ptr type from value type",
|
||||
"value", "ptr",
|
||||
@@ -133,130 +167,11 @@ def TT_StoreOp : TT_Op<"store",
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TT_AddPtrOp : TT_Op<"addptr",
|
||||
[NoSideEffect, SameOperandsAndResultShape,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
"result", "ptr", "$_self">,
|
||||
TypesMatchWith<"result shape matches offset shape",
|
||||
"result", "offset",
|
||||
"getI32SameShape($_self)">]> {
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
|
||||
|
||||
let results = (outs TT_PtrLike:$result);
|
||||
|
||||
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Shape Manipulation Ops
|
||||
// Atomic Op
|
||||
//
|
||||
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "expand_dims";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_ViewOp : TT_Op<"view", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "view";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_SplatOp : TT_Op<"splat", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "splat";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "broadcast. No left-padding as of now.";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_CatOp : TT_Op<"cat", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "concatenate 2 tensors";
|
||||
|
||||
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
//
|
||||
// builtin Ops
|
||||
//
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
TypesMatchWith<"result's type matches accumulator's type",
|
||||
"d", "c", "$_self">]> {
|
||||
let summary = "dot";
|
||||
|
||||
let description = [{
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||
|
||||
// let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect]> {
|
||||
let summary = "reduce";
|
||||
|
||||
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "atomic rmw";
|
||||
|
||||
let description = [{
|
||||
@@ -271,7 +186,8 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
||||
def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "atomic cas";
|
||||
|
||||
let description = [{
|
||||
@@ -289,10 +205,133 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// External Function Ops
|
||||
// Shape Manipulation Ops
|
||||
//
|
||||
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "splat";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "expand_dims";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "view";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
}
|
||||
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "broadcast. No left-padding as of now.";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "concatenate 2 tensors";
|
||||
|
||||
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
//
|
||||
// SPMD Ops
|
||||
//
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// Dot Op
|
||||
//
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
TypesMatchWith<"result's type matches accumulator's type",
|
||||
"d", "c", "$_self">]> {
|
||||
let summary = "dot";
|
||||
|
||||
let description = [{
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||
}
|
||||
|
||||
//
|
||||
// Reduce Op
|
||||
//
|
||||
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "reduce";
|
||||
|
||||
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// External elementwise op
|
||||
//
|
||||
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize]> {
|
||||
let summary = "ext_elemwise";
|
||||
|
||||
@@ -307,10 +346,11 @@ def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOpe
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
// Intrinsics
|
||||
// Make Range Op
|
||||
//
|
||||
// TODO: should have ConstantLike as Trait
|
||||
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
||||
|
@@ -11,6 +11,7 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#define TRITONGPU_ATTRDEFS
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TritonGPU Attribute Definitions
|
||||
@@ -34,6 +35,7 @@ Right now, Triton implements two classes of layouts: shared, and distributed.
|
||||
|
||||
code extraBaseClassDeclaration = [{
|
||||
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
|
||||
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -301,15 +303,15 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
TODO: improve docs
|
||||
|
||||
A = [x x x x x x x x]
|
||||
[x x x x x x x x]
|
||||
L_parent = [0 1 2 3 ]
|
||||
[4 5 6 7 ]
|
||||
[8 9 10 11]
|
||||
[12 13 14 15]
|
||||
|
||||
parent = [0 1 2 3 ]
|
||||
[4 5 6 7 ]
|
||||
[8 9 10 11]
|
||||
[12 13 14 15]
|
||||
dim = 0
|
||||
|
||||
Then the data of A would be distributed as follow between the 16 CUDA threads:
|
||||
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15} ]
|
||||
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ]
|
||||
|
||||
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
|
||||
|
||||
|
@@ -76,15 +76,4 @@ def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::Modu
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> {
|
||||
let summary = "verify TritonGPU IR";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUVerifier()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::gpu::GPUDialect"];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user