[FRONTEND] Bunch of fixes here and there (#436)

This commit is contained in:
Philippe Tillet
2022-01-20 10:55:59 -08:00
committed by GitHub
parent e0c5709cc8
commit 4c97d1ecd7
7 changed files with 71 additions and 39 deletions

View File

@@ -745,8 +745,11 @@ ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir:
x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder);
y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder);
}
if(x->get_type()->get_scalar_ty() != y->get_type()->get_scalar_ty())
throw_incompatible_types(x->get_type()->get_scalar_ty(), y->get_type()->get_scalar_ty());
ir::type* x_ty = x->get_type()->get_scalar_ty();
ir::type* y_ty = y->get_type()->get_scalar_ty();
ir::type* ty = computation_type(x_ty, y_ty, DivOrMod::NO);
x = dispatch::cast(x, ty, builder);
y = dispatch::cast(y, ty, builder);
return builder->create_select(condition, x, y);
}

View File

@@ -329,7 +329,6 @@ void init_triton_runtime(py::module &&m) {
// cuda will block if too many ops are enqueued
Py_BEGIN_ALLOW_THREADS
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
nullptr, config);
@@ -466,6 +465,9 @@ std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
uint64_t device, int num_warps, int num_stages,
asm_map_t &asm_map){
int n_shared_bytes;
Py_BEGIN_ALLOW_THREADS
llvm::LLVMContext ctx;
// device properties
CUdevice dev = (CUdevice)device;
@@ -476,7 +478,6 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
drv::dispatch::cuDriverGetVersion(&version);
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::nvidia_cu_target target(cc);
int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
std::string tmp;
llvm::raw_string_ostream llir(tmp);
@@ -492,6 +493,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
py::bytes bytes(cubin);
asm_map["cubin"] = bytes;
}
Py_END_ALLOW_THREADS
return std::make_tuple(name, asm_map, n_shared_bytes);
}

View File

@@ -79,7 +79,7 @@ def to_numpy(x):
def patch_kernel(template, to_replace):
kernel = copy.deepcopy(template)
kernel = triton.JITFunction(template.fn)
for key, value in to_replace.items():
kernel.src = kernel.src.replace(key, value)
return kernel

View File

@@ -39,9 +39,9 @@ def kernel_nospec(X, i, BLOCK: tl.constexpr):
def apply_src_change(target, old, new):
delattr(kernel.fn, 'hash')
delattr(function_1.fn, 'hash')
delattr(function_2.fn, 'hash')
kernel.hash = None
function_1.hash = None
function_2.hash = None
function_1.src = function_1.src.replace(old, new)
target.src = target.src.replace(old, new)
ret = target.cache_key

View File

