[FRONTEND] Minor accumulated style and warning fixes (#388)

- Fix some whitespace.
- Make an undeclared dependency on `pytest` explicit.
- Fix deprecated `description-file` use.
- `#ifdef` out a deprecated `PyEval_InitThreads` call.
- Use a slightly different numpy invocation in `test_random.py` to quiet down overflow warnings in tests.
- Fix a deprecated cast in `test_core.py`.
- Suppress a warning about `visit_Constant` in Python 3.9+; we can't migrate yet because it'd break Python 3.6 and 3.7.
- Use chained exceptions for `CompilationError` rather than rolling our own; it makes the error messages nicer.
- Add a `__str__` for `tl.dtype` to make debugging kernels easier; it lets you `print` a dtype to see what type was inferred.
- Fix a few bad escapes.
This commit is contained in:
Madeleine Thompson
2021-12-10 15:19:20 -08:00
committed by GitHub
parent 9def2424ab
commit e575ae3443
10 changed files with 28 additions and 19 deletions

View File

@@ -86,7 +86,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
#define f32_ty builder_->getFloatTy()
#define i8_ty builder_->getInt8Ty()
#define i8_ty builder_->getInt8Ty()
#define i32_ty builder_->getInt32Ty()
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
@@ -163,8 +163,8 @@ Type *generator::cvt(ir::type *ty) {
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_);
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_);
case ir::type::TokenTyID: return Type::getTokenTy(*ctx_);

View File

@@ -1 +1,2 @@
pytest
scipy >= 1.7.1

View File

@@ -1,2 +1,2 @@
[metadata]
description-file = README.md
description_file = README.md

View File

@@ -197,7 +197,9 @@ PYBIND11_NOINLINE inline internals &get_internals() {
auto *&internals_ptr = *internals_pp;
internals_ptr = new internals();
#if defined(WITH_THREAD)
#if PY_VERSION_HEX < 0x03090000
PyEval_InitThreads();
#endif
PyThreadState *tstate = PyThreadState_Get();
#if PY_VERSION_HEX >= 0x03070000
internals_ptr->tstate = PyThread_tss_alloc();

View File

@@ -339,7 +339,8 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
('float32', 'int32', True)
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
x0 = 43 if dtype_x.startswith('int') else 43.5
x = torch.tensor([x0], dtype=cvt[dtype_x], device=device)
# triton kernel
@triton.jit
@@ -665,4 +666,4 @@ def test_noop(device='cuda'):
def kernel(x):
pass
x = triton.testing.random((1,), dtype=torch.int32, device=device)
kernel[(1, )](x)
kernel[(1, )](x)

View File

@@ -74,9 +74,8 @@ class CustomPhilox4x:
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
def _raise_key(self, key):
ret0 = key[0] + self._config.PHILOX_KEY_A
ret1 = key[1] + self._config.PHILOX_KEY_B
return np.array([ret0, ret1], dtype=self._dtype)
pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B]
return key + np.array(pk, dtype=self._dtype)
def random_raw(self):
counter = self._counter

View File

@@ -10,6 +10,7 @@ import os
import pickle
import subprocess
import os
import warnings
from .tools.disasm import extract
import torch
import triton
@@ -475,7 +476,11 @@ class CodeGenerator(ast.NodeVisitor):
def visit(self, node):
if node is not None:
self.last_node = node
return super().visit(node)
with warnings.catch_warnings():
# The ast library added visit_Constant and deprecated some other
# methods but we can't move to that without breaking Python 3.6 and 3.7.
warnings.simplefilter("ignore", DeprecationWarning)
return super().visit(node)
def generic_visit(self, node):
typename = type(node).__name__
@@ -512,12 +517,11 @@ class LoadedBinary:
class CompilationError(Exception):
def __init__(self, src, node, err):
def __init__(self, src, node):
self.message = '\n'.join(src.split('\n')[:node.lineno])
self.message += '\n' + ' ' * node.col_offset + '^'
self.message += '\n Error: ' + str(err)
super().__init__(self.message)
self.args = (src, node, err)
self.args = (src, node)
class OutOfResources(Exception):
@@ -618,7 +622,7 @@ class Kernel:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.fn.src, node, e)
raise CompilationError(self.fn.src, node) from e
# Compile to machine code
if torch.version.hip is None:
backend = _triton.runtime.backend.CUDA

View File

@@ -1,4 +1,3 @@
from . import core
from . import random
from . import core, random
from .core import *
from .random import *
from .random import rand, randint, randint4x, randn

View File

@@ -72,6 +72,9 @@ class dtype:
ctx = builder.context
return self.init(ctx)
def __str__(self):
return f"dtype({self.init.__name__})"
class pointer_dtype:
def __init__(self, element_ty):

View File

@@ -134,7 +134,7 @@ def pair_uniform_to_normal(u1, u2):
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
@@ -149,7 +149,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.