[FRONTEND] Minor accumulated style and warning fixes (#388)
- Fix some whitespace. - Make an undeclared dependency on `pytest` explicit. - Fix deprecated `description-file` use. - `#ifdef` out a deprecated `PyEval_InitThreads` call. - Use a slightly different numpy invocation in `test_random.py` to quiet down overflow warnings in tests. - Fix a deprecated cast in `test_core.py`. - Suppress a warning about `visit_Constant` in Python 3.9+; we can't migrate yet because it'd break Python 3.6 and 3.7. - Use chained exceptions for `CompilationError` rather than rolling our own; it makes the error messages nicer. - Add a `__str__` for `tl.dtype` to make debugging kernels easier; it lets you `print` a dtype to see what type was inferred. - Fix a few bad escapes.
This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							9def2424ab
						
					
				
				
					commit
					e575ae3443
				
			| @@ -86,7 +86,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ | ||||
| #define void_ty              builder_->getVoidTy() | ||||
| #define f16_ty               builder_->getHalfTy() | ||||
| #define f32_ty               builder_->getFloatTy() | ||||
| #define i8_ty               builder_->getInt8Ty() | ||||
| #define i8_ty                builder_->getInt8Ty() | ||||
| #define i32_ty               builder_->getInt32Ty() | ||||
| #define vec_ty(type, num_el) VectorType::get(type, num_el, false) | ||||
| #define ptr_ty(...)          PointerType::get(__VA_ARGS__) | ||||
| @@ -163,8 +163,8 @@ Type *generator::cvt(ir::type *ty) { | ||||
|     case ir::type::FP8TyID:       return Type::getInt8Ty(*ctx_); | ||||
|     case ir::type::FP16TyID:      return Type::getHalfTy(*ctx_); | ||||
|     case ir::type::BF16TyID:      return Type::getInt16Ty(*ctx_); | ||||
|     case ir::type::FP32TyID:     return Type::getFloatTy(*ctx_); | ||||
|     case ir::type::FP64TyID:    return Type::getDoubleTy(*ctx_); | ||||
|     case ir::type::FP32TyID:      return Type::getFloatTy(*ctx_); | ||||
|     case ir::type::FP64TyID:      return Type::getDoubleTy(*ctx_); | ||||
|     case ir::type::LabelTyID:     return Type::getLabelTy(*ctx_); | ||||
|     case ir::type::MetadataTyID:  return Type::getMetadataTy(*ctx_); | ||||
|     case ir::type::TokenTyID:     return Type::getTokenTy(*ctx_); | ||||
|   | ||||
| @@ -1 +1,2 @@ | ||||
| pytest | ||||
| scipy >= 1.7.1 | ||||
|   | ||||
| @@ -1,2 +1,2 @@ | ||||
| [metadata] | ||||
| description-file = README.md | ||||
| description_file = README.md | ||||
|   | ||||
| @@ -197,7 +197,9 @@ PYBIND11_NOINLINE inline internals &get_internals() { | ||||
|         auto *&internals_ptr = *internals_pp; | ||||
|         internals_ptr = new internals(); | ||||
| #if defined(WITH_THREAD) | ||||
|         #if PY_VERSION_HEX < 0x03090000 | ||||
|         PyEval_InitThreads(); | ||||
|         #endif | ||||
|         PyThreadState *tstate = PyThreadState_Get(); | ||||
|         #if PY_VERSION_HEX >= 0x03070000 | ||||
|             internals_ptr->tstate = PyThread_tss_alloc(); | ||||
|   | ||||
| @@ -339,7 +339,8 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'): | ||||
|     ('float32', 'int32', True) | ||||
| ]) | ||||
| def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): | ||||
|     x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device) | ||||
|     x0 = 43 if dtype_x.startswith('int') else 43.5 | ||||
|     x = torch.tensor([x0], dtype=cvt[dtype_x], device=device) | ||||
|  | ||||
|     # triton kernel | ||||
|     @triton.jit | ||||
| @@ -665,4 +666,4 @@ def test_noop(device='cuda'): | ||||
|     def kernel(x): | ||||
|         pass | ||||
|     x = triton.testing.random((1,), dtype=torch.int32, device=device) | ||||
|     kernel[(1, )](x) | ||||
|     kernel[(1, )](x) | ||||
|   | ||||
| @@ -74,9 +74,8 @@ class CustomPhilox4x: | ||||
|         return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) | ||||
|  | ||||
|     def _raise_key(self, key): | ||||
|         ret0 = key[0] + self._config.PHILOX_KEY_A | ||||
|         ret1 = key[1] + self._config.PHILOX_KEY_B | ||||
|         return np.array([ret0, ret1], dtype=self._dtype) | ||||
|         pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B] | ||||
|         return key + np.array(pk, dtype=self._dtype) | ||||
|  | ||||
|     def random_raw(self): | ||||
|         counter = self._counter | ||||
|   | ||||
| @@ -10,6 +10,7 @@ import os | ||||
| import pickle | ||||
| import subprocess | ||||
| import os | ||||
| import warnings | ||||
| from .tools.disasm import extract | ||||
| import torch | ||||
| import triton | ||||
| @@ -475,7 +476,11 @@ class CodeGenerator(ast.NodeVisitor): | ||||
|     def visit(self, node): | ||||
|         if node is not None: | ||||
|             self.last_node = node | ||||
|         return super().visit(node) | ||||
|         with warnings.catch_warnings(): | ||||
|             # The ast library added visit_Constant and deprecated some other | ||||
|             # methods but we can't move to that without breaking Python 3.6 and 3.7. | ||||
|             warnings.simplefilter("ignore", DeprecationWarning) | ||||
|             return super().visit(node) | ||||
|  | ||||
|     def generic_visit(self, node): | ||||
|         typename = type(node).__name__ | ||||
| @@ -512,12 +517,11 @@ class LoadedBinary: | ||||
|  | ||||
|  | ||||
| class CompilationError(Exception): | ||||
|     def __init__(self, src, node, err): | ||||
|     def __init__(self, src, node): | ||||
|         self.message = '\n'.join(src.split('\n')[:node.lineno]) | ||||
|         self.message += '\n' + ' ' * node.col_offset + '^' | ||||
|         self.message += '\n Error: ' + str(err) | ||||
|         super().__init__(self.message) | ||||
|         self.args = (src, node, err) | ||||
|         self.args = (src, node) | ||||
|  | ||||
|  | ||||
| class OutOfResources(Exception): | ||||
| @@ -618,7 +622,7 @@ class Kernel: | ||||
|             node = generator.last_node | ||||
|             if node is None or isinstance(e, (NotImplementedError, CompilationError)): | ||||
|                 raise e | ||||
|             raise CompilationError(self.fn.src, node, e) | ||||
|             raise CompilationError(self.fn.src, node) from e | ||||
|         # Compile to machine code | ||||
|         if torch.version.hip is None: | ||||
|             backend = _triton.runtime.backend.CUDA | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
| from . import core | ||||
| from . import random | ||||
| from . import core, random | ||||
| from .core import * | ||||
| from .random import * | ||||
| from .random import rand, randint, randint4x, randn | ||||
|   | ||||
| @@ -72,6 +72,9 @@ class dtype: | ||||
|         ctx = builder.context | ||||
|         return self.init(ctx) | ||||
|  | ||||
|     def __str__(self): | ||||
|         return f"dtype({self.init.__name__})" | ||||
|  | ||||
|  | ||||
| class pointer_dtype: | ||||
|     def __init__(self, element_ty): | ||||
|   | ||||
| @@ -134,7 +134,7 @@ def pair_uniform_to_normal(u1, u2): | ||||
| def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): | ||||
|     """ | ||||
|     Given a :code:`seed` scalar and an :code:`offset` block,  | ||||
|     returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)` | ||||
|     returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)` | ||||
|  | ||||
|     :param seed: The seed for generating random numbers. | ||||
|     :param offsets: The offsets to generate random numbers for. | ||||
| @@ -149,7 +149,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): | ||||
| def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): | ||||
|     """ | ||||
|     Given a :code:`seed` scalar and an :code:`offset` block, | ||||
|     returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)` | ||||
|     returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)` | ||||
|  | ||||
|     :param seed: The seed for generating random numbers. | ||||
|     :param offsets: The offsets to generate random numbers for. | ||||
|   | ||||
		Reference in New Issue
	
	Block a user