[OPTIMIZER] Made layout simplification pass efficient for fused attention kernels (#790)

This commit is contained in:
Philippe Tillet
2022-10-21 16:52:15 -07:00
committed by GitHub
parent c4726333bf
commit bb0f9235d1
26 changed files with 683 additions and 229 deletions

View File

@@ -272,7 +272,7 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
//
// SPMD Ops
//
def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoSideEffect]> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
@@ -280,7 +280,7 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id"> {
let assemblyFormat = "attr-dict `:` type($result)";
}
def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
@@ -301,7 +301,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB);
let results = (outs TT_FpIntTensor:$d);
@@ -324,6 +324,7 @@ def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
];
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
}
//

View File

@@ -328,4 +328,23 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
}
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
let mnemonic = "dot_op";
let description = [{
In TritonGPU dialect, considering `d = tt.dot a, b, c`
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
a's opIdx is 0, b's opIdx is 1.
The parend field in DotOperandEncodingAttr is the layout of d.
}];
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent
);
let extraClassDeclaration = extraBaseClassDeclaration;
}
#endif

View File

@@ -14,7 +14,7 @@ class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>;
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
[NoSideEffect]> {
[SameOperandsAndResultShape, NoSideEffect]> {
let summary = "convert layout";
let arguments = (ins TT_Tensor:$src);
@@ -32,10 +32,10 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
let assemblyFormat = "attr-dict";
}
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
// This is needed because Arith's Cmp ops don't
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
// This is needed because these ops don't
// handle encodings
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
let summary = "integer comparison operation";
@@ -48,7 +48,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
let results = (outs TT_BoolLike:$result);
}
def TTG_CmpFOp : TTG_Op<"cmpf"> {
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect]> {
let summary = "floating-point comparison operation";
let description = [{}];
@@ -60,6 +60,20 @@ def TTG_CmpFOp : TTG_Op<"cmpf"> {
let results = (outs TT_BoolLike:$result);
}
// TODO: migrate to arith::SelectOp on LLVM16
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> {
let summary = "select operation";
let description = [{}];
let arguments = (ins TT_BoolLike:$condition,
TT_Tensor:$true_value,
TT_Tensor:$false_value);
let results = (outs TT_Tensor:$result);
}
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[SameVariadicOperandSize,
// MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should?