More on type inference & assembly format
This commit is contained in:
@@ -11,15 +11,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
|||||||
//
|
//
|
||||||
// Types
|
// Types
|
||||||
//
|
//
|
||||||
// FloatType
|
|
||||||
// def F8 : Type<CPred<"$_self.isa<::mlir::Float8Type()>">,
|
|
||||||
// /*descr*/"8bit float",
|
|
||||||
// /*cppClassName*/"::mlir::triton::Float8Type">;
|
|
||||||
|
|
||||||
// def BF8 : Type<CPred<"$_self.isa<::mlir::triton::BFloat8Type>()">,
|
|
||||||
// /*descr*/"8bit bfloat",
|
|
||||||
// /*cppClassName*/"::mlir::triton::BFloat8Type">;
|
|
||||||
|
|
||||||
class TritonTypeDef<string name, string _mnemonic>
|
class TritonTypeDef<string name, string _mnemonic>
|
||||||
: TypeDef<Triton_Dialect, name> {
|
: TypeDef<Triton_Dialect, name> {
|
||||||
// Used by printer/parser
|
// Used by printer/parser
|
||||||
@@ -35,11 +26,8 @@ def TT_FloatTensor : TensorOf<[TT_Float]>;
|
|||||||
// IntegerType
|
// IntegerType
|
||||||
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
|
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
|
||||||
def TT_IntegerTensor : TensorOf<[TT_Int]>;
|
def TT_IntegerTensor : TensorOf<[TT_Int]>;
|
||||||
def TT_I1Tensor : TensorOf<[I1]>;
|
|
||||||
|
|
||||||
// PointerType
|
// PointerType
|
||||||
// def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">;
|
|
||||||
// def TT_AnyPtr : DialectType<Triton_Dialect, TT_IsPtrType, "any Triton pointer type">;
|
|
||||||
def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
|
def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
|
||||||
let summary = "pointer type";
|
let summary = "pointer type";
|
||||||
|
|
||||||
@@ -141,7 +129,17 @@ def TT_EvictionPolicyAttr : I32EnumAttr<
|
|||||||
let cppNamespace = "::mlir::triton";
|
let cppNamespace = "::mlir::triton";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> {
|
def TT_LoadOp : TT_Op<"load",
|
||||||
|
[SameOperandsAndResultShape,
|
||||||
|
TypesMatchWith<"infer ptr type from result type",
|
||||||
|
"result", "ptr",
|
||||||
|
"getPointerTypeFromTensor($_self)">,
|
||||||
|
TypesMatchWith<"infer mask type from result type",
|
||||||
|
"result", "mask",
|
||||||
|
"getI1SameShape($_self)">,
|
||||||
|
TypesMatchWith<"infer other type from result type",
|
||||||
|
"result", "other",
|
||||||
|
"$_self">]> {
|
||||||
let summary = "load";
|
let summary = "load";
|
||||||
|
|
||||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other,
|
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other,
|
||||||
@@ -157,27 +155,41 @@ def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> {
|
|||||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
||||||
];
|
];
|
||||||
|
|
||||||
|
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_StoreOp : TT_Op<"store", [SameOperandsShape]> {
|
def TT_StoreOp : TT_Op<"store",
|
||||||
|
[SameOperandsShape,
|
||||||
|
TypesMatchWith<"infer ptr type from value type",
|
||||||
|
"value", "ptr",
|
||||||
|
"getPointerTypeFromTensor($_self)">,
|
||||||
|
TypesMatchWith<"infer mask type from value type",
|
||||||
|
"value", "mask",
|
||||||
|
"getI1SameShape($_self)">]> {
|
||||||
let summary = "store";
|
let summary = "store";
|
||||||
|
|
||||||
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask);
|
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask);
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
// for args with default values
|
|
||||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||||
];
|
];
|
||||||
|
|
||||||
// let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict";
|
let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict `:` type($value)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]> {
|
def TT_GEPOp : TT_Op<"getelementptr",
|
||||||
let arguments = (ins TT_Type:$ptr, TT_IntegerTensor:$offset);
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
|
TypesMatchWith<"result type matches ptr type",
|
||||||
|
"result", "ptr", "$_self">,
|
||||||
|
TypesMatchWith<"result shape matches offset shape",
|
||||||
|
"result", "offset",
|
||||||
|
"getI32SameShape($_self)">]> {
|
||||||
|
let arguments = (ins TT_PtrTensor:$ptr, I32Tensor:$offset);
|
||||||
|
|
||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_PtrTensor:$result);
|
||||||
|
|
||||||
// let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)";
|
let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -190,6 +202,8 @@ def TT_ReshapeOp : TT_Op<"reshape", [SameOperandsAndResultElementType]> {
|
|||||||
let arguments = (ins TT_Tensor:$src);
|
let arguments = (ins TT_Tensor:$src);
|
||||||
|
|
||||||
let results = (outs TT_Tensor:$result);
|
let results = (outs TT_Tensor:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
|
def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
|
||||||
@@ -198,6 +212,8 @@ def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
|
|||||||
let arguments = (ins TT_Type:$src);
|
let arguments = (ins TT_Type:$src);
|
||||||
|
|
||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> {
|
def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> {
|
||||||
@@ -206,6 +222,8 @@ def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> {
|
|||||||
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
|
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
|
||||||
|
|
||||||
let results = (outs TT_Tensor:$result);
|
let results = (outs TT_Tensor:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -223,9 +241,13 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
|||||||
let arguments = (ins I32Attr:$axis);
|
let arguments = (ins I32Attr:$axis);
|
||||||
|
|
||||||
let results = (outs I32:$result);
|
let results = (outs I32:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "attr-dict `:` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect]> {
|
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||||
|
TypesMatchWith<"result's type matches accumulator's type",
|
||||||
|
"d", "c", "$_self">]> {
|
||||||
let summary = "dot";
|
let summary = "dot";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@@ -235,6 +257,8 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect]> {
|
|||||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||||
|
|
||||||
let results = (outs TT_FpIntTensor:$d);
|
let results = (outs TT_FpIntTensor:$d);
|
||||||
|
|
||||||
|
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduction
|
// reduction
|
||||||
@@ -286,7 +310,7 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr,
|
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr,
|
||||||
TT_Type:$val, TT_I1Tensor:$mask);
|
TT_Type:$val, I1Tensor:$mask);
|
||||||
|
|
||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
}
|
}
|
||||||
|
@@ -6,6 +6,37 @@
|
|||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/OperationSupport.h"
|
#include "mlir/IR/OperationSupport.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
// Type inference
|
||||||
|
static Type getI1SameShape(Type type) {
|
||||||
|
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||||
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||||
|
return RankedTensorType::get(tensorType.getShape(), i1Type);
|
||||||
|
return Type();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Type getI32SameShape(Type type) {
|
||||||
|
auto i32Type = IntegerType::get(type.getContext(), 32);
|
||||||
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||||
|
return RankedTensorType::get(tensorType.getShape(), i32Type);
|
||||||
|
return Type();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Type getPointerTypeFromTensor(Type type) {
|
||||||
|
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||||
|
Type elementType = tensorType.getElementType();
|
||||||
|
auto shape = tensorType.getShape();
|
||||||
|
PointerType ptrType = PointerType::get(elementType, 1);
|
||||||
|
return RankedTensorType::get(shape, ptrType);
|
||||||
|
}
|
||||||
|
return Type();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "triton/ir/Ops.cpp.inc"
|
#include "triton/ir/Ops.cpp.inc"
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user