[STYLE] run autopep8 and isort (#421)
Run: ``` isort ./python autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py') ``` with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
This commit is contained in:
committed by
GitHub
parent
120cda015e
commit
8bf551ae7a
@@ -1,16 +1,17 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.stats
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from numpy.random import Philox
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
#####################################
|
||||
## Reference Philox Implementation
|
||||
# Reference Philox Implementation
|
||||
#####################################
|
||||
|
||||
|
||||
class PhiloxConfig:
|
||||
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
|
||||
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
|
||||
@@ -103,18 +104,21 @@ class CustomPhilox(CustomPhilox4x):
|
||||
|
||||
|
||||
#####################################
|
||||
## Unit Tests
|
||||
# Unit Tests
|
||||
#####################################
|
||||
|
||||
BLOCK = 1024
|
||||
|
||||
# test generation of random uint32
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in ['10', '4,53', '10000']\
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
[(size, seed) for size in ['10', '4,53', '10000']
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
def test_randint(size, seed, device='cuda'):
|
||||
size = list(map(int, size.split(',')))
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
@@ -132,10 +136,12 @@ def test_randint(size, seed, device='cuda'):
|
||||
assert out_tri == out_ref
|
||||
|
||||
# test uniform PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_rand(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
@@ -151,10 +157,12 @@ def test_rand(size, seed, device='cuda'):
|
||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test normal PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_randn(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
|
Reference in New Issue
Block a user