Improve
This commit is contained in:
@@ -959,7 +959,7 @@ def generate_launcher(identifier, constants, signature):
|
|||||||
"int64_t": "L",
|
"int64_t": "L",
|
||||||
}[ty]
|
}[ty]
|
||||||
|
|
||||||
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + "O"
|
||||||
|
|
||||||
# generate glue code
|
# generate glue code
|
||||||
src = f"""
|
src = f"""
|
||||||
@@ -1022,34 +1022,22 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
|||||||
PyObject *launch_enter_hook = NULL;
|
PyObject *launch_enter_hook = NULL;
|
||||||
PyObject *launch_exit_hook = NULL;
|
PyObject *launch_exit_hook = NULL;
|
||||||
PyObject *compiled_kernel = NULL;
|
PyObject *compiled_kernel = NULL;
|
||||||
PyObject *hook_ret = NULL;
|
PyObject *constants = NULL;
|
||||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||||
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())}, &constants)) {{
|
||||||
return NULL;
|
return NULL;
|
||||||
}}
|
}}
|
||||||
|
|
||||||
if (launch_enter_hook != Py_None) {{
|
if (launch_enter_hook != Py_None) {{
|
||||||
PyObject *new_args = PyTuple_Pack(1, compiled_kernel);
|
PyObject_CallObject(launch_enter_hook, args);
|
||||||
hook_ret = PyObject_CallObject(launch_enter_hook, new_args);
|
|
||||||
Py_DECREF(new_args);
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
||||||
|
|
||||||
if (launch_exit_hook != Py_None) {{
|
if (launch_exit_hook != Py_None) {{
|
||||||
PyObject *new_args = NULL;
|
PyObject_CallObject(launch_exit_hook, args);
|
||||||
if (hook_ret) {{
|
|
||||||
new_args = PyTuple_Pack(2, compiled_kernel, hook_ret);
|
|
||||||
}} else {{
|
|
||||||
new_args = PyTuple_Pack(1, compiled_kernel);
|
|
||||||
}}
|
|
||||||
hook_ret = PyObject_CallObject(launch_exit_hook, new_args);
|
|
||||||
Py_DECREF(new_args);
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
if (hook_ret) {{
|
|
||||||
Py_DECREF(hook_ret);
|
|
||||||
}}
|
|
||||||
if(PyErr_Occurred()) {{
|
if(PyErr_Occurred()) {{
|
||||||
return NULL;
|
return NULL;
|
||||||
}}
|
}}
|
||||||
@@ -1214,7 +1202,7 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
|
|||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4,
|
def compile(fn, signature: str, device: int = -1, constants=dict(), constexpr=tuple(), num_warps: int = 4,
|
||||||
num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False):
|
num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False):
|
||||||
# we get the kernel, i.e. the first function generated in the module
|
# we get the kernel, i.e. the first function generated in the module
|
||||||
assert len(configs) == 1
|
assert len(configs) == 1
|
||||||
@@ -1250,7 +1238,7 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
|||||||
not fn_cache_manager.has_file(llir_name):
|
not fn_cache_manager.has_file(llir_name):
|
||||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||||
extern_libs, "cubin", cc)
|
extern_libs, "cubin", cc)
|
||||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages, "constexpr": constexpr}
|
||||||
fn_cache_manager.put(asm["cubin"], cubin_name)
|
fn_cache_manager.put(asm["cubin"], cubin_name)
|
||||||
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
|
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
|
||||||
fn_cache_manager.put(asm["ttir"], ttir_name, binary=False)
|
fn_cache_manager.put(asm["ttir"], ttir_name, binary=False)
|
||||||
@@ -1282,6 +1270,7 @@ class CompiledKernel:
|
|||||||
self.shared = metadata["shared"]
|
self.shared = metadata["shared"]
|
||||||
self.num_warps = metadata["num_warps"]
|
self.num_warps = metadata["num_warps"]
|
||||||
self.num_stages = metadata["num_stages"]
|
self.num_stages = metadata["num_stages"]
|
||||||
|
self.constexpr = metadata["constexpr"]
|
||||||
# initialize asm dict
|
# initialize asm dict
|
||||||
self.asm = dict()
|
self.asm = dict()
|
||||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||||
@@ -1305,7 +1294,7 @@ class CompiledKernel:
|
|||||||
if stream is None:
|
if stream is None:
|
||||||
stream = torch.cuda.current_stream().cuda_stream
|
stream = torch.cuda.current_stream().cuda_stream
|
||||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
|
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
|
||||||
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
|
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args, self.constexpr)
|
||||||
return runner
|
return runner
|
||||||
|
|
||||||
def get_sass(self, fun=None):
|
def get_sass(self, fun=None):
|
||||||
|
@@ -254,7 +254,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
try:
|
try:
|
||||||
bin = cache[device][key]
|
bin = cache[device][key]
|
||||||
if not warmup:
|
if not warmup:
|
||||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
|
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args}, constexpr_key)
|
||||||
return bin
|
return bin
|
||||||
# kernel not cached -- compile
|
# kernel not cached -- compile
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@@ -272,9 +272,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
if callable(arg):
|
if callable(arg):
|
||||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||||
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
bin = triton.compile(self, signature, device, constants, constexpr_key, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
||||||
if not warmup:
|
if not warmup:
|
||||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
|
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args, constexpr_key)
|
||||||
self.cache[device][key] = bin
|
self.cache[device][key] = bin
|
||||||
return bin
|
return bin
|
||||||
return None
|
return None
|
||||||
|
Reference in New Issue
Block a user