[FRONTEND] Minor accumulated style and warning fixes (#388)
- Fix some whitespace. - Make an undeclared dependency on `pytest` explicit. - Fix deprecated `description-file` use. - `#ifdef` out a deprecated `PyEval_InitThreads` call. - Use a slightly different numpy invocation in `test_random.py` to quiet down overflow warnings in tests. - Fix a deprecated cast in `test_core.py`. - Suppress a warning about `visit_Constant` in Python 3.9+; we can't migrate yet because it'd break Python 3.6 and 3.7. - Use chained exceptions for `CompilationError` rather than rolling our own; it makes the error messages nicer. - Add a `__str__` for `tl.dtype` to make debugging kernels easier; it lets you `print` a dtype to see what type was inferred. - Fix a few bad escapes.
This commit is contained in:
committed by
GitHub
parent
9def2424ab
commit
e575ae3443
@@ -86,7 +86,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
#define void_ty builder_->getVoidTy()
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
#define f32_ty builder_->getFloatTy()
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
#define i32_ty builder_->getInt32Ty()
|
||||
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
||||
@@ -163,8 +163,8 @@ Type *generator::cvt(ir::type *ty) {
|
||||
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
|
||||
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
|
||||
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_);
|
||||
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
|
||||
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
|
||||
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
|
||||
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
|
||||
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
|
||||
case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_);
|
||||
case ir::type::TokenTyID: return Type::getTokenTy(*ctx_);
|
||||
|
@@ -1 +1,2 @@
|
||||
pytest
|
||||
scipy >= 1.7.1
|
||||
|
@@ -1,2 +1,2 @@
|
||||
[metadata]
|
||||
description-file = README.md
|
||||
description_file = README.md
|
||||
|
@@ -197,7 +197,9 @@ PYBIND11_NOINLINE inline internals &get_internals() {
|
||||
auto *&internals_ptr = *internals_pp;
|
||||
internals_ptr = new internals();
|
||||
#if defined(WITH_THREAD)
|
||||
#if PY_VERSION_HEX < 0x03090000
|
||||
PyEval_InitThreads();
|
||||
#endif
|
||||
PyThreadState *tstate = PyThreadState_Get();
|
||||
#if PY_VERSION_HEX >= 0x03070000
|
||||
internals_ptr->tstate = PyThread_tss_alloc();
|
||||
|
@@ -339,7 +339,8 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
|
||||
('float32', 'int32', True)
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
|
||||
x0 = 43 if dtype_x.startswith('int') else 43.5
|
||||
x = torch.tensor([x0], dtype=cvt[dtype_x], device=device)
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -665,4 +666,4 @@ def test_noop(device='cuda'):
|
||||
def kernel(x):
|
||||
pass
|
||||
x = triton.testing.random((1,), dtype=torch.int32, device=device)
|
||||
kernel[(1, )](x)
|
||||
kernel[(1, )](x)
|
||||
|
@@ -74,9 +74,8 @@ class CustomPhilox4x:
|
||||
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
|
||||
|
||||
def _raise_key(self, key):
|
||||
ret0 = key[0] + self._config.PHILOX_KEY_A
|
||||
ret1 = key[1] + self._config.PHILOX_KEY_B
|
||||
return np.array([ret0, ret1], dtype=self._dtype)
|
||||
pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B]
|
||||
return key + np.array(pk, dtype=self._dtype)
|
||||
|
||||
def random_raw(self):
|
||||
counter = self._counter
|
||||
|
@@ -10,6 +10,7 @@ import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import os
|
||||
import warnings
|
||||
from .tools.disasm import extract
|
||||
import torch
|
||||
import triton
|
||||
@@ -475,7 +476,11 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit(self, node):
|
||||
if node is not None:
|
||||
self.last_node = node
|
||||
return super().visit(node)
|
||||
with warnings.catch_warnings():
|
||||
# The ast library added visit_Constant and deprecated some other
|
||||
# methods but we can't move to that without breaking Python 3.6 and 3.7.
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
return super().visit(node)
|
||||
|
||||
def generic_visit(self, node):
|
||||
typename = type(node).__name__
|
||||
@@ -512,12 +517,11 @@ class LoadedBinary:
|
||||
|
||||
|
||||
class CompilationError(Exception):
|
||||
def __init__(self, src, node, err):
|
||||
def __init__(self, src, node):
|
||||
self.message = '\n'.join(src.split('\n')[:node.lineno])
|
||||
self.message += '\n' + ' ' * node.col_offset + '^'
|
||||
self.message += '\n Error: ' + str(err)
|
||||
super().__init__(self.message)
|
||||
self.args = (src, node, err)
|
||||
self.args = (src, node)
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
@@ -618,7 +622,7 @@ class Kernel:
|
||||
node = generator.last_node
|
||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||
raise e
|
||||
raise CompilationError(self.fn.src, node, e)
|
||||
raise CompilationError(self.fn.src, node) from e
|
||||
# Compile to machine code
|
||||
if torch.version.hip is None:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
|
@@ -1,4 +1,3 @@
|
||||
from . import core
|
||||
from . import random
|
||||
from . import core, random
|
||||
from .core import *
|
||||
from .random import *
|
||||
from .random import rand, randint, randint4x, randn
|
||||
|
@@ -72,6 +72,9 @@ class dtype:
|
||||
ctx = builder.context
|
||||
return self.init(ctx)
|
||||
|
||||
def __str__(self):
|
||||
return f"dtype({self.init.__name__})"
|
||||
|
||||
|
||||
class pointer_dtype:
|
||||
def __init__(self, element_ty):
|
||||
|
@@ -134,7 +134,7 @@ def pair_uniform_to_normal(u1, u2):
|
||||
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
returns a block of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
|
||||
returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
@@ -149,7 +149,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
returns a 4 blocks of random :code:`float32` in :math:`\mathcal{N}(0, 1)`
|
||||
returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`
|
||||
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
|
Reference in New Issue
Block a user