[LANG] Added support for constexpr (#361)

This commit is contained in:
Philippe Tillet
2021-10-30 00:32:58 -07:00
committed by GitHub
parent 770ea96cca
commit 2acaa4d0dd
16 changed files with 355 additions and 365 deletions

View File

@@ -67,8 +67,8 @@ def test_matmul(M, N, K):
import triton.language as tl
@triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements, **meta):
BLOCK_SIZE = meta['BLOCK_SIZE']
def _add(x_ptr, y_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
@@ -99,7 +99,7 @@ def test_elementwise(N):
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z)
y = torch.randn_like(z)
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250)
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6