[STYLE] run autopep8 and isort (#421)

Run:
```
isort ./python
autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py')
```
with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
Madeleine Thompson
2022-01-06 14:34:17 -08:00
committed by GitHub
parent 120cda015e
commit 8bf551ae7a
30 changed files with 742 additions and 623 deletions

View File

@@ -86,6 +86,7 @@ def patch_kernel(template, to_replace):
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes])
def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128
@triton.jit
def kernel(X, SIZE: tl.constexpr):
pass
@@ -97,6 +98,7 @@ def test_empty_kernel(dtype_x, device='cuda'):
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
@@ -153,6 +155,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
@@ -206,11 +209,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
# ---------------
# test binary ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op)
for op in ['+', '-', '*', '/', '%']
for dtype_x in dtypes
for dtype_y in dtypes
for op in ['+', '-', '*', '/', '%']
for dtype_x in dtypes
for dtype_y in dtypes
])
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
expr = f' x {op} y'
@@ -242,9 +247,9 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
@pytest.mark.parametrize("dtype_x, dtype_y",
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
)
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
)
def test_floordiv(dtype_x, dtype_y, device='cuda'):
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
# through to //, so we have to use a nonstandard expression to get a
@@ -298,22 +303,24 @@ def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
# test compare ops
# ---------------
ops = ['==', '!=', '>', '<', '>=', '<=']
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", \
# real
[
(dtype_x, dtype_y, op, 'real', 'real') \
for op in ops \
for dtype_x in dtypes \
for dtype_y in dtypes
] + \
# NaNs
[('float32', 'float32', op, mode_x, mode_y) \
for op in ops
for mode_x, mode_y in [('nan' , 'real'),
('real', 'nan'),
('nan' , 'nan')]
])
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
# real
[
(dtype_x, dtype_y, op, 'real', 'real')
for op in ops
for dtype_x in dtypes
for dtype_y in dtypes
] +
# NaNs
[('float32', 'float32', op, mode_x, mode_y)
for op in ops
for mode_x, mode_y in [('nan', 'real'),
('real', 'nan'),
('nan', 'nan')]
])
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
expr = f'x {op} y'
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
@@ -343,6 +350,7 @@ def test_unary_op(dtype_x, expr, device='cuda'):
# 'exp', 'log', 'cos', 'sin'
# ])
@pytest.mark.parametrize("expr", [
'exp', 'log', 'cos', 'sin'
])
@@ -368,8 +376,8 @@ def make_ptr_str(name, shape):
@pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', d)
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
for d in ['int32', 'uint32', 'uint16']
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
for d in ['int32', 'uint32', 'uint16']
])
def test_index1d(expr, dtype_str, device='cuda'):
rank_x = expr.count(':')
@@ -413,8 +421,8 @@ def test_index1d(expr, dtype_str, device='cuda'):
@triton.jit
def fn(a, b):
return a + b, \
a - b, \
a * b
a - b, \
a * b
def test_tuples():
@@ -510,8 +518,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
(dtype_x, dtype_z, False)
for dtype_x in dtypes
for dtype_z in dtypes
for dtype_x in dtypes
for dtype_z in dtypes
] + [
('float32', 'bfloat16', False),
('bfloat16', 'float32', False),
@@ -534,7 +542,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
@triton.jit
def kernel(X, Z, BITCAST: tl.constexpr):
x = tl.load(X)
z = x.to(Z.dtype.element_ty, bitcast = BITCAST)
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
tl.store(Z, z)
# triton result
@@ -558,10 +566,12 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# ---------------
# test reduce
# ---------------
@pytest.mark.parametrize("dtype_str, shape",
[(dtype, shape) \
for dtype in dtypes\
for shape in [128, 512]])
[(dtype, shape)
for dtype in dtypes
for shape in [128, 512]])
def test_reduce1d(dtype_str, shape, device='cuda'):
# triton kernel
@@ -591,7 +601,7 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
range_m = tl.arange(0, BLOCK_M)
range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None]*BLOCK_N + range_n[None, :])
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = tl.sum(x, axis=AXIS)
tl.store(Z + range_m, z)
# input
@@ -608,11 +618,13 @@ def test_reduce2d(dtype_str, shape, axis, device='cuda'):
# ---------------
# test permute
# ---------------
@pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm) \
for dtype in ['float32']\
for shape in [(128, 128)]\
for perm in [(1, 0)]])
[(dtype, shape, perm)
for dtype in ['float32']
for shape in [(128, 128)]
for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'):
# triton kernel
@@ -646,6 +658,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# test dot
# ---------------
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'):
# triton kernel
@@ -687,17 +700,17 @@ def test_dot(epilogue, device='cuda'):
y_tri, y_tri.stride(0), y_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX = epilogue=='add-matrix',
ADD_ROWS = epilogue=='add-rows',
ADD_COLS = epilogue=='add-cols')
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols')
# torch result
z_ref = np.matmul(x, y)
if epilogue == 'add-matrix':
z_ref += z
if epilogue == 'add-rows':
z_ref += z[:,0][:, None]
z_ref += z[:, 0][:, None]
if epilogue == 'add-cols':
z_ref += z[0,:][None, :]
z_ref += z[0, :][None, :]
# compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
@@ -705,6 +718,7 @@ def test_dot(epilogue, device='cuda'):
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
def test_dot_without_load():
@triton.jit
def kernel(out):
@@ -713,28 +727,30 @@ def test_dot_without_load():
b = tl.zeros((32, 32), tl.float32)
c = tl.zeros((32, 32), tl.float32)
c = tl.dot(a, b)
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :]
pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
tl.store(pout, c)
out = torch.ones((32,32), dtype=torch.float32, device="cuda")
out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
kernel[(1,)](out)
# ---------------
# test arange
# ---------------
@pytest.mark.parametrize("start", [0, 1, 7, 16])
def test_arange(start, device='cuda'):
BLOCK = 128
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
@triton.jit
def _kernel(z, BLOCK: tl.constexpr,
START: tl.constexpr, END: tl.constexpr):
off = tl.arange(0, BLOCK)
val = tl.arange(START, END)
tl.store(z + off, val)
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
_kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
triton.testing.assert_almost_equal(z_tri, z_ref)
# ---------------
@@ -742,6 +758,8 @@ def test_arange(start, device='cuda'):
# ---------------
# 'bfloat16': torch.bfloat16,
# Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'):
M = 32
@@ -762,8 +780,8 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
N_offsets = tl.arange(0, N)
K_offsets = tl.arange(0, K)
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:]
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:]
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :]
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
# Load inputs.
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
@@ -773,21 +791,22 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
o = tl.dot(x, w)
# Store output
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:]
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
pgm = _kernel[(1,)](in1, in2, out,
in1.stride()[0],
in2.stride()[0],
out.stride()[0],
in1.numel(),
in2.numel(),
out.numel(),
M=M, N=N, K=K)
in1.stride()[0],
in2.stride()[0],
out.stride()[0],
in1.numel(),
in2.numel(),
out.numel(),
M=M, N=N, K=K)
reference_out =torch.matmul(in1, in2)
reference_out = torch.matmul(in1, in2)
triton.testing.allclose(out, reference_out)
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
def test_load_cache_modifier(cache):
src = torch.empty(128, device='cuda')
@@ -796,8 +815,8 @@ def test_load_cache_modifier(cache):
@triton.jit
def _kernel(dst, src, CACHE: tl.constexpr):
offsets = tl.arange(0, 128)
x = tl.load(src+offsets, cache_modifier=CACHE)
tl.store(dst+offsets, x)
x = tl.load(src + offsets, cache_modifier=CACHE)
tl.store(dst + offsets, x)
pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
@@ -830,11 +849,14 @@ def test_load_cache_modifier(cache):
# ---------------
# test default
# ---------------
#TODO: can't be local to test_default
# TODO: can't be local to test_default
@triton.jit
def _impl(value = 10):
def _impl(value=10):
return value
def test_default():
value = 5
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
@@ -851,7 +873,9 @@ def test_default():
# ---------------
# test noop
#----------------
# ----------------
def test_noop(device='cuda'):
@triton.jit
def kernel(x):
@@ -861,9 +885,9 @@ def test_noop(device='cuda'):
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'),
(2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64')
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:

View File

@@ -1,16 +1,17 @@
import torch
import triton
import triton.language as tl
import numpy as np
import pytest
import scipy.stats
import numpy as np
import torch
from numpy.random import Philox
import triton
import triton.language as tl
#####################################
## Reference Philox Implementation
# Reference Philox Implementation
#####################################
class PhiloxConfig:
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
@@ -103,18 +104,21 @@ class CustomPhilox(CustomPhilox4x):
#####################################
## Unit Tests
# Unit Tests
#####################################
BLOCK = 1024
# test generation of random uint32
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']\
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
)
[(size, seed) for size in ['10', '4,53', '10000']
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
)
def test_randint(size, seed, device='cuda'):
size = list(map(int, size.split(',')))
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
@@ -132,10 +136,12 @@ def test_randint(size, seed, device='cuda'):
assert out_tri == out_ref
# test uniform PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\
for seed in [0, 42, 124, 54]]
)
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_rand(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
@@ -151,10 +157,12 @@ def test_rand(size, seed, device='cuda'):
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]\
for seed in [0, 42, 124, 54]]
)
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_randn(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):

