[LANG] Added support for device functions (#484)
This commit is contained in:
@@ -585,7 +585,6 @@ def test_f8_f16_roundtrip():
|
||||
|
||||
f8_output_tensor = torch.empty_like(f16, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
print(f16.dtype, f8_output.dtype)
|
||||
copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
@@ -1009,8 +1008,8 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
||||
|
||||
# Parse out the type of the 'VALUE' parameter from the Triton IR.
|
||||
triton_ir = pgm.asm['ttir']
|
||||
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
|
||||
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
|
||||
ir_value_match = re.match(r'\s*def void (\w+)\((\w+) VALUE ', triton_ir)
|
||||
ir_value_type = None if ir_value_match is None else ir_value_match.group(2)
|
||||
assert ir_value_type == value_type
|
||||
|
||||
|
||||
@@ -1031,3 +1030,28 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
|
||||
kernel[(1, )](value, x)
|
||||
else:
|
||||
kernel[(1, )](value, x)
|
||||
# -------------------------
|
||||
# test dynamic parallelism
|
||||
# -------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def mult(x, alpha):
|
||||
tl.store(x + tl.program_id(0), alpha)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def stub(X, alpha, grid_0, grid_1, grid_2):
|
||||
tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2])
|
||||
|
||||
|
||||
def test_dyn_par(cond=True, device='cuda'):
|
||||
n_pids = 10
|
||||
# pids = torch.arange(n_pids, device=device)
|
||||
# alpha = 2.0
|
||||
# x_ref = pids * alpha
|
||||
x_tri = torch.full((10,), fill_value=-1., device=device)
|
||||
# cond = torch.tensor([cond], device=device)
|
||||
stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
|
||||
print(x_tri)
|
||||
# triton.testing.assert_almost_equal(x_ref, x_tri)
|
||||
|
Reference in New Issue
Block a user