[STYLE] run autopep8 and isort (#421)

Run:
```
isort ./python
autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py')
```
with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
Madeleine Thompson
2022-01-06 14:34:17 -08:00
committed by GitHub
parent 120cda015e
commit 8bf551ae7a
30 changed files with 742 additions and 623 deletions

View File

@@ -1,13 +1,16 @@
import torch
import triton
from triton.code_gen import JITFunction
import triton.language as tl
import os
import shutil
import pytest
import torch
import triton
import triton.language as tl
from triton.code_gen import JITFunction
tmpdir = ".tmp"
@triton.jit
def function_1(i):
i = i + 1
@@ -20,18 +23,21 @@ def function_2(i):
i = i + 1
return i
@triton.jit
def kernel(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
@triton.jit(do_not_specialize=["i"])
def kernel_nospec(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
def apply_src_change(target, old, new):
delattr(kernel.fn, 'hash')
delattr(function_1.fn, 'hash')
@@ -42,28 +48,34 @@ def apply_src_change(target, old, new):
target.src = target.src.replace(new, old)
return ret
def test_nochange():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
assert baseline == updated
def test_toplevel_change():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
assert baseline != updated
def test_nested1_change():
baseline = kernel.cache_key
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
assert baseline != updated
def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
def test_reuse():
counter = 0
def inc_counter(key, binary, repr):
nonlocal counter
counter += 1
@@ -73,11 +85,12 @@ def test_reuse():
for i in range(10):
kernel[(1,)](x, 1, BLOCK=1024)
assert counter == 1
@pytest.mark.parametrize('mode', ['enable', 'disable'])
def test_specialize(mode):
counter = 0
def inc_counter(key, binary, repr):
nonlocal counter
counter += 1

View File

@@ -1,9 +1,11 @@
import torch
import triton
import pytest
import subprocess
import triton.language as tl
import numpy as np
import pytest
import torch
import triton
import triton.language as tl
def get_p2p_matrix():