trying more things
This commit is contained in:
@@ -1,19 +0,0 @@
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm,
|
||||
Z, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1
|
||||
Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn
|
||||
tl.store(Zs, tl.load(Xs))
|
||||
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
|
||||
print(ret)
|
@@ -1,13 +0,0 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
|
||||
pass
|
||||
|
||||
|
||||
X = torch.randn(1, device="cuda")
|
||||
pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)
|
@@ -1562,7 +1562,7 @@ class CompiledKernel:
|
||||
if self.shared > max_shared:
|
||||
raise OutOfResources(self.shared, max_shared, "shared memory")
|
||||
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
print(n_regs, n_spills)
|
||||
print(self.shared, n_regs, n_spills)
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
|
||||
|
@@ -194,6 +194,7 @@ def _bwd_kernel(
|
||||
|
||||
|
||||
empty = torch.empty(128, device="cuda")
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@@ -220,7 +221,7 @@ class _attention(torch.autograd.Function):
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=2,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
@@ -335,7 +336,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
flops_per_matmul = 2.*BATCH*H*N_CTX*N_CTX*D_HEAD*0.5
|
||||
total_flops = 2*flops_per_matmul
|
||||
print(total_flops/ms*1e-9)
|
||||
# print(total_flops/ms*1e-9)
|
||||
print(ms)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
|
Reference in New Issue
Block a user