[FRONTEND] Add scalar type support for some ops (#661)

This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
Shintaro Iwasaki
2022-09-15 16:12:52 -07:00
committed by GitHub
parent 2e08450c80
commit 43be75ad42
27 changed files with 203 additions and 129 deletions

View File

@@ -29,22 +29,26 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffec
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast int64 to pointer";
let arguments = (ins I64Tensor:$from);
let arguments = (ins TT_I64Like:$from);
let results = (outs TT_PtrTensor:$result);
let results = (outs TT_PtrLike:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast pointer to int64";
let arguments = (ins TT_PtrTensor:$from);
let arguments = (ins TT_PtrLike:$from);
let results = (outs I64Tensor:$result);
let results = (outs TT_I64Like:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Floating point casting for custom types";
let description = [{
@@ -54,9 +58,11 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
BF8 <-> F8, FP16, FP32
}];
let arguments = (ins TT_FloatTensor:$from);
let arguments = (ins TT_FloatLike:$from);
let results = (outs TT_FloatTensor:$result);
let results = (outs TT_FloatLike:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
// TODO: We need a verifier here.
}
@@ -127,16 +133,16 @@ def TT_StoreOp : TT_Op<"store",
let hasCanonicalizer = 1;
}
def TT_GEPOp : TT_Op<"getelementptr",
def TT_AddPtrOp : TT_Op<"addptr",
[NoSideEffect, SameOperandsAndResultShape,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">,
TypesMatchWith<"result shape matches offset shape",
"result", "offset",
"getI32SameShape($_self)">]> {
let arguments = (ins TT_PtrTensor:$ptr, I32Tensor:$offset);
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
let results = (outs TT_PtrTensor:$result);
let results = (outs TT_PtrLike:$result);
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
}
@@ -278,7 +284,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
return $old
}];
let arguments = (ins TT_Pointer:$ptr, TT_Type:$cmp, TT_Type:$val);
let arguments = (ins TT_Ptr:$ptr, TT_Type:$cmp, TT_Type:$val);
let results = (outs TT_Type:$result);
}
@@ -318,7 +324,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
let arguments = (ins I32Attr:$start, I32Attr:$end);
let results = (outs TT_IntegerTensor:$result);
let results = (outs TT_IntTensor:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}

View File

@@ -12,18 +12,36 @@ class TritonTypeDef<string name, string _mnemonic>
let mnemonic = _mnemonic;
}
// Floating-point Type
def F8 : TritonTypeDef<"Float8", "f8">;
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
// IntegerType
// Boolean Type
// TT_Bool -> I1
def TT_BoolTensor : TensorOf<[I1]>;
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;
// Integer Type
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
def TT_IntegerTensor : TensorOf<[TT_Int]>;
def TT_IntTensor : TensorOf<[TT_Int]>;
def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;
// PointerType
def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
// I32 Type
// TT_I32 -> I32
// TT_I32Tensor -> I32Tensor
def TT_I32Like: AnyTypeOf<[I32, I32Tensor]>;
// I64 Type
// TT_I64 -> I64
// TT_I64Tensor -> I64Tensor
def TT_I64Like: AnyTypeOf<[I64, I64Tensor]>;
// Pointer Type
def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
let summary = "pointer type";
let description = [{
@@ -43,12 +61,12 @@ def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
let skipDefaultBuilders = 1;
}
def TT_PtrTensor : TensorOf<[TT_Pointer]>;
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>;
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>;
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntTensor]>;
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
TT_Pointer, TT_PtrTensor]>;
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike]>;
#endif

View File

@@ -8,7 +8,7 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
let description = [{
dot(a, b, 0) + c => dot(a, b, c)
gep(gep(ptr, idx0), idx1) => gep(ptr, AddI(idx0, idx1))
addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1))
select(cond, load(ptrs, broadcast(cond), ???), other) =>
load(ptrs, broadcast(cond), other)

View File

@@ -10,12 +10,6 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
def TT_BoolTensor : TensorOf<[I1]>;
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;
def TT_IntegerLike : AnyTypeOf<[TT_Int, TT_IntegerTensor]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>;
@@ -48,8 +42,8 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
let description = [{}];
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
TT_IntegerLike:$lhs,
TT_IntegerLike:$rhs);
TT_IntLike:$lhs,
TT_IntLike:$rhs);
let results = (outs TT_BoolLike:$result);
}
@@ -66,7 +60,7 @@ def TTG_CmpFOp : TTG_Op<"cmpf"> {
let results = (outs TT_BoolLike:$result);
}
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[SameVariadicOperandSize,
MemoryEffects<[MemRead, MemWrite]>,
TypesMatchWith<"infer mask type from src type",
@@ -94,7 +88,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
* other: optional tensor-rank number of other tensors which specify what
values are inserted into the `$dst` tensor if the corresponding
element of the `$mask` tensor is false.
In the future, we may decompose this operation into a sequence of:
* `async` operation to specify a sequence of asynchronous operations
@@ -191,7 +185,7 @@ def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect]> {
Note: This op can be repalced to a `bufferization.alloc_tensor` in LLVM 16.
}];
let assemblyFormat = [{attr-dict `:` type($result)}];
let assemblyFormat = [{attr-dict `:` type($result)}];
let results = (outs TT_Tensor:$result);