Add more Ops

This commit is contained in:
Yan Da
2022-03-28 19:50:23 +08:00
parent 0d139ec460
commit 38e67b4293
2 changed files with 150 additions and 53 deletions

View File

@@ -26,6 +26,7 @@ def TT_FloatTensor : TensorOf<[TT_Float]>;
// IntegerType
def TT_Int : AnyTypeOf<[I8, I16, I32, I64], "integer">;
def TT_IntegerTensor : TensorOf<[TT_Int]>;
def TT_I1Tensor : TensorOf<[I1]>;
// PointerType
def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">;
@@ -145,6 +146,14 @@ def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]
//
// Shape Manipulation Ops
//
def TT_ReshapeOp : TT_Op<"reshape", [SameOperandsAndResultElementType]> {
let summary = "reshape";
let arguments = (ins TT_Tensor:$src, I64ArrayAttr:$shape);
let results = (outs TT_Tensor:$result);
}
def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
let summary = "broadcast";
@@ -170,6 +179,12 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
let results = (outs I32:$result);
}
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
}
def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> {
let summary = "dot";
@@ -227,7 +242,7 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
}];
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr,
TT_Type:$val);
TT_Type:$val, TT_I1Tensor:$mask);
let results = (outs TT_Type:$result);
}
@@ -245,7 +260,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
return $old
}];
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$cmp, TT_Type:$val);
let arguments = (ins TT_AnyPtr:$ptr, TT_Type:$cmp, TT_Type:$val);
let results = (outs TT_Type:$result);
}