[Triton-MLIR] tt.dot operands now must have DotOperand layout; also added prefetch pass prototype (#712)

Co-authored-by: Jokeren <kerenzhou@openai.com>
Co-authored-by: Phil Tillet <phil@openai.com>
Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Da Yan
2022-11-10 13:57:27 +08:00
committed by GitHub
parent 8832e32683
commit 4946167241
29 changed files with 1227 additions and 507 deletions

View File

@@ -4,6 +4,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -30,7 +31,15 @@ public:
virtual LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const = 0;
Attribute &resultEncoding,
Optional<Location> location) const = 0;
// Note: this function only verify operand encoding but doesn't infer result
// encoding
virtual LogicalResult
inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
Attribute retEncoding,
Optional<Location> location) const = 0;
};
} // namespace triton

View File

@@ -330,7 +330,6 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
}];
}
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
let mnemonic = "dot_op";

View File

@@ -37,7 +37,7 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
// This is needed because these ops don't
// handle encodings
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
let summary = "integer comparison operation";

View File

@@ -6,6 +6,9 @@
namespace mlir {
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
// TODO(Keren): prefetch pass not working yet
std::unique_ptr<Pass> createTritonGPUPrefetchPass();
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
std::unique_ptr<Pass> createTritonGPUSwizzlePass();

View File

@@ -7,7 +7,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
let summary = "pipeline";
let description = [{
TODO
Unroll loops to hide global memory -> shared memory latency.
}];
let constructor = "mlir::createTritonGPUPipelinePass()";
@@ -23,6 +23,20 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
];
}
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
let summary = "prefetch";
let description = [{
Prefetch operands (a and b) of tt.dot into shared memory to hide shared memory -> register latency.
}];
let constructor = "mlir::createTritonGPUPrefetchPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithmeticDialect"];
}
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
let summary = "coalesce";