View File

@@ -1,6 +1,7 @@
import torch
import triton
import pytest
import torch
import triton
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@@ -71,7 +72,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE):
# torch result
rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf"))
# broadcast at_mask to the same shape as rx
if is_causal: at_mask = torch.tril(at_mask)
if is_causal:
at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(rx)
rx[M == 0] = float("-inf")
# rx += kp_mask[:, None, None, :]

View File

@@ -1,14 +1,16 @@
import torch
import triton
import pytest
import torch
import triton
@pytest.mark.parametrize("M, N, dtype, mode",
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']\
for mode in ['forward', 'backward']
]
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']
for mode in ['forward', 'backward']
]
)
def test_op(M, N, dtype, mode):
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
@@ -30,4 +32,4 @@ def test_op(M, N, dtype, mode):
x.grad.zero_()
th_y.backward(dy)
th_dx = x.grad.clone()
triton.testing.assert_almost_equal(th_dx, tt_dx)
triton.testing.assert_almost_equal(th_dx, tt_dx)

View File

@@ -1,8 +1,10 @@
import pytest
import itertools
import triton
import pytest
import torch
import triton
@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
@@ -80,11 +82,11 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
K = BLOCK_K * SPLIT_K if K is None else K
# allocate/transpose inputs
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
a = .1*torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = .1*torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a
b = b.t() if BT else b
# run test
th_c = torch.matmul(a, b)
tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest)
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
triton.testing.assert_almost_equal(th_c, tt_c)

