Add more Ops
This commit is contained in:
@@ -26,6 +26,7 @@ def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||
// IntegerType
|
||||
def TT_Int : AnyTypeOf<[I8, I16, I32, I64], "integer">;
|
||||
def TT_IntegerTensor : TensorOf<[TT_Int]>;
|
||||
def TT_I1Tensor : TensorOf<[I1]>;
|
||||
|
||||
// PointerType
|
||||
def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">;
|
||||
@@ -145,6 +146,14 @@ def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]
|
||||
//
|
||||
// Shape Manipulation Ops
|
||||
//
|
||||
def TT_ReshapeOp : TT_Op<"reshape", [SameOperandsAndResultElementType]> {
|
||||
let summary = "reshape";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src, I64ArrayAttr:$shape);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
}
|
||||
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
|
||||
let summary = "broadcast";
|
||||
|
||||
@@ -170,6 +179,12 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
||||
let results = (outs I32:$result);
|
||||
}
|
||||
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
}
|
||||
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let summary = "dot";
|
||||
|
||||
@@ -227,7 +242,7 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr,
|
||||
TT_Type:$val);
|
||||
TT_Type:$val, TT_I1Tensor:$mask);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
@@ -245,7 +260,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
||||
return $old
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||
let arguments = (ins TT_AnyPtr:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
Reference in New Issue
Block a user