[LANG] Added support for constexpr (#361)
This commit is contained in:
@@ -67,8 +67,8 @@ def test_matmul(M, N, K):
|
||||
import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def _add(x_ptr, y_ptr, output_ptr, n_elements, **meta):
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE']
|
||||
def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
@@ -99,7 +99,7 @@ def test_elementwise(N):
|
||||
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
|
||||
x = torch.randn_like(z)
|
||||
y = torch.randn_like(z)
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
|
||||
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250)
|
||||
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6
|
||||
|
Reference in New Issue
Block a user