View File

@@ -1,13 +1,16 @@
import torch
import triton
from triton.code_gen import JITFunction
import triton.language as tl
import os
import shutil
import pytest
import torch
import triton
import triton.language as tl
from triton.code_gen import JITFunction
tmpdir = ".tmp"
@triton.jit
def function_1(i):
i = i + 1
@@ -20,18 +23,21 @@ def function_2(i):
i = i + 1
return i
@triton.jit
def kernel(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
@triton.jit(do_not_specialize=["i"])
def kernel_nospec(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
def apply_src_change(target, old, new):
delattr(kernel.fn, 'hash')
delattr(function_1.fn, 'hash')
@@ -42,28 +48,34 @@ def apply_src_change(target, old, new):
target.src = target.src.replace(new, old)
return ret
def test_nochange():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
assert baseline == updated
def test_toplevel_change():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
assert baseline != updated
def test_nested1_change():
baseline = kernel.cache_key
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
assert baseline != updated
def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
def test_reuse():
counter = 0
def inc_counter(key, binary, repr):
nonlocal counter
counter += 1
@@ -73,11 +85,12 @@ def test_reuse():
for i in range(10):
kernel[(1,)](x, 1, BLOCK=1024)
assert counter == 1
@pytest.mark.parametrize('mode', ['enable', 'disable'])
def test_specialize(mode):
counter = 0
def inc_counter(key, binary, repr):
nonlocal counter
counter += 1

View File

@@ -1,9 +1,11 @@
import torch
import triton
import pytest
import subprocess
import triton.language as tl
import numpy as np
import pytest
import torch
import triton
import triton.language as tl
def get_p2p_matrix():