This commit is contained in:
Phil Tillet
2022-12-28 13:42:43 -08:00
parent 7aba2a60d6
commit 54ae3e8d6e
7 changed files with 47 additions and 82 deletions

View File

@@ -1339,11 +1339,15 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUPrefetchPass());
})
.def("add_triton_gpu_combine_pass",
.def("add_tritongpu_combine_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(
mlir::createTritonGPUCombineOpsPass(computeCapability));
})
.def("add_tritongpu_optimize_load_convert_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUOptimizeLoadConvertPass());
})
.def("add_triton_gpu_to_llvm",
[](mlir::PassManager &self) {
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());

View File

@@ -894,17 +894,18 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
pm.add_coalesce_pass()
# The combine pass converts blocked layout to mma layout
# for dot ops so that pipeline can get shared memory swizzled correctly.
pm.add_triton_gpu_combine_pass(compute_capability)
pm.add_tritongpu_combine_pass(compute_capability)
pm.add_tritongpu_pipeline_pass(num_stages)
# Prefetch must be done after pipeline pass because pipeline pass
# extracts slices from the original tensor.
pm.add_tritongpu_prefetch_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_triton_gpu_combine_pass(compute_capability)
pm.add_tritongpu_combine_pass(compute_capability)
pm.add_licm_pass()
pm.add_triton_gpu_combine_pass(compute_capability)
pm.add_tritongpu_combine_pass(compute_capability)
pm.add_cse_pass()
pm.add_tritongpu_optimize_load_convert_pass()
pm.run(mod)
return mod

View File

@@ -191,7 +191,7 @@ def _bwd_kernel(
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
_fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4)
# _fwd_kernel = triton.compile("./flash-attention.ttgir", num_warps=4)
empty = torch.empty(128, device="cuda")
@@ -210,28 +210,28 @@ class _attention(torch.autograd.Function):
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
# _fwd_kernel[grid](
# q, k, v, sm_scale,
# L, m,
# o,
# q.stride(0), q.stride(1), q.stride(2), q.stride(3),
# k.stride(0), k.stride(1), k.stride(2), k.stride(3),
# v.stride(0), v.stride(1), v.stride(2), v.stride(3),
# o.stride(0), o.stride(1), o.stride(2), o.stride(3),
# q.shape[0], q.shape[1], q.shape[2],
# BLOCK_M=BLOCK, BLOCK_N=BLOCK,
# BLOCK_DMODEL=Lk, num_warps=num_warps,
# num_stages=1,
# )
_fwd_kernel[grid](
q.data_ptr(), k.data_ptr(), v.data_ptr(), sm_scale,
L.data_ptr(), m.data_ptr(),
o.data_ptr(),
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
o.stride(0), o.stride(1), o.stride(2),
q.shape[0], q.shape[1], q.shape[2])
q, k, v, sm_scale,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=1,
)
# _fwd_kernel[grid](
# q.data_ptr(), k.data_ptr(), v.data_ptr(), sm_scale,
# L.data_ptr(), m.data_ptr(),
# o.data_ptr(),
# q.stride(0), q.stride(1), q.stride(2),
# k.stride(0), k.stride(1), k.stride(2),
# v.stride(0), v.stride(1), v.stride(2),
# o.stride(0), o.stride(1), o.stride(2),
# q.shape[0], q.shape[1], q.shape[2])
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK