From 62f7609612ce0b87b2f5c9d554ef1197379bc230 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Fri, 8 Apr 2022 19:37:57 +0800 Subject: [PATCH] More on type inference & assembly format --- include/triton/ir/TritonOps.td | 68 +++++++++++++++++++++++----------- lib/ir/Ops.cpp | 31 ++++++++++++++++ 2 files changed, 77 insertions(+), 22 deletions(-) diff --git a/include/triton/ir/TritonOps.td b/include/triton/ir/TritonOps.td index 6ae554928..3f66b6d65 100644 --- a/include/triton/ir/TritonOps.td +++ b/include/triton/ir/TritonOps.td @@ -11,15 +11,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType // // Types // -// FloatType -// def F8 : Type">, -// /*descr*/"8bit float", -// /*cppClassName*/"::mlir::triton::Float8Type">; - -// def BF8 : Type()">, -// /*descr*/"8bit bfloat", -// /*cppClassName*/"::mlir::triton::BFloat8Type">; - class TritonTypeDef : TypeDef { // Used by printer/parser @@ -35,11 +26,8 @@ def TT_FloatTensor : TensorOf<[TT_Float]>; // IntegerType def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; def TT_IntegerTensor : TensorOf<[TT_Int]>; -def TT_I1Tensor : TensorOf<[I1]>; // PointerType -// def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">; -// def TT_AnyPtr : DialectType; def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> { let summary = "pointer type"; @@ -141,7 +129,17 @@ def TT_EvictionPolicyAttr : I32EnumAttr< let cppNamespace = "::mlir::triton"; } -def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> { +def TT_LoadOp : TT_Op<"load", + [SameOperandsAndResultShape, + TypesMatchWith<"infer ptr type from result type", + "result", "ptr", + "getPointerTypeFromTensor($_self)">, + TypesMatchWith<"infer mask type from result type", + "result", "mask", + "getI1SameShape($_self)">, + TypesMatchWith<"infer other type from result type", + "result", "other", + "$_self">]> { let summary = "load"; let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other, @@ -157,27 +155,41 @@ def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> { OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict, "bool":$isVolatile)> ]; + + let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($result)"; } -def TT_StoreOp : TT_Op<"store", [SameOperandsShape]> { +def TT_StoreOp : TT_Op<"store", + [SameOperandsShape, + TypesMatchWith<"infer ptr type from value type", + "value", "ptr", + "getPointerTypeFromTensor($_self)">, + TypesMatchWith<"infer mask type from value type", + "value", "mask", + "getI1SameShape($_self)">]> { let summary = "store"; let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask); let builders = [ - // for args with default values OpBuilder<(ins "Value":$ptr, "Value":$value)>, ]; - // let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict"; + let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict `:` type($value)"; } -def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]> { - let arguments = (ins TT_Type:$ptr, TT_IntegerTensor:$offset); +def TT_GEPOp : TT_Op<"getelementptr", + [NoSideEffect, SameOperandsAndResultShape, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">, + TypesMatchWith<"result shape matches offset shape", + "result", "offset", + "getI32SameShape($_self)">]> { + let arguments = (ins TT_PtrTensor:$ptr, I32Tensor:$offset); - let results = (outs TT_Type:$result); + let results = (outs TT_PtrTensor:$result); - // let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)"; + let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)"; } @@ -190,6 +202,8 @@ def TT_ReshapeOp : TT_Op<"reshape", [SameOperandsAndResultElementType]> { let arguments = (ins TT_Tensor:$src); let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; } def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> { @@ -198,6 +212,8 @@ def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> { let arguments = (ins TT_Type:$src); let results = (outs TT_Type:$result); + + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; } def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> { @@ -206,6 +222,8 @@ def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> { let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)"; } // @@ -223,9 +241,13 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> { let arguments = (ins I32Attr:$axis); let results = (outs I32:$result); + + let assemblyFormat = "attr-dict `:` type($result)"; } -def TT_DotOp : TT_Op<"dot", [NoSideEffect]> { +def TT_DotOp : TT_Op<"dot", [NoSideEffect, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { let summary = "dot"; let description = [{ @@ -235,6 +257,8 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect]> { let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32); let results = (outs TT_FpIntTensor:$d); + + let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)"; } // reduction @@ -286,7 +310,7 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> { }]; let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr, - TT_Type:$val, TT_I1Tensor:$mask); + TT_Type:$val, I1Tensor:$mask); let results = (outs TT_Type:$result); } diff --git a/lib/ir/Ops.cpp b/lib/ir/Ops.cpp index 3eb33cbbe..32d9f53ee 100644 --- a/lib/ir/Ops.cpp +++ b/lib/ir/Ops.cpp @@ -6,6 +6,37 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OperationSupport.h" +namespace mlir { +namespace triton { + +// Type inference +static Type getI1SameShape(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorType = type.dyn_cast()) + return RankedTensorType::get(tensorType.getShape(), i1Type); + return Type(); +} + +static Type getI32SameShape(Type type) { + auto i32Type = IntegerType::get(type.getContext(), 32); + if (auto tensorType = type.dyn_cast()) + return RankedTensorType::get(tensorType.getShape(), i32Type); + return Type(); +} + +static Type getPointerTypeFromTensor(Type type) { + if (auto tensorType = type.dyn_cast()) { + Type elementType = tensorType.getElementType(); + auto shape = tensorType.getShape(); + PointerType ptrType = PointerType::get(elementType, 1); + return RankedTensorType::get(shape, ptrType); + } + return Type(); +} + +} +} + #define GET_OP_CLASSES #include "triton/ir/Ops.cpp.inc"