[FRONTEND] Backport new runtime from master
(#706)
This PR merges the new runtime back into the `triton-mlir` branch. This adds caching and just-in-time compilation functionality to the triton-mlir project, and paves the way for re-using tests from the master branch.
This commit is contained in:
@@ -1,10 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import sysconfig
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from sysconfig import get_paths
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import setuptools
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
@@ -85,7 +101,7 @@ class enter_sub_region:
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, module=None, is_kernel=False, function_types=dict()):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
self.module = self.builder.create_module() if module is None else module
|
||||
self.function_ret_types = function_types
|
||||
@@ -94,6 +110,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.lscope = dict()
|
||||
self.attributes = attributes
|
||||
self.constants = constants
|
||||
self.function_name = function_name
|
||||
self.is_kernel = is_kernel
|
||||
self.last_node = None
|
||||
self.builtins = {
|
||||
@@ -194,8 +211,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
||||
self.visit(init_node)
|
||||
# initialize function
|
||||
fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants)
|
||||
fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder))
|
||||
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder))
|
||||
self.module.push_back(fn)
|
||||
entry = fn.add_entry_block()
|
||||
arg_values = []
|
||||
@@ -206,10 +222,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if not isinstance(cst, triton.language.constexpr):
|
||||
cst = triton.language.constexpr(self.constants[i])
|
||||
arg_values.append(cst)
|
||||
continue
|
||||
else:
|
||||
pass
|
||||
if i in self.attributes:
|
||||
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i])
|
||||
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i][1])
|
||||
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
|
||||
@@ -746,21 +762,40 @@ class OutOfResources(Exception):
|
||||
return (type(self), (self.required, self.limit, self.name))
|
||||
|
||||
|
||||
def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
||||
def kernel_suffix(signature, specialization):
|
||||
# suffix format:
|
||||
# <argid><'c' if equal to 1><'d' if divisible by 16>
|
||||
suffix = ''
|
||||
for i, _ in enumerate(signature):
|
||||
suffix += str(i)
|
||||
if i in specialization.equal_to_1:
|
||||
suffix += 'c'
|
||||
if i in specialization.divisible_by_16:
|
||||
suffix += 'd'
|
||||
return suffix
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_triton_ir(fn, signature, specialization, constants):
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# create kernel prototype
|
||||
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
|
||||
attributes = {fn.arg_names.index(name): value for name, value in attributes.items()}
|
||||
if signature.replace(' ', '') != '':
|
||||
arg_types = signature.replace(' ', '').split(',')
|
||||
arg_types = [str_to_ty(x) for x in arg_types]
|
||||
else:
|
||||
arg_types = []
|
||||
prototype = triton.language.function_type([], arg_types)
|
||||
cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
|
||||
constants = {cst_key(key): value for key, value in constants.items()}
|
||||
# visit kernel AST
|
||||
gscope = fn.__globals__.copy()
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=constants, attributes=attributes, is_kernel=True)
|
||||
function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)])
|
||||
tys = list(signature.values())
|
||||
new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1}
|
||||
new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16}
|
||||
all_constants = constants.copy()
|
||||
all_constants.update(new_constants)
|
||||
arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants]
|
||||
|
||||
prototype = triton.language.function_type([], arg_types)
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True)
|
||||
try:
|
||||
generator.visit(fn.parse())
|
||||
except Exception as e:
|
||||
@@ -769,9 +804,9 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
||||
raise e
|
||||
raise CompilationError(fn.src, node) from e
|
||||
ret = generator.module
|
||||
# module takes ownership of the MLIR context
|
||||
# module takes ownership of the context
|
||||
ret.context = context
|
||||
return ret
|
||||
return ret, generator
|
||||
|
||||
|
||||
def optimize_triton_ir(mod):
|
||||
@@ -842,14 +877,21 @@ def ptx_get_kernel_name(ptx: str) -> str:
|
||||
return line.split()[-1]
|
||||
|
||||
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=dict(), num_warps: int = 4, num_stages: int = 3, output: str = "ttgir") -> Tuple[str, int, str]:
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
|
||||
|
||||
|
||||
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
|
||||
# triton-ir
|
||||
module = make_triton_ir(fn, signature, constants, attributes)
|
||||
module, _ = make_triton_ir(fn, signature, specialization, constants)
|
||||
module = optimize_triton_ir(module)
|
||||
if output == "ttir":
|
||||
return module.str()
|
||||
|
||||
# tritongpu-ir
|
||||
module = make_tritongpu_ir(module, num_warps)
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
@@ -865,6 +907,357 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), attributes=d
|
||||
|
||||
cubin = make_cubin(ptx, device)
|
||||
if output == "cubin":
|
||||
return cubin, shem_size, kernel_name
|
||||
return cubin, ptx, shem_size, kernel_name
|
||||
|
||||
assert False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# compiler
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def ty_to_cpp(ty):
|
||||
if ty[0] == '*':
|
||||
return "CUdeviceptr"
|
||||
return {
|
||||
"i1": "int32_t",
|
||||
"i8": "int8_t",
|
||||
"i16": "int16_t",
|
||||
"i32": "int32_t",
|
||||
"i64": "int64_t",
|
||||
"u32": "uint32_t",
|
||||
"u64": "uint64_t",
|
||||
"fp32": "float",
|
||||
}[ty]
|
||||
|
||||
|
||||
def generate_name_initializer(signature):
|
||||
src = "int i = 0;\n"
|
||||
tys = signature.split(',')
|
||||
for i, ty in enumerate(tys):
|
||||
src
|
||||
|
||||
|
||||
def binary_name_to_header_name(name):
|
||||
if len(name) > 128:
|
||||
# avoid filename too long errors (filename limit is 255)
|
||||
name = "kernel_" + hashlib.sha256(name.encode("utf-8")).hexdigest()
|
||||
return f"{name}.h"
|
||||
|
||||
|
||||
def generate_launcher(identifier, constants, signature):
|
||||
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
|
||||
|
||||
def _extracted_type(ty):
|
||||
if ty[0] == '*':
|
||||
return "PyObject*"
|
||||
return {
|
||||
'i1': 'int32_t',
|
||||
'i32': 'int32_t',
|
||||
'i64': 'int64_t',
|
||||
'u32': 'uint32_t',
|
||||
'u64': 'uint64_t',
|
||||
'fp32': 'float',
|
||||
'fp64': 'double',
|
||||
}[ty]
|
||||
|
||||
def format_of(ty):
|
||||
return {
|
||||
"PyObject*": "O",
|
||||
"float": "f",
|
||||
"double": "d",
|
||||
"long": "l",
|
||||
"uint32_t": "I",
|
||||
"int32_t": "i",
|
||||
"uint64_t": "K",
|
||||
"int64_t": "L",
|
||||
}[ty]
|
||||
|
||||
format = "iiiiiKK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
||||
|
||||
# generate glue code
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <Python.h>
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
{{
|
||||
if (code != CUDA_SUCCESS)
|
||||
{{
|
||||
const char* prefix = "Triton Error [CUDA]: ";
|
||||
const char* str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {{0}};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
}}
|
||||
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||
if(gridX*gridY*gridZ > 0){{
|
||||
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
|
||||
}}
|
||||
}}
|
||||
static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
||||
if (PyLong_Check(obj)) {{
|
||||
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
|
||||
}}
|
||||
if (obj == Py_None) {{
|
||||
return (CUdeviceptr)0;
|
||||
}}
|
||||
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
|
||||
if(ptr){{
|
||||
PyObject *empty_tuple = PyTuple_New(0);
|
||||
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
|
||||
Py_DECREF(empty_tuple);
|
||||
Py_DECREF(ptr);
|
||||
if (!PyLong_Check(ret)) {{
|
||||
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
|
||||
}}
|
||||
return (CUdeviceptr)PyLong_AsUnsignedLongLong(ret);
|
||||
}}
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return (CUdeviceptr)0;
|
||||
}}
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
uint64_t _function;
|
||||
int num_warps;
|
||||
int shared_memory;
|
||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
||||
return NULL;
|
||||
}}
|
||||
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
||||
if(PyErr_Occurred()) {{
|
||||
return NULL;
|
||||
}}
|
||||
// return None
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}}
|
||||
static PyMethodDef ModuleMethods[] = {{
|
||||
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||
{{NULL, NULL, 0, NULL}} // sentinel
|
||||
}};
|
||||
static struct PyModuleDef ModuleDef = {{
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"launcher\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
}};
|
||||
PyMODINIT_FUNC PyInit_launcher(void) {{
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {{
|
||||
return NULL;
|
||||
}}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}}
|
||||
"""
|
||||
|
||||
return src
|
||||
|
||||
|
||||
def default_cache_dir():
|
||||
return os.path.join(os.environ["HOME"], ".triton", "cache")
|
||||
|
||||
|
||||
class CacheManager:
|
||||
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
|
||||
if self.cache_dir:
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
|
||||
def _make_path(self, filename):
|
||||
return os.path.join(self.cache_dir, filename)
|
||||
|
||||
def has_file(self, filename):
|
||||
if not self.cache_dir:
|
||||
return False
|
||||
return os.path.exists(self._make_path(filename))
|
||||
|
||||
def put(self, data, filename, binary=True):
|
||||
if not self.cache_dir:
|
||||
return
|
||||
assert self.lock_path is not None
|
||||
filepath = self._make_path(filename)
|
||||
with FileLock(self.lock_path):
|
||||
# use tempfile to be robust against program interruptions
|
||||
mode = "wb" if binary else "w"
|
||||
with open(filepath + ".tmp", mode) as f:
|
||||
f.write(data)
|
||||
os.rename(filepath + ".tmp", filepath)
|
||||
|
||||
|
||||
# utilties for generating and compiling C wrappers
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def libcuda_dir():
|
||||
loc = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[-1]
|
||||
return os.path.dirname(loc)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def quiet():
|
||||
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.stdout, sys.stderr = old_stdout, old_stderr
|
||||
|
||||
|
||||
def _build(name, src, srcdir):
|
||||
cuda_lib_dir = libcuda_dir()
|
||||
cu_include_dir = "/usr/local/cuda/include"
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
||||
# try to avoid setuptools if possible
|
||||
cc = os.environ.get("CC")
|
||||
if cc is None:
|
||||
# TODO: support more things here.
|
||||
clang = shutil.which("clang")
|
||||
gcc = shutil.which("gcc")
|
||||
cc = gcc if gcc is not None else clang
|
||||
py_include_dir = get_paths()["include"]
|
||||
ret = subprocess.check_call([cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{cuda_lib_dir}", "-lcuda", "-o", so])
|
||||
if ret == 0:
|
||||
return so
|
||||
# fallback on setuptools
|
||||
extra_compile_args = []
|
||||
library_dirs = [cuda_lib_dir]
|
||||
include_dirs = [srcdir, cu_include_dir]
|
||||
libraries = ['cuda']
|
||||
# extra arguments
|
||||
extra_link_args = []
|
||||
# create extension module
|
||||
ext = setuptools.Extension(
|
||||
name=name,
|
||||
language='c',
|
||||
sources=[src],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=extra_compile_args + ['-O3'],
|
||||
extra_link_args=extra_link_args,
|
||||
library_dirs=library_dirs,
|
||||
libraries=libraries,
|
||||
)
|
||||
# build extension module
|
||||
args = ['build_ext']
|
||||
args.append('--build-temp=' + srcdir)
|
||||
args.append('--build-lib=' + srcdir)
|
||||
args.append('-q')
|
||||
args = dict(
|
||||
name=name,
|
||||
ext_modules=[ext],
|
||||
script_args=args,
|
||||
)
|
||||
with quiet():
|
||||
setuptools.setup(**args)
|
||||
return so
|
||||
|
||||
|
||||
def make_so_cache_key(signature, constants):
|
||||
# Get unique key for the compiled code
|
||||
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
|
||||
key = f"{''.join(signature.values())}{constants}"
|
||||
key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
return key
|
||||
|
||||
|
||||
def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_stages):
|
||||
# Get unique key for the compiled code
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
key = f"{fn_hash}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
|
||||
key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
return key
|
||||
|
||||
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
# we get the kernel, i.e. the first function generated in the module
|
||||
if configs is None:
|
||||
configs = [instance_descriptor()]
|
||||
assert len(configs) == 1
|
||||
# cache manager
|
||||
name = fn.__name__
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(signature, constants)
|
||||
so_cache_manager = CacheManager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
if not so_cache_manager.has_file(so_name):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src = generate_launcher(name, constants, signature)
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build(fn.__name__, src_path, tmpdir)
|
||||
with open(so, "rb") as f:
|
||||
so_cache_manager.put(f.read(), so_name, binary=True)
|
||||
|
||||
# retrieve cached shared object if it exists
|
||||
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
||||
fn_cache_manager = CacheManager(fn_cache_key)
|
||||
ptx_name = f"{name}.ptx"
|
||||
cubin_name = f"{name}.cubin"
|
||||
data_name = f"{name}.json"
|
||||
if not fn_cache_manager.has_file(cubin_name) or \
|
||||
not fn_cache_manager.has_file(data_name) or \
|
||||
not fn_cache_manager.has_file(ptx_name):
|
||||
cubin, ptx, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
|
||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||
fn_cache_manager.put(cubin, cubin_name)
|
||||
fn_cache_manager.put(ptx, ptx_name, binary=False)
|
||||
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
||||
|
||||
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
|
||||
|
||||
|
||||
class CompiledKernel:
|
||||
|
||||
def __init__(self, fn_name, so_path, cache_dir):
|
||||
|
||||
# initialize launcher
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.c_wrapper = getattr(mod, "launch")
|
||||
# initialize metadata
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.json")) as f:
|
||||
metadata = json.load(f)
|
||||
self.shared = metadata["shared"]
|
||||
self.num_warps = metadata["num_warps"]
|
||||
self.num_stages = metadata["num_stages"]
|
||||
# initialize asm dict
|
||||
self.asm = dict()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||
self.asm["cubin"] = f.read()
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||
self.asm["ptx"] = f.read()
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
mod, func, n_regs, n_spills = _triton.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
|
||||
def __getitem__(self, grid):
|
||||
def runner(*args, stream=None):
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream().cuda_stream
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
|
||||
return
|
||||
|
Reference in New Issue
Block a user