more pass template
This commit is contained in:
@@ -1345,14 +1345,18 @@ void init_triton_ir(py::module &&m) {
|
||||
mlir::createTritonGPUCombineOpsPass(computeCapability));
|
||||
})
|
||||
.def("add_tritongpu_optimize_load_convert_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUOptimizeLoadConvertPass());
|
||||
})
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUOptimizeLoadConvertPass());
|
||||
})
|
||||
.def("add_tritongpu_sink_conversions_from_shared_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(
|
||||
mlir::createTritonGPUSinkConversionsFromSharedPass());
|
||||
})
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUSinkConversionsFromSharedPass());
|
||||
})
|
||||
.def("add_tritongpu_decompose_conversions_to_dot_operand_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(
|
||||
mlir::createTritonGPUDecomposeConversionsToDotOperandPass());
|
||||
})
|
||||
.def("add_triton_gpu_to_llvm",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||
|
@@ -906,6 +906,8 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
||||
pm.add_tritongpu_combine_pass(compute_capability)
|
||||
pm.add_cse_pass()
|
||||
# pm.add_tritongpu_optimize_load_convert_pass()
|
||||
pm.add_tritongpu_sink_conversions_from_shared_pass()
|
||||
pm.add_tritongpu_decompose_conversions_to_dot_operand_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
@@ -194,8 +194,10 @@ def _bwd_kernel(
|
||||
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
|
||||
# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
|
||||
|
||||
|
||||
empty = torch.empty(128, device="cuda")
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@@ -255,7 +257,7 @@ class _attention(torch.autograd.Function):
|
||||
do_scaled, delta,
|
||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
|
||||
|
||||
# _bwd_kernel[(ctx.grid[1],1,1)](
|
||||
# q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
|
||||
# o.data_ptr(), do_scaled.data_ptr(),
|
||||
@@ -284,8 +286,8 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
num_stages=1,
|
||||
)
|
||||
print(pgm.asm["ttgir"])
|
||||
exit(1)
|
||||
# print(pgm.asm["ttgir"])
|
||||
# exit()
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
@@ -327,6 +329,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
@@ -358,8 +361,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
flops_per_matmul = 2.*BATCH*H*N_CTX*N_CTX*D_HEAD*0.5
|
||||
total_flops = 2*flops_per_matmul
|
||||
flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
|
||||
total_flops = 2 * flops_per_matmul
|
||||
# print(total_flops/ms*1e-9)
|
||||
print(ms)
|
||||
return ms
|
||||
@@ -376,4 +379,5 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
bench_flash_attention.run(save_path='.', print_data=True)
|
||||
|
||||
# bench_flash_attention.run(save_path='.', print_data=True)
|
||||
|
Reference in New Issue
Block a user