update mma encoding & triton-opt
This commit is contained in:
@@ -126,7 +126,7 @@ def TT_GEPOp : TT_Op<"getelementptr",
|
||||
|
||||
let results = (outs TT_PtrTensor:$result);
|
||||
|
||||
let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)";
|
||||
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
|
||||
|
@@ -9,6 +9,10 @@ namespace triton {
|
||||
std::unique_ptr<Pass> createCombineOpsPass();
|
||||
|
||||
}
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -14,7 +14,7 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
|
||||
load(ptrs, broadcast(cond), other)
|
||||
}];
|
||||
|
||||
let constructor = "mlir::triton::createCombineOpsPass";
|
||||
let constructor = "mlir::triton::createCombineOpsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||
/*SelectOp*/"mlir::StandardOpsDialect"];
|
||||
|
@@ -130,7 +130,7 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
||||
// TODO: should we rename this as blockTileSize? (consistent naming with Distributed layout)
|
||||
ArrayRefParameter<"unsigned">:$shapePerTile,
|
||||
// TODO: should Distributed layout also
|
||||
ArrayRefParameter<"unsigned">:$reptitions,
|
||||
ArrayRefParameter<"unsigned">:$repetitions,
|
||||
ArrayRefParameter<"unsigned">:$contigPerThread
|
||||
// "AffineMap":$warpOrdering,
|
||||
// "AffineMap":$blockOrdering
|
||||
|
@@ -26,6 +26,8 @@ def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||
|
@@ -4,19 +4,15 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages);
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
|
||||
|
||||
namespace triton {
|
||||
namespace gpu {
|
||||
std::unique_ptr<Pass> createCombineOpsPass();
|
||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||
}
|
||||
}
|
||||
|
||||
// /// Generate the code for registering passes.
|
||||
// #define GEN_PASS_REGISTRATION
|
||||
// #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
/// Generate the code for registering passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
#endif
|
||||
|
@@ -19,7 +19,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||
...
|
||||
}];
|
||||
|
||||
let constructor = "mlir::triton::gpu::createPipelinePass";
|
||||
let constructor = "mlir::createTritonGPUPipelinePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
@@ -45,7 +45,7 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
||||
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
|
||||
}];
|
||||
|
||||
let constructor = "mlir::triton::gpu::createCombineOpsPass";
|
||||
let constructor = "mlir::createTritonGPUCombineOpsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
@@ -56,7 +56,7 @@ def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> {
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let constructor = "mlir::triton::gpu::createTritonGPUVerifier";
|
||||
let constructor = "mlir::createTritonGPUVerifier()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
Reference in New Issue
Block a user