[FRONTEND] Minor accumulated style and warning fixes (#388)
- Fix some whitespace. - Make an undeclared dependency on `pytest` explicit. - Fix deprecated `description-file` use. - `#ifdef` out a deprecated `PyEval_InitThreads` call. - Use a slightly different numpy invocation in `test_random.py` to quiet down overflow warnings in tests. - Fix a deprecated cast in `test_core.py`. - Suppress a warning about `visit_Constant` in Python 3.9+; we can't migrate yet because it'd break Python 3.6 and 3.7. - Use chained exceptions for `CompilationError` rather than rolling our own; it makes the error messages nicer. - Add a `__str__` for `tl.dtype` to make debugging kernels easier; it lets you `print` a dtype to see what type was inferred. - Fix a few bad escapes.
This commit is contained in:
committed by
GitHub
parent
9def2424ab
commit
e575ae3443
@@ -1 +1,2 @@
|
|||||||
|
pytest
|
||||||
scipy >= 1.7.1
|
scipy >= 1.7.1
|
||||||
|
@@ -1,2 +1,2 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
description-file = README.md
|
description_file = README.md
|
||||||
|
@@ -197,7 +197,9 @@ PYBIND11_NOINLINE inline internals &get_internals() {
|
|||||||
auto *&internals_ptr = *internals_pp;
|
auto *&internals_ptr = *internals_pp;
|
||||||
internals_ptr = new internals();
|
internals_ptr = new internals();
|
||||||
#if defined(WITH_THREAD)
|
#if defined(WITH_THREAD)
|
||||||
|
#if PY_VERSION_HEX < 0x03090000
|
||||||
PyEval_InitThreads();
|
PyEval_InitThreads();
|
||||||
|
#endif
|
||||||
PyThreadState *tstate = PyThreadState_Get();
|
PyThreadState *tstate = PyThreadState_Get();
|
||||||
#if PY_VERSION_HEX >= 0x03070000
|
#if PY_VERSION_HEX >= 0x03070000
|
||||||
internals_ptr->tstate = PyThread_tss_alloc();
|
internals_ptr->tstate = PyThread_tss_alloc();
|
||||||
|
@@ -339,7 +339,8 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
|||||||
('float32', 'int32', True)
|
('float32', 'int32', True)
|
||||||
])
|
])
|
||||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||||
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
|
x0 = 43 if dtype_x.startswith('int') else 43.5
|
||||||
|
x = torch.tensor([x0], dtype=cvt[dtype_x], device=device)
|
||||||
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
@@ -74,9 +74,8 @@ class CustomPhilox4x:
|
|||||||
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
|
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
|
||||||
|
|
||||||
def _raise_key(self, key):
|
def _raise_key(self, key):
|
||||||
ret0 = key[0] + self._config.PHILOX_KEY_A
|
pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B]
|
||||||
ret1 = key[1] + self._config.PHILOX_KEY_B
|
return key + np.array(pk, dtype=self._dtype)
|
||||||
return np.array([ret0, ret1], dtype=self._dtype)
|
|
||||||
|
|
||||||
def random_raw(self):
|
def random_raw(self):
|
||||||
counter = self._counter
|
counter = self._counter
|
||||||
|
@@ -10,6 +10,7 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from .tools.disasm import extract
|
from .tools.disasm import extract
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@@ -475,6 +476,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
def visit(self, node):
|
def visit(self, node):
|
||||||
if node is not None:
|
if node is not None:
|
||||||
self.last_node = node
|
self.last_node = node
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
# The ast library added visit_Constant and deprecated some other
|
||||||
|
# methods but we can't move to that without breaking Python 3.6 and 3.7.
|
||||||
|
warnings.simplefilter("ignore", DeprecationWarning)
|
||||||
return super().visit(node)
|
return super().visit(node)
|
||||||
|
|
||||||
def generic_visit(self, node):
|
def generic_visit(self, node):
|
||||||
@@ -512,12 +517,11 @@ class LoadedBinary:
|
|||||||
|
|
||||||
|
|
||||||
class CompilationError(Exception):
|
class CompilationError(Exception):
|
||||||
def __init__(self, src, node, err):
|
def __init__(self, src, node):
|
||||||
self.message = '\n'.join(src.split('\n')[:node.lineno])
|
self.message = '\n'.join(src.split('\n')[:node.lineno])
|
||||||
self.message += '\n' + ' ' * node.col_offset + '^'
|
self.message += '\n' + ' ' * node.col_offset + '^'
|
||||||
self.message += '\n Error: ' + str(err)
|
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
self.args = (src, node, err)
|
self.args = (src, node)
|
||||||
|
|
||||||
|
|
||||||
class OutOfResources(Exception):
|
class OutOfResources(Exception):
|
||||||
@@ -618,7 +622,7 @@ class Kernel:
|
|||||||
node = generator.last_node
|
node = generator.last_node
|
||||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||||
raise e
|
raise e
|
||||||
raise CompilationError(self.fn.src, node, e)
|
raise CompilationError(self.fn.src, node) from e
|
||||||
# Compile to machine code
|
# Compile to machine code
|
||||||
if torch.version.hip is None:
|
if torch.version.hip is None:
|
||||||
backend = _triton.runtime.backend.CUDA
|
backend = _triton.runtime.backend.CUDA
|
||||||
|
@@ -1,4 +1,3 @@
|
|||||||
from . import core
|
from . import core, random
|
||||||
from . import random
|
|
||||||
from .core import *
|
from .core import *
|
||||||
from .random import *
|
from .random import rand, randint, randint4x, randn
|
||||||
|
@@ -72,6 +72,9 @@ class dtype:
|
|||||||
ctx = builder.context
|
ctx = builder.context
|
||||||
return self.init(ctx)
|
return self.init(ctx)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"dtype({self.init.__name__})"
|
||||||
|
|
||||||
|
|
||||||
class pointer_dtype:
|
class pointer_dtype:
|
||||||
def __init__(self, element_ty):
|
def __init__(self, element_ty):
|
||||||
|
@@ -134,7 +134,7 @@ def pair_uniform_to_normal(u1, u2):
|
|||||||
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
"""
|
"""
|
||||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||||
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
|
returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`
|
||||||
|
|
||||||
:param seed: The seed for generating random numbers.
|
:param seed: The seed for generating random numbers.
|
||||||
:param offsets: The offsets to generate random numbers for.
|
:param offsets: The offsets to generate random numbers for.
|
||||||
@@ -149,7 +149,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|||||||
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
"""
|
"""
|
||||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||||
returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
|
returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`
|
||||||
|
|
||||||
:param seed: The seed for generating random numbers.
|
:param seed: The seed for generating random numbers.
|
||||||
:param offsets: The offsets to generate random numbers for.
|
:param offsets: The offsets to generate random numbers for.
|
||||||
|
Reference in New Issue
Block a user