[FRONTEND] Bunch of fixes here and there (#436)
This commit is contained in:
@@ -745,8 +745,11 @@ ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir:
|
|||||||
x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder);
|
x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder);
|
||||||
y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder);
|
y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder);
|
||||||
}
|
}
|
||||||
if(x->get_type()->get_scalar_ty() != y->get_type()->get_scalar_ty())
|
ir::type* x_ty = x->get_type()->get_scalar_ty();
|
||||||
throw_incompatible_types(x->get_type()->get_scalar_ty(), y->get_type()->get_scalar_ty());
|
ir::type* y_ty = y->get_type()->get_scalar_ty();
|
||||||
|
ir::type* ty = computation_type(x_ty, y_ty, DivOrMod::NO);
|
||||||
|
x = dispatch::cast(x, ty, builder);
|
||||||
|
y = dispatch::cast(y, ty, builder);
|
||||||
return builder->create_select(condition, x, y);
|
return builder->create_select(condition, x, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -329,7 +329,6 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
// cuda will block if too many ops are enqueued
|
// cuda will block if too many ops are enqueued
|
||||||
Py_BEGIN_ALLOW_THREADS
|
Py_BEGIN_ALLOW_THREADS
|
||||||
|
|
||||||
|
|
||||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||||
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||||
nullptr, config);
|
nullptr, config);
|
||||||
@@ -466,6 +465,9 @@ std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_
|
|||||||
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
|
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
|
||||||
uint64_t device, int num_warps, int num_stages,
|
uint64_t device, int num_warps, int num_stages,
|
||||||
asm_map_t &asm_map){
|
asm_map_t &asm_map){
|
||||||
|
|
||||||
|
int n_shared_bytes;
|
||||||
|
Py_BEGIN_ALLOW_THREADS
|
||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
// device properties
|
// device properties
|
||||||
CUdevice dev = (CUdevice)device;
|
CUdevice dev = (CUdevice)device;
|
||||||
@@ -476,7 +478,6 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
|||||||
drv::dispatch::cuDriverGetVersion(&version);
|
drv::dispatch::cuDriverGetVersion(&version);
|
||||||
// Triton-IR -> NVPTX LLVM-IR
|
// Triton-IR -> NVPTX LLVM-IR
|
||||||
triton::codegen::nvidia_cu_target target(cc);
|
triton::codegen::nvidia_cu_target target(cc);
|
||||||
int n_shared_bytes;
|
|
||||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
|
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
|
||||||
std::string tmp;
|
std::string tmp;
|
||||||
llvm::raw_string_ostream llir(tmp);
|
llvm::raw_string_ostream llir(tmp);
|
||||||
@@ -492,6 +493,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
|||||||
py::bytes bytes(cubin);
|
py::bytes bytes(cubin);
|
||||||
asm_map["cubin"] = bytes;
|
asm_map["cubin"] = bytes;
|
||||||
}
|
}
|
||||||
|
Py_END_ALLOW_THREADS
|
||||||
return std::make_tuple(name, asm_map, n_shared_bytes);
|
return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -79,7 +79,7 @@ def to_numpy(x):
|
|||||||
|
|
||||||
|
|
||||||
def patch_kernel(template, to_replace):
|
def patch_kernel(template, to_replace):
|
||||||
kernel = copy.deepcopy(template)
|
kernel = triton.JITFunction(template.fn)
|
||||||
for key, value in to_replace.items():
|
for key, value in to_replace.items():
|
||||||
kernel.src = kernel.src.replace(key, value)
|
kernel.src = kernel.src.replace(key, value)
|
||||||
return kernel
|
return kernel
|
||||||
|
@@ -39,9 +39,9 @@ def kernel_nospec(X, i, BLOCK: tl.constexpr):
|
|||||||
|
|
||||||
|
|
||||||
def apply_src_change(target, old, new):
|
def apply_src_change(target, old, new):
|
||||||
delattr(kernel.fn, 'hash')
|
kernel.hash = None
|
||||||
delattr(function_1.fn, 'hash')
|
function_1.hash = None
|
||||||
delattr(function_2.fn, 'hash')
|
function_2.hash = None
|
||||||
function_1.src = function_1.src.replace(old, new)
|
function_1.src = function_1.src.replace(old, new)
|
||||||
target.src = target.src.replace(old, new)
|
target.src = target.src.replace(old, new)
|
||||||
ret = target.cache_key
|
ret = target.cache_key
|
||||||
|
@@ -651,7 +651,7 @@ class Kernel:
|
|||||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||||
# generate Triton-IR
|
# generate Triton-IR
|
||||||
# export symbols visible from self.fn into code-generator object
|
# export symbols visible from self.fn into code-generator object
|
||||||
gscope = self.fn.fn.__globals__
|
gscope = self.fn.__globals__
|
||||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||||
try:
|
try:
|
||||||
generator.visit(self.fn.parse())
|
generator.visit(self.fn.parse())
|
||||||
@@ -723,7 +723,7 @@ class Kernel:
|
|||||||
pickle.dump({"binary": binary, "key": key}, f)
|
pickle.dump({"binary": binary, "key": key}, f)
|
||||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||||
if JITFunction.cache_hook is not None:
|
if JITFunction.cache_hook is not None:
|
||||||
name = self.fn.fn.__name__
|
name = self.fn.__name__
|
||||||
info = key.split('-')[-3:]
|
info = key.split('-')[-3:]
|
||||||
num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:]
|
num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:]
|
||||||
# make signature human-readable
|
# make signature human-readable
|
||||||
@@ -885,8 +885,6 @@ def version_key():
|
|||||||
ptxas_version = ''
|
ptxas_version = ''
|
||||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||||
|
|
||||||
# 3
|
|
||||||
|
|
||||||
|
|
||||||
class DependenciesFinder(ast.NodeVisitor):
|
class DependenciesFinder(ast.NodeVisitor):
|
||||||
|
|
||||||
@@ -910,17 +908,14 @@ class DependenciesFinder(ast.NodeVisitor):
|
|||||||
func = self.visit(node.func)
|
func = self.visit(node.func)
|
||||||
if func is None:
|
if func is None:
|
||||||
return
|
return
|
||||||
if isinstance(func, triton.JITFunction):
|
|
||||||
func = func.fn
|
|
||||||
module = inspect.getmodule(func)
|
|
||||||
if module and module.__name__.startswith('triton.'):
|
|
||||||
return
|
|
||||||
if inspect.isbuiltin(func):
|
if inspect.isbuiltin(func):
|
||||||
return
|
return
|
||||||
if not hasattr(func, 'hash'):
|
if func.__module__ and func.__module__.startswith('triton.'):
|
||||||
src = textwrap.dedent(inspect.getsource(func))
|
return
|
||||||
tree = ast.parse(src)
|
assert isinstance(func, triton.JITFunction)
|
||||||
finder = DependenciesFinder(func.__globals__, src)
|
if func.hash is None:
|
||||||
|
tree = ast.parse(func.src)
|
||||||
|
finder = DependenciesFinder(func.__globals__, func.src)
|
||||||
finder.visit(tree)
|
finder.visit(tree)
|
||||||
func.hash = finder.ret
|
func.hash = finder.ret
|
||||||
self.ret = (self.ret + func.hash).encode("utf-8")
|
self.ret = (self.ret + func.hash).encode("utf-8")
|
||||||
@@ -941,10 +936,12 @@ class JITFunction:
|
|||||||
|
|
||||||
self.version = version
|
self.version = version
|
||||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||||
self.do_not_specialize = [] if do_not_specialize is None else\
|
self.src = self.src[self.src.find("def"):]
|
||||||
[self.arg_names.index(arg) for arg in do_not_specialize]
|
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||||
|
self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize]
|
||||||
# cache for callable driver objects (e.g. CUkernel)
|
# cache for callable driver objects (e.g. CUkernel)
|
||||||
self.bin_cache = dict()
|
self.bin_cache = dict()
|
||||||
|
self.hash = None
|
||||||
# JITFunction can be instantiated as kernel
|
# JITFunction can be instantiated as kernel
|
||||||
# when called with a grid using __getitem__
|
# when called with a grid using __getitem__
|
||||||
self.kernel_decorators = []
|
self.kernel_decorators = []
|
||||||
@@ -954,15 +951,19 @@ class JITFunction:
|
|||||||
self.__annotations__ = fn.__annotations__
|
self.__annotations__ = fn.__annotations__
|
||||||
# forward docs
|
# forward docs
|
||||||
self.__doc__ = fn.__doc__
|
self.__doc__ = fn.__doc__
|
||||||
|
self.__name__ = fn.__name__
|
||||||
|
self.__globals__ = fn.__globals__
|
||||||
|
self.__module__ = fn.__module__
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@functools.lru_cache()
|
@functools.lru_cache()
|
||||||
def cache_key(self):
|
def cache_key(self):
|
||||||
if not hasattr(self.fn, 'hash'):
|
# TODO : hash should be attribute of `self`
|
||||||
dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
|
if self.hash is None:
|
||||||
|
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
|
||||||
dependencies_finder.visit(self.parse())
|
dependencies_finder.visit(self.parse())
|
||||||
self.fn.hash = dependencies_finder.ret + version_key()
|
self.hash = dependencies_finder.ret + version_key()
|
||||||
return self.fn.hash
|
return self.hash
|
||||||
|
|
||||||
# we do not parse `src` in the constructor because
|
# we do not parse `src` in the constructor because
|
||||||
# the user might want to monkey-patch self.src dynamically.
|
# the user might want to monkey-patch self.src dynamically.
|
||||||
@@ -974,14 +975,20 @@ class JITFunction:
|
|||||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
def __call__(self, *args, generator: CodeGenerator):
|
def __call__(self, *args, generator: CodeGenerator, **kwargs):
|
||||||
try:
|
try:
|
||||||
|
from inspect import getcallargs
|
||||||
|
arg_values = getcallargs(self.fn, *args, **kwargs)
|
||||||
|
arg_values = [arg_values[name] for name in self.arg_names]
|
||||||
|
arg_values = [arg if isinstance(arg, triton.language.block)
|
||||||
|
else triton.language.constexpr(arg) for arg in arg_values]
|
||||||
|
|
||||||
gscope = generator.gscope.copy()
|
gscope = generator.gscope.copy()
|
||||||
lscope = generator.lscope.copy()
|
lscope = generator.lscope.copy()
|
||||||
values = generator.module.get_values().copy()
|
values = generator.module.get_values().copy()
|
||||||
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
||||||
generator.lscope = dict()
|
generator.lscope = dict()
|
||||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
|
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
|
||||||
generator.gscope = gscope
|
generator.gscope = gscope
|
||||||
generator.lscope = lscope
|
generator.lscope = lscope
|
||||||
generator.module.set_values(values)
|
generator.module.set_values(values)
|
||||||
@@ -1001,8 +1008,7 @@ class JITFunction:
|
|||||||
self.kernel = None
|
self.kernel = None
|
||||||
super(JITFunction, self).__setattr__(name, value)
|
super(JITFunction, self).__setattr__(name, value)
|
||||||
if name == 'src':
|
if name == 'src':
|
||||||
if hasattr(self.fn, 'hash'):
|
self.hash = None
|
||||||
delattr(self.fn, 'hash')
|
|
||||||
JITFunction.cache_key.fget.cache_clear()
|
JITFunction.cache_key.fget.cache_clear()
|
||||||
|
|
||||||
def _init_kernel(self):
|
def _init_kernel(self):
|
||||||
|
@@ -168,6 +168,8 @@ class block:
|
|||||||
self.numel = constexpr(self.numel)
|
self.numel = constexpr(self.numel)
|
||||||
# Data-type wrapper
|
# Data-type wrapper
|
||||||
self.dtype = block._init_dtype(self.handle.type.scalar)
|
self.dtype = block._init_dtype(self.handle.type.scalar)
|
||||||
|
# Shape is a constexpr
|
||||||
|
self.shape = [constexpr(s) for s in self.shape]
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
# ex. "float32[3,4]"
|
# ex. "float32[3,4]"
|
||||||
@@ -297,7 +299,7 @@ class block:
|
|||||||
if sl is None:
|
if sl is None:
|
||||||
dst_shape.append(1)
|
dst_shape.append(1)
|
||||||
elif sl == slice(None, None, None):
|
elif sl == slice(None, None, None):
|
||||||
dst_shape.append(src_shape[curr])
|
dst_shape.append(src_shape[curr].value)
|
||||||
curr += 1
|
curr += 1
|
||||||
ret = frontend.reshape(self, dst_shape, _builder)
|
ret = frontend.reshape(self, dst_shape, _builder)
|
||||||
return ret
|
return ret
|
||||||
@@ -320,8 +322,15 @@ class constexpr:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, value):
|
def __init__(self, value):
|
||||||
|
if isinstance(value, constexpr):
|
||||||
|
self.value = value.value
|
||||||
|
else:
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"constexpr[{self.value}]"
|
||||||
|
|
||||||
|
#
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
return self.value + other.value
|
return self.value + other.value
|
||||||
|
|
||||||
@@ -516,6 +525,7 @@ def reshape(input, shape, _builder=None):
|
|||||||
:type shape: Tuple[int]
|
:type shape: Tuple[int]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
shape = [x.value for x in shape]
|
||||||
return frontend.reshape(input, shape, _builder)
|
return frontend.reshape(input, shape, _builder)
|
||||||
|
|
||||||
|
|
||||||
@@ -908,3 +918,8 @@ def swizzle2d(i, j, size_i, size_j, size_g):
|
|||||||
new_i = off_i + (ij % size_g)
|
new_i = off_i + (ij % size_g)
|
||||||
new_j = (ij % size_gj) // size_g
|
new_j = (ij % size_gj) // size_g
|
||||||
return new_i, new_j
|
return new_i, new_j
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def zeros_like(input):
|
||||||
|
return zeros(input.shape, input.dtype)
|
||||||
|
@@ -13,7 +13,7 @@ N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
|
|||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
"""
|
"""
|
||||||
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
|
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
|
||||||
"""
|
"""
|
||||||
@@ -32,6 +32,14 @@ def philox_f(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|||||||
return c0, c1, c2, c3
|
return c0, c1, c2, c3
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
|
seed = seed.to(tl.uint64)
|
||||||
|
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
|
||||||
|
seed_lo = (seed & 0xffffffff).to(tl.uint32)
|
||||||
|
return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||||
"""
|
"""
|
||||||
@@ -60,11 +68,9 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|||||||
:param seed: The seed for generating random numbers.
|
:param seed: The seed for generating random numbers.
|
||||||
:param offsets: The offsets to generate random numbers for.
|
:param offsets: The offsets to generate random numbers for.
|
||||||
"""
|
"""
|
||||||
z = offset * 0 # FIXME: just 0 doesn't work. Likely some error with broadcasting
|
# _0 = tl.zeros(offset.shape, offset.dtype)
|
||||||
seed = seed.to(tl.uint64)
|
_0 = offset * 0
|
||||||
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
|
return philox(seed, offset, _0, _0, _0, n_rounds)
|
||||||
seed_lo = (seed & 0xffffffff).to(tl.uint32)
|
|
||||||
return philox_f(offset, z, z, z, seed_lo, seed_hi, n_rounds)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------
|
# -------------------
|
||||||
|
Reference in New Issue
Block a user