[OPTIMIZER] Made layout simplification pass efficient for fused attention kernels (#790)
This commit is contained in:
@@ -272,7 +272,7 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||
//
|
||||
// SPMD Ops
|
||||
//
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoSideEffect]> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
@@ -280,7 +280,7 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
@@ -301,7 +301,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
@@ -324,6 +324,7 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
|
||||
];
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
|
@@ -328,4 +328,23 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
}
|
||||
|
||||
|
||||
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
|
||||
let mnemonic = "dot_op";
|
||||
|
||||
let description = [{
|
||||
In TritonGPU dialect, considering `d = tt.dot a, b, c`
|
||||
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
|
||||
a's opIdx is 0, b's opIdx is 1.
|
||||
The parend field in DotOperandEncodingAttr is the layout of d.
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$opIdx,
|
||||
"Attribute":$parent
|
||||
);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -14,7 +14,7 @@ class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<TritonGPU_Dialect, mnemonic, traits>;
|
||||
|
||||
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
|
||||
[NoSideEffect]> {
|
||||
[SameOperandsAndResultShape, NoSideEffect]> {
|
||||
let summary = "convert layout";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
@@ -32,10 +32,10 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
|
||||
// This is needed because Arith's Cmp ops don't
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
|
||||
// This is needed because these ops don't
|
||||
// handle encodings
|
||||
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
|
||||
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
|
||||
let summary = "integer comparison operation";
|
||||
|
||||
@@ -48,7 +48,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
||||
def TTG_CmpFOp : TTG_Op<"cmpf"> {
|
||||
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect]> {
|
||||
let summary = "floating-point comparison operation";
|
||||
|
||||
let description = [{}];
|
||||
@@ -60,6 +60,20 @@ def TTG_CmpFOp : TTG_Op<"cmpf"> {
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
||||
// TODO: migrate to arith::SelectOp on LLVM16
|
||||
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> {
|
||||
let summary = "select operation";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let arguments = (ins TT_BoolLike:$condition,
|
||||
TT_Tensor:$true_value,
|
||||
TT_Tensor:$false_value);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
}
|
||||
|
||||
|
||||
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
[SameVariadicOperandSize,
|
||||
// MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should?
|
||||
|
Reference in New Issue
Block a user