SinkConversionsFromShared template
This commit is contained in:
@@ -15,6 +15,8 @@ std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
|||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUOptimizeLoadConvertPass();
|
std::unique_ptr<Pass> createTritonGPUOptimizeLoadConvertPass();
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createTritonGPUSinkConversionsFromSharedPass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
|
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUVerifier();
|
std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||||
|
@@ -84,6 +84,18 @@ def TritonGPUOptimizeLoadConvert: Pass<"tritongpu-optimize-load-convert", "mlir:
|
|||||||
"mlir::triton::TritonDialect"];
|
"mlir::triton::TritonDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TritonGPUSinkConversionsFromShared: Pass<"tritongpu-sink-conversions-from-shared", "mlir::ModuleOp"> {
|
||||||
|
let summary = "Sink conversions from shared into loops";
|
||||||
|
|
||||||
|
let description = "This pass sinks conversions from shared memory into loops. This will lead the codegen "
|
||||||
|
"to keep data in shared memory throughout loops, which will reduce register pressure.";
|
||||||
|
|
||||||
|
let constructor = "mlir::createTritonGPUSinkConversionsFromSharedPass()";
|
||||||
|
|
||||||
|
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||||
|
"mlir::triton::TritonDialect"];
|
||||||
|
}
|
||||||
|
|
||||||
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
||||||
let summary = "canonicalize scf.ForOp ops";
|
let summary = "canonicalize scf.ForOp ops";
|
||||||
|
|
||||||
|
@@ -9,6 +9,7 @@ add_mlir_dialect_library(TritonGPUTransforms
|
|||||||
Pipeline.cpp
|
Pipeline.cpp
|
||||||
Prefetch.cpp
|
Prefetch.cpp
|
||||||
OptimizeLoadConvert.cpp
|
OptimizeLoadConvert.cpp
|
||||||
|
SinkConversionsFromShared.cpp
|
||||||
TritonGPUConversion.cpp
|
TritonGPUConversion.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
|
@@ -1348,6 +1348,11 @@ void init_triton_ir(py::module &&m) {
|
|||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::createTritonGPUOptimizeLoadConvertPass());
|
self.addPass(mlir::createTritonGPUOptimizeLoadConvertPass());
|
||||||
})
|
})
|
||||||
|
.def("add_tritongpu_sink_conversions_from_shared_pass",
|
||||||
|
[](mlir::PassManager &self) {
|
||||||
|
self.addPass(
|
||||||
|
mlir::createTritonGPUSinkConversionsFromSharedPass());
|
||||||
|
})
|
||||||
.def("add_triton_gpu_to_llvm",
|
.def("add_triton_gpu_to_llvm",
|
||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||||
|
@@ -191,7 +191,7 @@ def _bwd_kernel(
|
|||||||
tl.store(dv_ptrs, dv)
|
tl.store(dv_ptrs, dv)
|
||||||
tl.store(dk_ptrs, dk)
|
tl.store(dk_ptrs, dk)
|
||||||
|
|
||||||
_bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
|
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
|
||||||
# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
|
# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
|
||||||
|
|
||||||
empty = torch.empty(128, device="cuda")
|
empty = torch.empty(128, device="cuda")
|
||||||
@@ -256,36 +256,36 @@ class _attention(torch.autograd.Function):
|
|||||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||||
)
|
)
|
||||||
|
|
||||||
_bwd_kernel[(ctx.grid[1],1,1)](
|
# _bwd_kernel[(ctx.grid[1],1,1)](
|
||||||
q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
|
# q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
|
||||||
o.data_ptr(), do_scaled.data_ptr(),
|
# o.data_ptr(), do_scaled.data_ptr(),
|
||||||
dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
|
# dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
|
||||||
l.data_ptr(), m.data_ptr(),
|
# l.data_ptr(), m.data_ptr(),
|
||||||
delta.data_ptr(),
|
# delta.data_ptr(),
|
||||||
q.stride(0), q.stride(1), q.stride(2),
|
# q.stride(0), q.stride(1), q.stride(2),
|
||||||
k.stride(0), k.stride(1), k.stride(2),
|
# k.stride(0), k.stride(1), k.stride(2),
|
||||||
v.stride(0), v.stride(1), v.stride(2),
|
# v.stride(0), v.stride(1), v.stride(2),
|
||||||
q.shape[0], q.shape[1], q.shape[2],
|
|
||||||
ctx.grid[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
# pgm = _bwd_kernel[(ctx.grid[1],)](
|
|
||||||
# q, k, v, ctx.sm_scale,
|
|
||||||
# o, do_scaled,
|
|
||||||
# dq, dk, dv,
|
|
||||||
# l, m,
|
|
||||||
# delta,
|
|
||||||
# q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
|
||||||
# k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
|
||||||
# v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
|
||||||
# q.shape[0], q.shape[1], q.shape[2],
|
# q.shape[0], q.shape[1], q.shape[2],
|
||||||
# ctx.grid[0],
|
# ctx.grid[0]
|
||||||
# BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
|
||||||
# BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
|
||||||
# num_stages=1,
|
|
||||||
# )
|
# )
|
||||||
# print(pgm.asm["ttgir"])
|
|
||||||
# # exit(1)
|
pgm = _bwd_kernel[(ctx.grid[1],)](
|
||||||
|
q, k, v, ctx.sm_scale,
|
||||||
|
o, do_scaled,
|
||||||
|
dq, dk, dv,
|
||||||
|
l, m,
|
||||||
|
delta,
|
||||||
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||||
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||||
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||||
|
q.shape[0], q.shape[1], q.shape[2],
|
||||||
|
ctx.grid[0],
|
||||||
|
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||||
|
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||||
|
num_stages=1,
|
||||||
|
)
|
||||||
|
print(pgm.asm["ttgir"])
|
||||||
|
exit(1)
|
||||||
return dq, dk, dv, None
|
return dq, dk, dv, None
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user