[FRONTEND] Add scalar type support for some ops (#661)
This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
@@ -29,22 +29,26 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffec
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast int64 to pointer";
|
||||
|
||||
let arguments = (ins I64Tensor:$from);
|
||||
let arguments = (ins TT_I64Like:$from);
|
||||
|
||||
let results = (outs TT_PtrTensor:$result);
|
||||
let results = (outs TT_PtrLike:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast pointer to int64";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$from);
|
||||
let arguments = (ins TT_PtrLike:$from);
|
||||
|
||||
let results = (outs I64Tensor:$result);
|
||||
let results = (outs TT_I64Like:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Floating point casting for custom types";
|
||||
|
||||
let description = [{
|
||||
@@ -54,9 +58,11 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
||||
BF8 <-> F8, FP16, FP32
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FloatTensor:$from);
|
||||
let arguments = (ins TT_FloatLike:$from);
|
||||
|
||||
let results = (outs TT_FloatTensor:$result);
|
||||
let results = (outs TT_FloatLike:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
|
||||
// TODO: We need a verifier here.
|
||||
}
|
||||
@@ -127,16 +133,16 @@ def TT_StoreOp : TT_Op<"store",
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TT_GEPOp : TT_Op<"getelementptr",
|
||||
def TT_AddPtrOp : TT_Op<"addptr",
|
||||
[NoSideEffect, SameOperandsAndResultShape,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
"result", "ptr", "$_self">,
|
||||
TypesMatchWith<"result shape matches offset shape",
|
||||
"result", "offset",
|
||||
"getI32SameShape($_self)">]> {
|
||||
let arguments = (ins TT_PtrTensor:$ptr, I32Tensor:$offset);
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
|
||||
|
||||
let results = (outs TT_PtrTensor:$result);
|
||||
let results = (outs TT_PtrLike:$result);
|
||||
|
||||
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
|
||||
}
|
||||
@@ -278,7 +284,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
||||
return $old
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_Pointer:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||
let arguments = (ins TT_Ptr:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
@@ -318,7 +324,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
||||
|
||||
let arguments = (ins I32Attr:$start, I32Attr:$end);
|
||||
|
||||
let results = (outs TT_IntegerTensor:$result);
|
||||
let results = (outs TT_IntTensor:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
@@ -12,18 +12,36 @@ class TritonTypeDef<string name, string _mnemonic>
|
||||
let mnemonic = _mnemonic;
|
||||
}
|
||||
|
||||
// Floating-point Type
|
||||
def F8 : TritonTypeDef<"Float8", "f8">;
|
||||
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
|
||||
|
||||
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
|
||||
|
||||
// IntegerType
|
||||
// Boolean Type
|
||||
// TT_Bool -> I1
|
||||
def TT_BoolTensor : TensorOf<[I1]>;
|
||||
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;
|
||||
|
||||
// Integer Type
|
||||
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
|
||||
def TT_IntegerTensor : TensorOf<[TT_Int]>;
|
||||
def TT_IntTensor : TensorOf<[TT_Int]>;
|
||||
def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;
|
||||
|
||||
// PointerType
|
||||
def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
|
||||
// I32 Type
|
||||
// TT_I32 -> I32
|
||||
// TT_I32Tensor -> I32Tensor
|
||||
def TT_I32Like: AnyTypeOf<[I32, I32Tensor]>;
|
||||
|
||||
// I64 Type
|
||||
// TT_I64 -> I64
|
||||
// TT_I64Tensor -> I64Tensor
|
||||
def TT_I64Like: AnyTypeOf<[I64, I64Tensor]>;
|
||||
|
||||
// Pointer Type
|
||||
def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
|
||||
let summary = "pointer type";
|
||||
|
||||
let description = [{
|
||||
@@ -43,12 +61,12 @@ def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
def TT_PtrTensor : TensorOf<[TT_Pointer]>;
|
||||
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
|
||||
def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>;
|
||||
|
||||
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>;
|
||||
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntTensor]>;
|
||||
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
|
||||
|
||||
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
|
||||
TT_Pointer, TT_PtrTensor]>;
|
||||
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike]>;
|
||||
|
||||
#endif
|
||||
|
@@ -8,7 +8,7 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
|
||||
let description = [{
|
||||
dot(a, b, 0) + c => dot(a, b, c)
|
||||
|
||||
gep(gep(ptr, idx0), idx1) => gep(ptr, AddI(idx0, idx1))
|
||||
addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1))
|
||||
|
||||
select(cond, load(ptrs, broadcast(cond), ???), other) =>
|
||||
load(ptrs, broadcast(cond), other)
|
||||
|
@@ -10,12 +10,6 @@ include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
|
||||
def TT_BoolTensor : TensorOf<[I1]>;
|
||||
|
||||
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;
|
||||
def TT_IntegerLike : AnyTypeOf<[TT_Int, TT_IntegerTensor]>;
|
||||
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
|
||||
|
||||
class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<TritonGPU_Dialect, mnemonic, traits>;
|
||||
|
||||
@@ -48,8 +42,8 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
|
||||
let description = [{}];
|
||||
|
||||
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
|
||||
TT_IntegerLike:$lhs,
|
||||
TT_IntegerLike:$rhs);
|
||||
TT_IntLike:$lhs,
|
||||
TT_IntLike:$rhs);
|
||||
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
@@ -66,7 +60,7 @@ def TTG_CmpFOp : TTG_Op<"cmpf"> {
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
||||
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
[SameVariadicOperandSize,
|
||||
MemoryEffects<[MemRead, MemWrite]>,
|
||||
TypesMatchWith<"infer mask type from src type",
|
||||
@@ -94,7 +88,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
* other: optional tensor-rank number of other tensors which specify what
|
||||
values are inserted into the `$dst` tensor if the corresponding
|
||||
element of the `$mask` tensor is false.
|
||||
|
||||
|
||||
In the future, we may decompose this operation into a sequence of:
|
||||
|
||||
* `async` operation to specify a sequence of asynchronous operations
|
||||
@@ -191,7 +185,7 @@ def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect]> {
|
||||
Note: This op can be repalced to a `bufferization.alloc_tensor` in LLVM 16.
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{attr-dict `:` type($result)}];
|
||||
let assemblyFormat = [{attr-dict `:` type($result)}];
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
|
Reference in New Issue
Block a user