@@ -651,7 +651,7 @@ class Kernel:
prototype = _triton.ir.type.make_function(ret_type, arg_types)
# generate Triton-IR
# export symbols visible from self.fn into code-generator object
gscope = self.fn.fn.__globals__
gscope = self.fn.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
try:
generator.visit(self.fn.parse())
@@ -723,7 +723,7 @@ class Kernel:
pickle.dump({"binary": binary, "key": key}, f)
os.rename(bin_cache_path + ".tmp", bin_cache_path)
if JITFunction.cache_hook is not None:
name = self.fn.fn.__name__
name = self.fn.__name__
info = key.split('-')[-3:]
num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:]
# make signature human-readable
@@ -885,8 +885,6 @@ def version_key():
ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
# 3
class DependenciesFinder(ast.NodeVisitor):
@@ -910,17 +908,14 @@ class DependenciesFinder(ast.NodeVisitor):
func = self.visit(node.func)
if func is None:
return
if isinstance(func, triton.JITFunction):
func = func.fn
module = inspect.getmodule(func)
if module and module.__name__.startswith('triton.'):
return
if inspect.isbuiltin(func):
return
if not hasattr(func, 'hash'):
src = textwrap.dedent(inspect.getsource(func))
tree = ast.parse(src)
finder = DependenciesFinder(func.__globals__, src)
if func.__module__ and func.__module__.startswith('triton.'):
return
assert isinstance(func, triton.JITFunction)
if func.hash is None:
tree = ast.parse(func.src)
finder = DependenciesFinder(func.__globals__, func.src)
finder.visit(tree)
func.hash = finder.ret
self.ret = (self.ret + func.hash).encode("utf-8")
@@ -941,10 +936,12 @@ class JITFunction:
self.version = version
self.src = textwrap.dedent(inspect.getsource(fn))
self.do_not_specialize = [] if do_not_specialize is None else\
[self.arg_names.index(arg) for arg in do_not_specialize]
self.src = self.src[self.src.find("def"):]
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize]
# cache for callable driver objects (e.g. CUkernel)
self.bin_cache = dict()
self.hash = None
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel_decorators = []
@@ -954,15 +951,19 @@ class JITFunction:
self.__annotations__ = fn.__annotations__
# forward docs
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
self.__globals__ = fn.__globals__
self.__module__ = fn.__module__
@property
@functools.lru_cache()
def cache_key(self):
if not hasattr(self.fn, 'hash'):
dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
# TODO : hash should be attribute of `self`
if self.hash is None:
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.fn.hash = dependencies_finder.ret + version_key()
return self.fn.hash
self.hash = dependencies_finder.ret + version_key()
return self.hash
# we do not parse `src` in the constructor because
# the user might want to monkey-patch self.src dynamically.
@@ -974,14 +975,20 @@ class JITFunction:
assert isinstance(tree.body[0], ast.FunctionDef)
return tree
def __call__(self, *args, generator: CodeGenerator):
def __call__(self, *args, generator: CodeGenerator, **kwargs):
try:
from inspect import getcallargs
arg_values = getcallargs(self.fn, *args, **kwargs)
arg_values = [arg_values[name] for name in self.arg_names]
arg_values = [arg if isinstance(arg, triton.language.block)
else triton.language.constexpr(arg) for arg in arg_values]
gscope = generator.gscope.copy()
lscope = generator.lscope.copy()
values = generator.module.get_values().copy()
generator.gscope = sys.modules[self.fn.__module__].__dict__
generator.lscope = dict()
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
generator.gscope = gscope
generator.lscope = lscope
generator.module.set_values(values)
@@ -1001,8 +1008,7 @@ class JITFunction:
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
if name == 'src':
if hasattr(self.fn, 'hash'):
delattr(self.fn, 'hash')
self.hash = None
JITFunction.cache_key.fget.cache_clear()
def _init_kernel(self):

View File

@@ -168,6 +168,8 @@ class block:
self.numel = constexpr(self.numel)
# Data-type wrapper
self.dtype = block._init_dtype(self.handle.type.scalar)
# Shape is a constexpr
self.shape = [constexpr(s) for s in self.shape]
def __str__(self) -> str:
# ex. "float32[3,4]"
@@ -297,7 +299,7 @@ class block:
if sl is None:
dst_shape.append(1)
elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr])
dst_shape.append(src_shape[curr].value)
curr += 1
ret = frontend.reshape(self, dst_shape, _builder)
return ret
@@ -320,8 +322,15 @@ class constexpr:
"""
def __init__(self, value):
if isinstance(value, constexpr):
self.value = value.value
else:
self.value = value
def __repr__(self) -> str:
return f"constexpr[{self.value}]"
#
def __add__(self, other):
return self.value + other.value
@@ -516,6 +525,7 @@ def reshape(input, shape, _builder=None):
:type shape: Tuple[int]
"""
shape = [x.value for x in shape]
return frontend.reshape(input, shape, _builder)
@@ -908,3 +918,8 @@ def swizzle2d(i, j, size_i, size_j, size_g):
new_i = off_i + (ij % size_g)
new_j = (ij % size_gj) // size_g
return new_i, new_j
@triton.jit
def zeros_like(input):
return zeros(input.shape, input.dtype)

View File

@@ -13,7 +13,7 @@ N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
@triton.jit
def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
"""
@@ -32,6 +32,14 @@ def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
return c0, c1, c2, c3
@triton.jit
def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
seed = seed.to(tl.uint64)
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
seed_lo = (seed & 0xffffffff).to(tl.uint32)
return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
@triton.jit
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
@@ -60,11 +68,9 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
:param seed: The seed for generating random numbers.
:param offsets: The offsets to generate random numbers for.
"""
z = offset * 0 # FIXME: just 0 doesn't work. Likely some error with broadcasting
seed = seed.to(tl.uint64)
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
seed_lo = (seed & 0xffffffff).to(tl.uint32)
return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds)
# _0 = tl.zeros(offset.shape, offset.dtype)
_0 = offset * 0
return philox(seed, offset, _0, _0, _0, n_rounds)
# -------------------