[FRONTEND] Minor accumulated style and warning fixes (#388)

- Fix some whitespace.
- Make an undeclared dependency on `pytest` explicit.
- Fix deprecated `description-file` use.
- `#ifdef` out a deprecated `PyEval_InitThreads` call.
- Use a slightly different numpy invocation in `test_random.py` to quiet down overflow warnings in tests.
- Fix a deprecated cast in `test_core.py`.
- Suppress a warning about `visit_Constant` in Python 3.9+; we can't migrate yet because it'd break Python 3.6 and 3.7.
- Use chained exceptions for `CompilationError` rather than rolling our own; it makes the error messages nicer.
- Add a `__str__` for `tl.dtype` to make debugging kernels easier; it lets you `print` a dtype to see what type was inferred.
- Fix a few bad escapes.
This commit is contained in:
Madeleine Thompson
2021-12-10 15:19:20 -08:00
committed by GitHub
parent 9def2424ab
commit e575ae3443
10 changed files with 28 additions and 19 deletions

View File

@@ -1 +1,2 @@
pytest
scipy >= 1.7.1 scipy >= 1.7.1

View File

@@ -1,2 +1,2 @@
[metadata] [metadata]
description-file = README.md description_file = README.md

View File

@@ -197,7 +197,9 @@ PYBIND11_NOINLINE inline internals &get_internals() {
auto *&internals_ptr = *internals_pp; auto *&internals_ptr = *internals_pp;
internals_ptr = new internals(); internals_ptr = new internals();
#if defined(WITH_THREAD) #if defined(WITH_THREAD)
#if PY_VERSION_HEX < 0x03090000
PyEval_InitThreads(); PyEval_InitThreads();
#endif
PyThreadState *tstate = PyThreadState_Get(); PyThreadState *tstate = PyThreadState_Get();
#if PY_VERSION_HEX >= 0x03070000 #if PY_VERSION_HEX >= 0x03070000
internals_ptr->tstate = PyThread_tss_alloc(); internals_ptr->tstate = PyThread_tss_alloc();

View File

@@ -339,7 +339,8 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
('float32', 'int32', True) ('float32', 'int32', True)
]) ])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device) x0 = 43 if dtype_x.startswith('int') else 43.5
x = torch.tensor([x0], dtype=cvt[dtype_x], device=device)
# triton kernel # triton kernel
@triton.jit @triton.jit

View File

@@ -74,9 +74,8 @@ class CustomPhilox4x:
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
def _raise_key(self, key): def _raise_key(self, key):
ret0 = key[0] + self._config.PHILOX_KEY_A pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B]
ret1 = key[1] + self._config.PHILOX_KEY_B return key + np.array(pk, dtype=self._dtype)
return np.array([ret0, ret1], dtype=self._dtype)
def random_raw(self): def random_raw(self):
counter = self._counter counter = self._counter

View File

@@ -10,6 +10,7 @@ import os
import pickle import pickle
import subprocess import subprocess
import os import os
import warnings
from .tools.disasm import extract from .tools.disasm import extract
import torch import torch
import triton import triton
@@ -475,6 +476,10 @@ class CodeGenerator(ast.NodeVisitor):
def visit(self, node): def visit(self, node):
if node is not None: if node is not None:
self.last_node = node self.last_node = node
with warnings.catch_warnings():
# The ast library added visit_Constant and deprecated some other
# methods but we can't move to that without breaking Python 3.6 and 3.7.
warnings.simplefilter("ignore", DeprecationWarning)
return super().visit(node) return super().visit(node)
def generic_visit(self, node): def generic_visit(self, node):
@@ -512,12 +517,11 @@ class LoadedBinary:
class CompilationError(Exception): class CompilationError(Exception):
def __init__(self, src, node, err): def __init__(self, src, node):
self.message = '\n'.join(src.split('\n')[:node.lineno]) self.message = '\n'.join(src.split('\n')[:node.lineno])
self.message += '\n' + ' ' * node.col_offset + '^' self.message += '\n' + ' ' * node.col_offset + '^'
self.message += '\n Error: ' + str(err)
super().__init__(self.message) super().__init__(self.message)
self.args = (src, node, err) self.args = (src, node)
class OutOfResources(Exception): class OutOfResources(Exception):
@@ -618,7 +622,7 @@ class Kernel:
node = generator.last_node node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)): if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e raise e
raise CompilationError(self.fn.src, node, e) raise CompilationError(self.fn.src, node) from e
# Compile to machine code # Compile to machine code
if torch.version.hip is None: if torch.version.hip is None:
backend = _triton.runtime.backend.CUDA backend = _triton.runtime.backend.CUDA

View File

@@ -1,4 +1,3 @@
from . import core from . import core, random
from . import random
from .core import * from .core import *
from .random import * from .random import rand, randint, randint4x, randn

View File

@@ -72,6 +72,9 @@ class dtype:
ctx = builder.context ctx = builder.context
return self.init(ctx) return self.init(ctx)
def __str__(self):
return f"dtype({self.init.__name__})"
class pointer_dtype: class pointer_dtype:
def __init__(self, element_ty): def __init__(self, element_ty):

View File

@@ -134,7 +134,7 @@ def pair_uniform_to_normal(u1, u2):
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
""" """
Given a :code:`seed` scalar and an :code:`offset` block, Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)` returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`
:param seed: The seed for generating random numbers. :param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for. :param offsets: The offsets to generate random numbers for.
@@ -149,7 +149,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
""" """
Given a :code:`seed` scalar and an :code:`offset` block, Given a :code:`seed` scalar and an :code:`offset` block,
returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)` returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`
:param seed: The seed for generating random numbers. :param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for. :param offsets: The offsets to generate random numbers for.