[Triton-MLIR] tt.dot
operands now must have DotOperand layout; also added prefetch pass prototype (#712)
Co-authored-by: Jokeren <kerenzhou@openai.com> Co-authored-by: Phil Tillet <phil@openai.com> Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
@@ -30,7 +31,15 @@ public:
|
||||
|
||||
virtual LogicalResult
|
||||
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||
Attribute &resultEncoding) const = 0;
|
||||
Attribute &resultEncoding,
|
||||
Optional<Location> location) const = 0;
|
||||
|
||||
// Note: this function only verify operand encoding but doesn't infer result
|
||||
// encoding
|
||||
virtual LogicalResult
|
||||
inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
|
||||
Attribute retEncoding,
|
||||
Optional<Location> location) const = 0;
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
|
@@ -330,7 +330,6 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
|
||||
let mnemonic = "dot_op";
|
||||
|
||||
|
@@ -37,7 +37,7 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
|
||||
// This is needed because these ops don't
|
||||
// handle encodings
|
||||
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
|
||||
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
|
||||
let summary = "integer comparison operation";
|
||||
|
||||
|
@@ -6,6 +6,9 @@
|
||||
namespace mlir {
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
|
||||
|
||||
// TODO(Keren): prefetch pass not working yet
|
||||
std::unique_ptr<Pass> createTritonGPUPrefetchPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUSwizzlePass();
|
||||
|
@@ -7,7 +7,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||
let summary = "pipeline";
|
||||
|
||||
let description = [{
|
||||
TODO
|
||||
Unroll loops to hide global memory -> shared memory latency.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUPipelinePass()";
|
||||
@@ -23,6 +23,20 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
|
||||
let summary = "prefetch";
|
||||
|
||||
let description = [{
|
||||
Prefetch operands (a and b) of tt.dot into shared memory to hide shared memory -> register latency.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUPrefetchPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithmeticDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
|
||||
let summary = "coalesce";
|
||||
|
||||
|
Reference in New Issue
Block a user