431 lines
14 KiB
TableGen
431 lines
14 KiB
TableGen
#ifndef TRITON_OPS
|
|
#define TRITON_OPS
|
|
|
|
include "triton/Dialect/Triton/IR/TritonDialect.td"
|
|
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
|
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
|
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
|
include "mlir/IR/OpBase.td"
|
|
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
|
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
|
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
|
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
|
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
|
|
|
|
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
|
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
|
|
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
|
|
|
|
//
|
|
// Op Base
|
|
//
|
|
class TT_Op<string mnemonic, list<Trait> traits = []> :
|
|
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])> {
|
|
}
|
|
|
|
//
|
|
// CastOps
|
|
//
|
|
// Use cast ops in arith:
|
|
// bitcast
|
|
// fptoui, fptosi, uitofp, sitofp,
|
|
// extf, tructf,
|
|
// extui, extsi, tructi
|
|
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
|
SameOperandsAndResultEncoding,
|
|
NoSideEffect,
|
|
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
|
let summary = "Cast int64 to pointer";
|
|
|
|
let arguments = (ins TT_I64Like:$from);
|
|
|
|
let results = (outs TT_PtrLike:$result);
|
|
|
|
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
|
}
|
|
|
|
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
|
SameOperandsAndResultEncoding,
|
|
NoSideEffect,
|
|
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
|
let summary = "Cast pointer to int64";
|
|
|
|
let arguments = (ins TT_PtrLike:$from);
|
|
|
|
let results = (outs TT_I64Like:$result);
|
|
|
|
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
|
}
|
|
|
|
// arith.bitcast doesn't support pointers
|
|
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
|
SameOperandsAndResultEncoding,
|
|
NoSideEffect,
|
|
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
|
let summary = "Cast between types of the same bitwidth";
|
|
|
|
let arguments = (ins TT_Type:$from);
|
|
|
|
let results = (outs TT_Type:$result);
|
|
|
|
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
|
|
|
// TODO: Add verifier
|
|
}
|
|
|
|
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
|
SameOperandsAndResultEncoding,
|
|
NoSideEffect,
|
|
DeclareOpInterfaceMethods<CastOpInterface>]> {
|
|
let summary = "Floating point casting for custom types";
|
|
|
|
let description = [{
|
|
Floating point casting for custom types (F8).
|
|
|
|
F8 <-> FP16, BF16, FP32, FP64
|
|
}];
|
|
|
|
let arguments = (ins TT_FloatLike:$from);
|
|
|
|
let results = (outs TT_FloatLike:$result);
|
|
|
|
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
|
|
|
// TODO: We need a verifier here.
|
|
}
|
|
|
|
//
|
|
// Pointer Arith Ops
|
|
//
|
|
|
|
def TT_AddPtrOp : TT_Op<"addptr",
|
|
[NoSideEffect,
|
|
SameOperandsAndResultShape,
|
|
SameOperandsAndResultEncoding,
|
|
TypesMatchWith<"result type matches ptr type",
|
|
"result", "ptr", "$_self">,
|
|
TypesMatchWith<"result shape matches offset shape",
|
|
"result", "offset",
|
|
"getI32SameShape($_self)">]> {
|
|
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
|
|
|
|
let results = (outs TT_PtrLike:$result);
|
|
|
|
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
|
|
}
|
|
|
|
|
|
//
|
|
// Load/Store Ops
|
|
//
|
|
def TT_LoadOp : TT_Op<"load",
|
|
[SameOperandsAndResultShape,
|
|
SameOperandsAndResultEncoding,
|
|
AttrSizedOperandSegments,
|
|
MemoryEffects<[MemRead]>,
|
|
TypesMatchWith<"infer ptr type from result type",
|
|
"result", "ptr", "getPointerTypeSameShape($_self)">,
|
|
TypesMatchWith<"infer mask type from result type or none",
|
|
"result", "mask", "getI1SameShape($_self)",
|
|
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
|
|
TypesMatchWith<"infer other type from result type or none",
|
|
"result", "other", "$_self",
|
|
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
|
let summary = "load";
|
|
|
|
let arguments = (ins TT_PtrLike:$ptr, Optional<TT_BoolLike>:$mask, Optional<TT_Type>:$other,
|
|
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
|
BoolAttr:$isVolatile);
|
|
|
|
let results = (outs TT_Type:$result);
|
|
|
|
let builders = [
|
|
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
|
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
|
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
|
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
|
];
|
|
|
|
// let assemblyFormat = "operands attr-dict `:` type($result)";
|
|
let parser = [{ return mlir::triton::parseLoadOp(parser, result); }];
|
|
|
|
let printer = [{ return mlir::triton::printLoadOp(p, *this); }];
|
|
|
|
let hasCanonicalizer = 1;
|
|
}
|
|
|
|
def TT_StoreOp : TT_Op<"store",
|
|
[SameOperandsShape,
|
|
SameOperandsEncoding,
|
|
MemoryEffects<[MemWrite]>,
|
|
TypesMatchWith<"infer ptr type from value type",
|
|
"value", "ptr",
|
|
"getPointerTypeSameShape($_self)">,
|
|
TypesMatchWith<"infer mask type from value type",
|
|
"value", "mask", "getI1SameShape($_self)",
|
|
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
|
let summary = "store";
|
|
|
|
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask);
|
|
|
|
let builders = [
|
|
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
|
];
|
|
|
|
// let assemblyFormat = "operands attr-dict `:` type($value)";
|
|
let parser = [{ return mlir::triton::parseStoreOp(parser, result); }];
|
|
|
|
let printer = [{ return mlir::triton::printStoreOp(p, *this); }];
|
|
|
|
let hasCanonicalizer = 1;
|
|
}
|
|
|
|
//
|
|
// Atomic Op
|
|
//
|
|
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
|
SameOperandsAndResultEncoding,
|
|
MemoryEffects<[MemRead]>,
|
|
MemoryEffects<[MemWrite]>,
|
|
TypesMatchWith<"infer ptr type from value type",
|
|
"val", "ptr",
|
|
"getPointerTypeSameShape($_self)">,
|
|
TypesMatchWith<"infer mask type from value type",
|
|
"val", "mask", "getI1SameShape($_self)",
|
|
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
|
let summary = "atomic rmw";
|
|
|
|
let description = [{
|
|
load data at $ptr, do $rmw_op with $val, and store result to $ptr.
|
|
|
|
return old value at $ptr
|
|
}];
|
|
|
|
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr,
|
|
TT_Type:$val, Optional<TT_BoolLike>:$mask);
|
|
|
|
let results = (outs TT_Type:$result);
|
|
}
|
|
|
|
def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
|
|
MemoryEffects<[MemWrite]>,
|
|
SameOperandsAndResultShape,
|
|
SameOperandsAndResultEncoding]> {
|
|
let summary = "atomic cas";
|
|
|
|
let description = [{
|
|
compare $cmp with data $old at location $ptr,
|
|
|
|
if $old == $cmp, store $val to $ptr,
|
|
|
|
else store $old to $ptr,
|
|
|
|
return $old
|
|
}];
|
|
|
|
let arguments = (ins TT_Ptr:$ptr, TT_Type:$cmp, TT_Type:$val);
|
|
|
|
let results = (outs TT_Type:$result);
|
|
}
|
|
|
|
|
|
//
|
|
// Shape Manipulation Ops
|
|
//
|
|
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
|
|
SameOperandsAndResultElementType]> {
|
|
let summary = "splat";
|
|
|
|
let arguments = (ins TT_Type:$src);
|
|
|
|
let results = (outs TT_Tensor:$result);
|
|
|
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
|
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
|
|
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
|
SameOperandsAndResultElementType]> {
|
|
let summary = "expand_dims";
|
|
|
|
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
|
|
|
|
let results = (outs TT_Tensor:$result);
|
|
|
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
|
}
|
|
|
|
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
|
|
SameOperandsAndResultElementType]> {
|
|
let summary = "view";
|
|
|
|
let arguments = (ins TT_Tensor:$src);
|
|
|
|
let results = (outs TT_Tensor:$result);
|
|
|
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
|
|
|
}
|
|
|
|
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
|
|
SameOperandsAndResultElementType]> {
|
|
let summary = "broadcast. No left-padding as of now.";
|
|
|
|
let arguments = (ins TT_Type:$src);
|
|
|
|
let results = (outs TT_Type:$result);
|
|
|
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
|
|
|
let hasFolder = 1;
|
|
}
|
|
|
|
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
|
SameOperandsAndResultElementType]> {
|
|
let summary = "concatenate 2 tensors";
|
|
|
|
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
|
|
|
|
let results = (outs TT_Tensor:$result);
|
|
|
|
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
|
}
|
|
|
|
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
|
|
SameOperandsAndResultElementType]> {
|
|
|
|
let summary = "transpose a tensor";
|
|
|
|
let arguments = (ins TT_Tensor:$src);
|
|
|
|
let results = (outs TT_Tensor:$result);
|
|
|
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
|
}
|
|
|
|
//
|
|
// SPMD Ops
|
|
//
|
|
def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoSideEffect]> {
|
|
let arguments = (ins I32Attr:$axis);
|
|
|
|
let results = (outs I32:$result);
|
|
|
|
let assemblyFormat = "attr-dict `:` type($result)";
|
|
}
|
|
|
|
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> {
|
|
let arguments = (ins I32Attr:$axis);
|
|
|
|
let results = (outs I32:$result);
|
|
|
|
let assemblyFormat = "attr-dict `:` type($result)";
|
|
}
|
|
|
|
//
|
|
// Dot Op
|
|
//
|
|
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
|
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
|
TypesMatchWith<"result's type matches accumulator's type",
|
|
"d", "c", "$_self">]> {
|
|
let summary = "dot";
|
|
|
|
let description = [{
|
|
$d = matrix_multiply($a, $b) + $c
|
|
}];
|
|
|
|
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
|
|
|
let results = (outs TT_FpIntTensor:$d);
|
|
|
|
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
|
}
|
|
|
|
//
|
|
// Reduce Op
|
|
//
|
|
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
|
|
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
|
let summary = "reduce";
|
|
|
|
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
|
|
|
let results = (outs TT_Type:$result);
|
|
|
|
let builders = [
|
|
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
|
|
];
|
|
|
|
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
|
|
|
let extraClassDeclaration = [{
|
|
// This member function is marked static because we need to call it before the ReduceOp
|
|
// is constructed, see the implementation of create_reduce in triton.cc.
|
|
static bool withIndex(mlir::triton::RedOp redOp);
|
|
}];
|
|
}
|
|
|
|
//
|
|
// External elementwise op
|
|
//
|
|
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape,
|
|
SameOperandsAndResultEncoding,
|
|
SameVariadicOperandSize]> {
|
|
let summary = "ext_elemwise";
|
|
|
|
let description = [{
|
|
call an external function $symbol implemented in $libpath/$libname with $args
|
|
|
|
return $libpath/$libname:$symbol($args...)
|
|
}];
|
|
|
|
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
|
|
|
|
let results = (outs TT_Type:$result);
|
|
|
|
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
|
}
|
|
|
|
//
|
|
// Make Range Op
|
|
//
|
|
// TODO: should have ConstantLike as Trait
|
|
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
|
let summary = "make range";
|
|
|
|
let description = [{
|
|
Returns an 1D int32 tensor.
|
|
|
|
Values span from $start to $end (exclusive), with step = 1
|
|
}];
|
|
|
|
let arguments = (ins I32Attr:$start, I32Attr:$end);
|
|
|
|
let results = (outs TT_IntTensor:$result);
|
|
|
|
let assemblyFormat = "attr-dict `:` type($result)";
|
|
}
|
|
|
|
//
|
|
// Make PrintfOp
|
|
//
|
|
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
|
|
Arguments<(ins StrAttr:$prefix,
|
|
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
|
|
let summary = "Device-side printf, as in CUDA for debugging";
|
|
let description = [{
|
|
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
|
|
format are generated automatically from the arguments.
|
|
}];
|
|
let assemblyFormat = [{
|
|
$prefix attr-dict ($args^ `:` type($args))?
|
|
}];
|
|
}
|
|
|
|
#endif // Triton_OPS
|