More on type inference & assembly format

This commit is contained in:
Yan Da
2022-04-08 19:37:57 +08:00
parent 13aead4808
commit 62f7609612
2 changed files with 77 additions and 22 deletions

View File

@@ -11,15 +11,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
//
// Types
//
// FloatType
// def F8 : Type<CPred<"$_self.isa<::mlir::Float8Type()>">,
// /*descr*/"8bit float",
// /*cppClassName*/"::mlir::triton::Float8Type">;
// def BF8 : Type<CPred<"$_self.isa<::mlir::triton::BFloat8Type>()">,
// /*descr*/"8bit bfloat",
// /*cppClassName*/"::mlir::triton::BFloat8Type">;
class TritonTypeDef<string name, string _mnemonic>
: TypeDef<Triton_Dialect, name> {
// Used by printer/parser
@@ -35,11 +26,8 @@ def TT_FloatTensor : TensorOf<[TT_Float]>;
// IntegerType
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
def TT_IntegerTensor : TensorOf<[TT_Int]>;
def TT_I1Tensor : TensorOf<[I1]>;
// PointerType
// def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">;
// def TT_AnyPtr : DialectType<Triton_Dialect, TT_IsPtrType, "any Triton pointer type">;
def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
let summary = "pointer type";
@@ -141,7 +129,17 @@ def TT_EvictionPolicyAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}
def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> {
def TT_LoadOp : TT_Op<"load",
[SameOperandsAndResultShape,
TypesMatchWith<"infer ptr type from result type",
"result", "ptr",
"getPointerTypeFromTensor($_self)">,
TypesMatchWith<"infer mask type from result type",
"result", "mask",
"getI1SameShape($_self)">,
TypesMatchWith<"infer other type from result type",
"result", "other",
"$_self">]> {
let summary = "load";
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other,
@@ -157,27 +155,41 @@ def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> {
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
];
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($result)";
}
def TT_StoreOp : TT_Op<"store", [SameOperandsShape]> {
def TT_StoreOp : TT_Op<"store",
[SameOperandsShape,
TypesMatchWith<"infer ptr type from value type",
"value", "ptr",
"getPointerTypeFromTensor($_self)">,
TypesMatchWith<"infer mask type from value type",
"value", "mask",
"getI1SameShape($_self)">]> {
let summary = "store";
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask);
let builders = [
// for args with default values
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
];
// let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict";
let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict `:` type($value)";
}
def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]> {
let arguments = (ins TT_Type:$ptr, TT_IntegerTensor:$offset);
def TT_GEPOp : TT_Op<"getelementptr",
[NoSideEffect, SameOperandsAndResultShape,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">,
TypesMatchWith<"result shape matches offset shape",
"result", "offset",
"getI32SameShape($_self)">]> {
let arguments = (ins TT_PtrTensor:$ptr, I32Tensor:$offset);
let results = (outs TT_Type:$result);
let results = (outs TT_PtrTensor:$result);
// let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)";
let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)";
}
@@ -190,6 +202,8 @@ def TT_ReshapeOp : TT_Op<"reshape", [SameOperandsAndResultElementType]> {
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
@@ -198,6 +212,8 @@ def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
let arguments = (ins TT_Type:$src);
let results = (outs TT_Type:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> {
@@ -206,6 +222,8 @@ def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> {
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
}
//
@@ -223,9 +241,13 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
def TT_DotOp : TT_Op<"dot", [NoSideEffect]> {
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot";
let description = [{
@@ -235,6 +257,8 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect]> {
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let results = (outs TT_FpIntTensor:$d);
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
}
// reduction
@@ -286,7 +310,7 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
}];
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr,
TT_Type:$val, TT_I1Tensor:$mask);
TT_Type:$val, I1Tensor:$mask);
let results = (outs TT_Type:$result);
}

View File

@@ -6,6 +6,37 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h"
namespace mlir {
namespace triton {
// Type inference
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type);
return Type();
}
static Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i32Type);
return Type();
}
static Type getPointerTypeFromTensor(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Type elementType = tensorType.getElementType();
auto shape = tensorType.getShape();
PointerType ptrType = PointerType::get(elementType, 1);
return RankedTensorType::get(shape, ptrType);
}
return Type();
}
}
}
#define GET_OP_CLASSES
#include "triton/ir/Ops.cpp.inc"