From e575ae3443fcefb6f04b954893696daaaf91bde1 Mon Sep 17 00:00:00 2001 From: Madeleine Thompson Date: Fri, 10 Dec 2021 15:19:20 -0800 Subject: [PATCH] [FRONTEND] Minor accumulated style and warning fixes (#388) - Fix some whitespace. - Make an undeclared dependency on `pytest` explicit. - Fix deprecated `description-file` use. - `#ifdef` out a deprecated `PyEval_InitThreads` call. - Use a slightly different numpy invocation in `test_random.py` to quiet down overflow warnings in tests. - Fix a deprecated cast in `test_core.py`. - Suppress a warning about `visit_Constant` in Python 3.9+; we can't migrate yet because it'd break Python 3.6 and 3.7. - Use chained exceptions for `CompilationError` rather than rolling our own; it makes the error messages nicer. - Add a `__str__` for `tl.dtype` to make debugging kernels easier; it lets you `print` a dtype to see what type was inferred. - Fix a few bad escapes. --- lib/codegen/selection/generator.cc | 6 +++--- python/requirements-test.txt | 1 + python/setup.cfg | 2 +- python/src/pybind11/detail/internals.h | 2 ++ python/test/unit/language/test_core.py | 5 +++-- python/test/unit/language/test_random.py | 5 ++--- python/triton/code_gen.py | 14 +++++++++----- python/triton/language/__init__.py | 5 ++--- python/triton/language/core.py | 3 +++ python/triton/language/random.py | 4 ++-- 10 files changed, 28 insertions(+), 19 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 3c4fae3d8..d5c5c4902 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -86,7 +86,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define void_ty builder_->getVoidTy() #define f16_ty builder_->getHalfTy() #define f32_ty builder_->getFloatTy() -#define i8_ty builder_->getInt8Ty() +#define i8_ty builder_->getInt8Ty() #define i32_ty builder_->getInt32Ty() #define vec_ty(type, num_el) VectorType::get(type, num_el, false) #define ptr_ty(...) PointerType::get(__VA_ARGS__) @@ -163,8 +163,8 @@ Type *generator::cvt(ir::type *ty) { case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); case ir::type::FP16TyID: return Type::getHalfTy(*ctx_); case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); - case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); - case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); + case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); + case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_); case ir::type::TokenTyID: return Type::getTokenTy(*ctx_); diff --git a/python/requirements-test.txt b/python/requirements-test.txt index 4a1b49122..48b6d3be3 100644 --- a/python/requirements-test.txt +++ b/python/requirements-test.txt @@ -1 +1,2 @@ +pytest scipy >= 1.7.1 diff --git a/python/setup.cfg b/python/setup.cfg index 224a77957..08aedd7e6 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -1,2 +1,2 @@ [metadata] -description-file = README.md \ No newline at end of file +description_file = README.md diff --git a/python/src/pybind11/detail/internals.h b/python/src/pybind11/detail/internals.h index f1dd38764..4f25759d3 100644 --- a/python/src/pybind11/detail/internals.h +++ b/python/src/pybind11/detail/internals.h @@ -197,7 +197,9 @@ PYBIND11_NOINLINE inline internals &get_internals() { auto *&internals_ptr = *internals_pp; internals_ptr = new internals(); #if defined(WITH_THREAD) + #if PY_VERSION_HEX < 0x03090000 PyEval_InitThreads(); + #endif PyThreadState *tstate = PyThreadState_Get(); #if PY_VERSION_HEX >= 0x03070000 internals_ptr->tstate = PyThread_tss_alloc(); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6359857fe..e85c399b3 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -339,7 +339,8 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'): ('float32', 'int32', True) ]) def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): - x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device) + x0 = 43 if dtype_x.startswith('int') else 43.5 + x = torch.tensor([x0], dtype=cvt[dtype_x], device=device) # triton kernel @triton.jit @@ -665,4 +666,4 @@ def test_noop(device='cuda'): def kernel(x): pass x = triton.testing.random((1,), dtype=torch.int32, device=device) - kernel[(1, )](x) \ No newline at end of file + kernel[(1, )](x) diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index 4c1261f1d..4d4501556 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -74,9 +74,8 @@ class CustomPhilox4x: return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) def _raise_key(self, key): - ret0 = key[0] + self._config.PHILOX_KEY_A - ret1 = key[1] + self._config.PHILOX_KEY_B - return np.array([ret0, ret1], dtype=self._dtype) + pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B] + return key + np.array(pk, dtype=self._dtype) def random_raw(self): counter = self._counter diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index deede2530..96948e360 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -10,6 +10,7 @@ import os import pickle import subprocess import os +import warnings from .tools.disasm import extract import torch import triton @@ -475,7 +476,11 @@ class CodeGenerator(ast.NodeVisitor): def visit(self, node): if node is not None: self.last_node = node - return super().visit(node) + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) + return super().visit(node) def generic_visit(self, node): typename = type(node).__name__ @@ -512,12 +517,11 @@ class LoadedBinary: class CompilationError(Exception): - def __init__(self, src, node, err): + def __init__(self, src, node): self.message = '\n'.join(src.split('\n')[:node.lineno]) self.message += '\n' + ' ' * node.col_offset + '^' - self.message += '\n Error: ' + str(err) super().__init__(self.message) - self.args = (src, node, err) + self.args = (src, node) class OutOfResources(Exception): @@ -618,7 +622,7 @@ class Kernel: node = generator.last_node if node is None or isinstance(e, (NotImplementedError, CompilationError)): raise e - raise CompilationError(self.fn.src, node, e) + raise CompilationError(self.fn.src, node) from e # Compile to machine code if torch.version.hip is None: backend = _triton.runtime.backend.CUDA diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index b96260c51..2f3f4ea05 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,4 +1,3 @@ -from . import core -from . import random +from . import core, random from .core import * -from .random import * +from .random import rand, randint, randint4x, randn diff --git a/python/triton/language/core.py b/python/triton/language/core.py index d7240fcf8..55b5bc0d9 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -72,6 +72,9 @@ class dtype: ctx = builder.context return self.init(ctx) + def __str__(self): + return f"dtype({self.init.__name__})" + class pointer_dtype: def __init__(self, element_ty): diff --git a/python/triton/language/random.py b/python/triton/language/random.py index a831af487..e1ac3c30a 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -134,7 +134,7 @@ def pair_uniform_to_normal(u1, u2): def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offset` block, - returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)` + returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)` :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. @@ -149,7 +149,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): """ Given a :code:`seed` scalar and an :code:`offset` block, - returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)` + returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)` :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for.