[LANG] Added support for device functions (#484)

This commit is contained in:
Philippe Tillet
2022-04-03 20:58:16 -07:00
committed by GitHub
parent e85c7a7fc7
commit 2bed6fc850
39 changed files with 1213 additions and 379 deletions

View File

@@ -585,7 +585,6 @@ def test_f8_f16_roundtrip():
f8_output_tensor = torch.empty_like(f16, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
print(f16.dtype, f8_output.dtype)
copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024)
assert torch.all(f8_tensor == f8_output_tensor)
@@ -1009,8 +1008,8 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
# Parse out the type of the 'VALUE' parameter from the Triton IR.
triton_ir = pgm.asm['ttir']
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
ir_value_match = re.match(r'\s*def void (\w+)\((\w+) VALUE ', triton_ir)
ir_value_type = None if ir_value_match is None else ir_value_match.group(2)
assert ir_value_type == value_type
@@ -1031,3 +1030,28 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
kernel[(1, )](value, x)
else:
kernel[(1, )](value, x)
# -------------------------
# test dynamic parallelism
# -------------------------
@triton.jit
def mult(x, alpha):
tl.store(x + tl.program_id(0), alpha)
@triton.jit
def stub(X, alpha, grid_0, grid_1, grid_2):
tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2])
def test_dyn_par(cond=True, device='cuda'):
n_pids = 10
# pids = torch.arange(n_pids, device=device)
# alpha = 2.0
# x_ref = pids * alpha
x_tri = torch.full((10,), fill_value=-1., device=device)
# cond = torch.tensor([cond], device=device)
stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
print(x_tri)
# triton.testing.assert_almost_equal(x_ref, x_tri)