[FRONTEND] Minor accumulated style and warning fixes (#388)
- Fix some whitespace. - Make an undeclared dependency on `pytest` explicit. - Fix deprecated `description-file` use. - `#ifdef` out a deprecated `PyEval_InitThreads` call. - Use a slightly different numpy invocation in `test_random.py` to quiet down overflow warnings in tests. - Fix a deprecated cast in `test_core.py`. - Suppress a warning about `visit_Constant` in Python 3.9+; we can't migrate yet because it'd break Python 3.6 and 3.7. - Use chained exceptions for `CompilationError` rather than rolling our own; it makes the error messages nicer. - Add a `__str__` for `tl.dtype` to make debugging kernels easier; it lets you `print` a dtype to see what type was inferred. - Fix a few bad escapes.
This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							9def2424ab
						
					
				
				
					commit
					e575ae3443
				
			| @@ -86,7 +86,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ | |||||||
| #define void_ty              builder_->getVoidTy() | #define void_ty              builder_->getVoidTy() | ||||||
| #define f16_ty               builder_->getHalfTy() | #define f16_ty               builder_->getHalfTy() | ||||||
| #define f32_ty               builder_->getFloatTy() | #define f32_ty               builder_->getFloatTy() | ||||||
| #define i8_ty               builder_->getInt8Ty() | #define i8_ty                builder_->getInt8Ty() | ||||||
| #define i32_ty               builder_->getInt32Ty() | #define i32_ty               builder_->getInt32Ty() | ||||||
| #define vec_ty(type, num_el) VectorType::get(type, num_el, false) | #define vec_ty(type, num_el) VectorType::get(type, num_el, false) | ||||||
| #define ptr_ty(...)          PointerType::get(__VA_ARGS__) | #define ptr_ty(...)          PointerType::get(__VA_ARGS__) | ||||||
| @@ -163,8 +163,8 @@ Type *generator::cvt(ir::type *ty) { | |||||||
|     case ir::type::FP8TyID:       return Type::getInt8Ty(*ctx_); |     case ir::type::FP8TyID:       return Type::getInt8Ty(*ctx_); | ||||||
|     case ir::type::FP16TyID:      return Type::getHalfTy(*ctx_); |     case ir::type::FP16TyID:      return Type::getHalfTy(*ctx_); | ||||||
|     case ir::type::BF16TyID:      return Type::getInt16Ty(*ctx_); |     case ir::type::BF16TyID:      return Type::getInt16Ty(*ctx_); | ||||||
|     case ir::type::FP32TyID:     return Type::getFloatTy(*ctx_); |     case ir::type::FP32TyID:      return Type::getFloatTy(*ctx_); | ||||||
|     case ir::type::FP64TyID:    return Type::getDoubleTy(*ctx_); |     case ir::type::FP64TyID:      return Type::getDoubleTy(*ctx_); | ||||||
|     case ir::type::LabelTyID:     return Type::getLabelTy(*ctx_); |     case ir::type::LabelTyID:     return Type::getLabelTy(*ctx_); | ||||||
|     case ir::type::MetadataTyID:  return Type::getMetadataTy(*ctx_); |     case ir::type::MetadataTyID:  return Type::getMetadataTy(*ctx_); | ||||||
|     case ir::type::TokenTyID:     return Type::getTokenTy(*ctx_); |     case ir::type::TokenTyID:     return Type::getTokenTy(*ctx_); | ||||||
|   | |||||||
| @@ -1 +1,2 @@ | |||||||
|  | pytest | ||||||
| scipy >= 1.7.1 | scipy >= 1.7.1 | ||||||
|   | |||||||
| @@ -1,2 +1,2 @@ | |||||||
| [metadata] | [metadata] | ||||||
| description-file = README.md | description_file = README.md | ||||||
|   | |||||||
| @@ -197,7 +197,9 @@ PYBIND11_NOINLINE inline internals &get_internals() { | |||||||
|         auto *&internals_ptr = *internals_pp; |         auto *&internals_ptr = *internals_pp; | ||||||
|         internals_ptr = new internals(); |         internals_ptr = new internals(); | ||||||
| #if defined(WITH_THREAD) | #if defined(WITH_THREAD) | ||||||
|  |         #if PY_VERSION_HEX < 0x03090000 | ||||||
|         PyEval_InitThreads(); |         PyEval_InitThreads(); | ||||||
|  |         #endif | ||||||
|         PyThreadState *tstate = PyThreadState_Get(); |         PyThreadState *tstate = PyThreadState_Get(); | ||||||
|         #if PY_VERSION_HEX >= 0x03070000 |         #if PY_VERSION_HEX >= 0x03070000 | ||||||
|             internals_ptr->tstate = PyThread_tss_alloc(); |             internals_ptr->tstate = PyThread_tss_alloc(); | ||||||
|   | |||||||
| @@ -339,7 +339,8 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'): | |||||||
|     ('float32', 'int32', True) |     ('float32', 'int32', True) | ||||||
| ]) | ]) | ||||||
| def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): | def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): | ||||||
|     x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device) |     x0 = 43 if dtype_x.startswith('int') else 43.5 | ||||||
|  |     x = torch.tensor([x0], dtype=cvt[dtype_x], device=device) | ||||||
|  |  | ||||||
|     # triton kernel |     # triton kernel | ||||||
|     @triton.jit |     @triton.jit | ||||||
|   | |||||||
| @@ -74,9 +74,8 @@ class CustomPhilox4x: | |||||||
|         return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) |         return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) | ||||||
|  |  | ||||||
|     def _raise_key(self, key): |     def _raise_key(self, key): | ||||||
|         ret0 = key[0] + self._config.PHILOX_KEY_A |         pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B] | ||||||
|         ret1 = key[1] + self._config.PHILOX_KEY_B |         return key + np.array(pk, dtype=self._dtype) | ||||||
|         return np.array([ret0, ret1], dtype=self._dtype) |  | ||||||
|  |  | ||||||
|     def random_raw(self): |     def random_raw(self): | ||||||
|         counter = self._counter |         counter = self._counter | ||||||
|   | |||||||
| @@ -10,6 +10,7 @@ import os | |||||||
| import pickle | import pickle | ||||||
| import subprocess | import subprocess | ||||||
| import os | import os | ||||||
|  | import warnings | ||||||
| from .tools.disasm import extract | from .tools.disasm import extract | ||||||
| import torch | import torch | ||||||
| import triton | import triton | ||||||
| @@ -475,7 +476,11 @@ class CodeGenerator(ast.NodeVisitor): | |||||||
|     def visit(self, node): |     def visit(self, node): | ||||||
|         if node is not None: |         if node is not None: | ||||||
|             self.last_node = node |             self.last_node = node | ||||||
|         return super().visit(node) |         with warnings.catch_warnings(): | ||||||
|  |             # The ast library added visit_Constant and deprecated some other | ||||||
|  |             # methods but we can't move to that without breaking Python 3.6 and 3.7. | ||||||
|  |             warnings.simplefilter("ignore", DeprecationWarning) | ||||||
|  |             return super().visit(node) | ||||||
|  |  | ||||||
|     def generic_visit(self, node): |     def generic_visit(self, node): | ||||||
|         typename = type(node).__name__ |         typename = type(node).__name__ | ||||||
| @@ -512,12 +517,11 @@ class LoadedBinary: | |||||||
|  |  | ||||||
|  |  | ||||||
| class CompilationError(Exception): | class CompilationError(Exception): | ||||||
|     def __init__(self, src, node, err): |     def __init__(self, src, node): | ||||||
|         self.message = '\n'.join(src.split('\n')[:node.lineno]) |         self.message = '\n'.join(src.split('\n')[:node.lineno]) | ||||||
|         self.message += '\n' + ' ' * node.col_offset + '^' |         self.message += '\n' + ' ' * node.col_offset + '^' | ||||||
|         self.message += '\n Error: ' + str(err) |  | ||||||
|         super().__init__(self.message) |         super().__init__(self.message) | ||||||
|         self.args = (src, node, err) |         self.args = (src, node) | ||||||
|  |  | ||||||
|  |  | ||||||
| class OutOfResources(Exception): | class OutOfResources(Exception): | ||||||
| @@ -618,7 +622,7 @@ class Kernel: | |||||||
|             node = generator.last_node |             node = generator.last_node | ||||||
|             if node is None or isinstance(e, (NotImplementedError, CompilationError)): |             if node is None or isinstance(e, (NotImplementedError, CompilationError)): | ||||||
|                 raise e |                 raise e | ||||||
|             raise CompilationError(self.fn.src, node, e) |             raise CompilationError(self.fn.src, node) from e | ||||||
|         # Compile to machine code |         # Compile to machine code | ||||||
|         if torch.version.hip is None: |         if torch.version.hip is None: | ||||||
|             backend = _triton.runtime.backend.CUDA |             backend = _triton.runtime.backend.CUDA | ||||||
|   | |||||||
| @@ -1,4 +1,3 @@ | |||||||
| from . import core | from . import core, random | ||||||
| from . import random |  | ||||||
| from .core import * | from .core import * | ||||||
| from .random import * | from .random import rand, randint, randint4x, randn | ||||||
|   | |||||||
| @@ -72,6 +72,9 @@ class dtype: | |||||||
|         ctx = builder.context |         ctx = builder.context | ||||||
|         return self.init(ctx) |         return self.init(ctx) | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return f"dtype({self.init.__name__})" | ||||||
|  |  | ||||||
|  |  | ||||||
| class pointer_dtype: | class pointer_dtype: | ||||||
|     def __init__(self, element_ty): |     def __init__(self, element_ty): | ||||||
|   | |||||||
| @@ -134,7 +134,7 @@ def pair_uniform_to_normal(u1, u2): | |||||||
| def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): | def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): | ||||||
|     """ |     """ | ||||||
|     Given a :code:`seed` scalar and an :code:`offset` block,  |     Given a :code:`seed` scalar and an :code:`offset` block,  | ||||||
|     returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)` |     returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)` | ||||||
|  |  | ||||||
|     :param seed: The seed for generating random numbers. |     :param seed: The seed for generating random numbers. | ||||||
|     :param offsets: The offsets to generate random numbers for. |     :param offsets: The offsets to generate random numbers for. | ||||||
| @@ -149,7 +149,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): | |||||||
| def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): | def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): | ||||||
|     """ |     """ | ||||||
|     Given a :code:`seed` scalar and an :code:`offset` block, |     Given a :code:`seed` scalar and an :code:`offset` block, | ||||||
|     returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)` |     returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)` | ||||||
|  |  | ||||||
|     :param seed: The seed for generating random numbers. |     :param seed: The seed for generating random numbers. | ||||||
|     :param offsets: The offsets to generate random numbers for. |     :param offsets: The offsets to generate random numbers for. | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user