[RUNTIME] Now using pybind11 to avoid memory leaks (#377)
This commit is contained in:
@@ -110,18 +110,19 @@ std::string pow2_divisor(long N){
|
||||
}
|
||||
|
||||
// Launch
|
||||
void parse_args(py::handle& args, py::handle do_not_specialize, const std::string& func_key, py::handle& arg_names,
|
||||
std::string& cache_key, std::string& params, size_t& params_size, PyObject* constants,
|
||||
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
|
||||
int num_warps, int num_stages) {
|
||||
size_t len = PyList_Size(args.ptr());
|
||||
params.reserve(8*len); // 8 max bytes by argument
|
||||
char* params_ptr = ¶ms[0];
|
||||
cache_key = func_key;
|
||||
for(int i = 0; i < len; i++){
|
||||
PyObject* py_i = PyLong_FromLong(i);
|
||||
bool specialize = !PySequence_Contains(do_not_specialize.ptr(), py_i);
|
||||
auto arg_ptr = PyList_GetItem(args.ptr(), i);
|
||||
auto arg = py::handle(arg_ptr);
|
||||
py::int_ py_i = py::int_(i);
|
||||
bool specialize = std::find(do_not_specialize.begin(), do_not_specialize.end(), py_i) == do_not_specialize.end();
|
||||
py::object arg = args[i];
|
||||
auto arg_ptr = arg.ptr();
|
||||
|
||||
// argument is `long`
|
||||
if(PyLong_Check(arg_ptr)){
|
||||
int overflow;
|
||||
@@ -172,28 +173,28 @@ void parse_args(py::handle& args, py::handle do_not_specialize, const std::strin
|
||||
continue;
|
||||
}
|
||||
// argument is tensor
|
||||
PyObject* data_ptr = PyObject_CallMethod(arg_ptr, "data_ptr", nullptr);
|
||||
if(data_ptr){
|
||||
if(py::hasattr(arg, "data_ptr")){
|
||||
py::object data_ptr = arg.attr("data_ptr")();
|
||||
cache_key += "P";
|
||||
long value = PyLong_AsLong(data_ptr);
|
||||
long value = data_ptr.cast<long>();
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
PyObject* dtype = PyObject_GetAttrString(arg_ptr, "dtype");
|
||||
PyObject* repr = PyObject_Repr(dtype);
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr) + 6; // remove 'torch.'
|
||||
size_t len = PyUnicode_GET_LENGTH(repr) - 6;
|
||||
py::object dtype = arg.attr("dtype");
|
||||
py::object repr = py::repr(dtype);
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
|
||||
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
|
||||
cache_key += std::string(start, len);
|
||||
continue;
|
||||
}
|
||||
// argument is `constexpr`
|
||||
PyObject* value = PyObject_GetAttrString(arg_ptr, "value");
|
||||
py::object value = arg.attr("value");
|
||||
if(value){
|
||||
PyObject* name = PyList_GetItem(arg_names.ptr(), i);
|
||||
PyDict_SetItem(constants, name, value);
|
||||
PyObject* repr = PyObject_Repr(value);
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr);
|
||||
size_t len = PyUnicode_GET_LENGTH(repr);
|
||||
py::object name = arg_names[i];
|
||||
constants[name] = value;
|
||||
py::object repr = py::repr(value);
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr());
|
||||
size_t len = PyUnicode_GET_LENGTH(repr.ptr());
|
||||
cache_key += std::string(start, len);
|
||||
continue;
|
||||
}
|
||||
@@ -228,37 +229,39 @@ void init_triton_runtime(py::module &&m) {
|
||||
);
|
||||
|
||||
// cache key
|
||||
m.def("launch", [](py::handle args, py::handle do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
py::handle device, py::handle stream, py::handle bin_cache, py::handle num_warps, py::handle num_stages,
|
||||
py::handle add_to_cache, py::handle grid){
|
||||
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
|
||||
py::function add_to_cache, py::object grid){
|
||||
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
|
||||
long _num_warps = PyLong_AsLong(num_warps.ptr());
|
||||
long _num_stages = PyLong_AsLong(num_stages.ptr());
|
||||
std::string cache_key;
|
||||
std::string params;
|
||||
size_t params_size;
|
||||
PyObject* constants = PyDict_New();
|
||||
py::dict constants;
|
||||
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages);
|
||||
|
||||
// get cached binary
|
||||
PyObject* key = PyUnicode_FromString(cache_key.c_str());
|
||||
PyObject* bin = nullptr;
|
||||
if(!PyDict_Contains(bin_cache.ptr(), key)){
|
||||
add_to_cache(py::handle(key), args, device, num_warps, num_stages);
|
||||
}
|
||||
bin = PyDict_GetItem(bin_cache.ptr(), key);
|
||||
py::str key(cache_key);
|
||||
if(!bin_cache.contains(key))
|
||||
add_to_cache(key, args, device, num_warps, num_stages);
|
||||
py::object bin = bin_cache[key];
|
||||
|
||||
// get grid
|
||||
PyObject* grid_ptr = grid.ptr();
|
||||
if(!PySequence_Check(grid_ptr)){
|
||||
PyObject* grid_call = PyObject_GetAttrString(grid_ptr, "__call__");
|
||||
grid_ptr = PyObject_Call(grid_call, PyTuple_Pack(1, constants), nullptr);
|
||||
}
|
||||
int size = PySequence_Size(grid_ptr);
|
||||
int grid_0 = PyLong_AsLong(PySequence_GetItem(grid_ptr, 0));
|
||||
int grid_1 = size < 2 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 1));
|
||||
int grid_2 = size < 3 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 2));
|
||||
py::sequence seq;
|
||||
if(!PySequence_Check(grid.ptr()))
|
||||
seq = grid(constants);
|
||||
else
|
||||
seq = grid;
|
||||
int size = seq.size();
|
||||
int grid_0 = py::cast<int>(seq[0]);
|
||||
int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
|
||||
int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
|
||||
|
||||
// enqueue
|
||||
uint64_t kernel = PyLong_AsLong(PyObject_GetAttrString(bin, "kernel"));
|
||||
uint64_t shared_mem = PyLong_AsLong(PyObject_GetAttrString(bin, "shared_mem"));
|
||||
uint64_t kernel = py::cast<uint64_t>(bin.attr("kernel"));
|
||||
uint64_t shared_mem = py::cast<uint64_t>(bin.attr("shared_mem"));
|
||||
|
||||
// actually launch
|
||||
void *config[] = {
|
||||
CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
|
||||
@@ -269,7 +272,7 @@ void init_triton_runtime(py::module &&m) {
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||
nullptr, config);
|
||||
return py::handle(bin);
|
||||
return bin;
|
||||
});
|
||||
|
||||
// query maximum shared memory
|
||||
|
@@ -493,184 +493,6 @@ class OutOfResources(Exception):
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
# class Kernel:
|
||||
# @staticmethod
|
||||
# def _type_name(obj):
|
||||
# type_names = {
|
||||
# triton.language.float8: 'f8',
|
||||
# torch.bfloat16: 'bf16',
|
||||
# torch.float16: 'f16',
|
||||
# torch.float32: 'f32',
|
||||
# torch.float64: 'f64',
|
||||
# torch.bool: 'i1',
|
||||
# torch.int8: 'i8',
|
||||
# torch.int16: 'i16',
|
||||
# torch.int32: 'i32',
|
||||
# torch.int64: 'i64',
|
||||
# }
|
||||
# if hasattr(obj, 'data_ptr'):
|
||||
# return type_names[obj.dtype]
|
||||
# if isinstance(obj, triton.language.core.constexpr):
|
||||
# obj = obj.value
|
||||
# if isinstance(obj, int):
|
||||
# if abs(obj) <= 0xffffffff:
|
||||
# return 'I'
|
||||
# return 'L'
|
||||
# if isinstance(obj, float):
|
||||
# return 'f'
|
||||
# if isinstance(obj, bool):
|
||||
# return 'B'
|
||||
# if isinstance(obj, str):
|
||||
# return 'str'
|
||||
# assert False
|
||||
|
||||
|
||||
|
||||
# @staticmethod
|
||||
# def _to_triton_ir(context, obj):
|
||||
# type_map = {
|
||||
# 'I': _triton.ir.type.get_int32,
|
||||
# 'L': _triton.ir.type.get_int64,
|
||||
# 'f': _triton.ir.type.get_fp32,
|
||||
# 'B': _triton.ir.type.get_int1,
|
||||
# 'f8': _triton.ir.type.get_fp8,
|
||||
# 'f16': _triton.ir.type.get_fp16,
|
||||
# 'bf16': _triton.ir.type.get_bf16,
|
||||
# 'f32': _triton.ir.type.get_fp32,
|
||||
# 'f64': _triton.ir.type.get_fp64,
|
||||
# 'i1': _triton.ir.type.get_int1,
|
||||
# 'i8': _triton.ir.type.get_int8,
|
||||
# 'i16': _triton.ir.type.get_int16,
|
||||
# 'i32': _triton.ir.type.get_int32,
|
||||
# 'i64': _triton.ir.type.get_int64,
|
||||
# }
|
||||
# # convert torch.Tensor to Triton IR pointers
|
||||
# if hasattr(obj, 'data_ptr'):
|
||||
# name = Kernel._type_name(obj)
|
||||
# elt_ty = type_map[name](context)
|
||||
# return _triton.ir.type.make_ptr(elt_ty, 1)
|
||||
# # default path returns triton.ir.type directly
|
||||
# name = Kernel._type_name(obj)
|
||||
# return type_map[name](context)
|
||||
|
||||
# @staticmethod
|
||||
# def pow2_divisor(N):
|
||||
# if N % 16 == 0: return 16
|
||||
# if N % 8 == 0: return 8
|
||||
# if N % 4 == 0: return 4
|
||||
# if N % 2 == 0: return 2
|
||||
# return 1
|
||||
|
||||
# def __init__(self, fn):
|
||||
# self.fn = fn
|
||||
|
||||
# def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
|
||||
# wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)]
|
||||
# # create IR module
|
||||
# context = _triton.ir.context()
|
||||
# # get just-in-time proto-type of kernel
|
||||
# arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs]
|
||||
# ret_type = _triton.ir.type.get_void(context)
|
||||
# prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||
# # generate Triton-IR
|
||||
# # export symbols visible from self.fn into code-generator object
|
||||
# gscope = sys.modules[self.fn.module].__dict__
|
||||
# generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||
# try:
|
||||
# generator.visit(self.fn.parse())
|
||||
# except Exception as e:
|
||||
# node = generator.last_node
|
||||
# if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||
# raise e
|
||||
# raise CompilationError(self.fn.src, node, e)
|
||||
# # Compile to machine code
|
||||
# if torch.version.hip is None:
|
||||
# backend = _triton.runtime.backend.CUDA
|
||||
# else:
|
||||
# backend = _triton.runtime.backend.ROCM
|
||||
# name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
|
||||
# max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
# if shared_mem > max_shared_memory:
|
||||
# raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||
# return Binary(backend, name, asm, shared_mem, num_warps)
|
||||
|
||||
# def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
||||
# tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
# # attributes
|
||||
# args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
# attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
||||
# if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||
|
||||
# # transforms ints whose value is one into constants for just-in-time compilation
|
||||
# constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||
# constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||
# hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
|
||||
# # create cache directory
|
||||
# cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
||||
# if cache_dir and not os.path.exists(cache_dir):
|
||||
# os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# if cache_dir:
|
||||
# bin_cache_path = os.path.join(cache_dir, hashed_key)
|
||||
# bin_lock_path = bin_cache_path + ".lock"
|
||||
# else:
|
||||
# bin_cache_path = None
|
||||
# bin_lock_path = None
|
||||
|
||||
# binary = None
|
||||
# if bin_cache_path and os.path.exists(bin_cache_path):
|
||||
# assert bin_lock_path is not None
|
||||
# with FileLock(bin_lock_path):
|
||||
# with open(bin_cache_path, 'rb') as f:
|
||||
# binary = pickle.load(f)["binary"]
|
||||
# if binary is None:
|
||||
# binary = self._compile(
|
||||
# *wargs, device=device_idx, attributes=attributes,
|
||||
# num_warps=num_warps, num_stages=num_stages,
|
||||
# constants=constants,
|
||||
# )
|
||||
# if bin_cache_path:
|
||||
# assert bin_lock_path is not None
|
||||
# with FileLock(bin_lock_path):
|
||||
# with open(bin_cache_path + ".tmp", "wb") as f:
|
||||
# pickle.dump({"binary": binary, "key": key}, f)
|
||||
# os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||
# if JITFunction.cache_hook is not None:
|
||||
# JITFunction.cache_hook(key=key, binary=binary)
|
||||
|
||||
# self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)
|
||||
|
||||
# def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||
# # handle arguments passed by name
|
||||
# kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
|
||||
# wargs = list(wargs)
|
||||
# for i, pos in enumerate(sorted(kwargs)):
|
||||
# wargs.insert(pos + i, kwargs[pos])
|
||||
# if len(wargs) != len(self.fn.arg_names):
|
||||
# raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
|
||||
# # handle annotations
|
||||
# for pos, _type in self.fn.annotations.items():
|
||||
# wargs[pos] = _type(wargs[pos])
|
||||
# # query device index and cuda stream
|
||||
# device = torch.cuda.current_device()
|
||||
# torch.cuda.set_device(device)
|
||||
# cc = torch.cuda.get_device_capability(device)
|
||||
# cc = str(cc[0]) + '-' + str(cc[1])
|
||||
# # # query stream
|
||||
# # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream`
|
||||
# # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154
|
||||
# # # building a C wrapper to re-use the unpack function would add a build-time torch dependency
|
||||
# # # and require different wheels for different torch versions -- undesirable!
|
||||
# # bits = torch._C._cuda_getCurrentStream(device)
|
||||
# # mask = 1 << 47
|
||||
# # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask
|
||||
# stream = torch.cuda.current_stream(device).cuda_stream
|
||||
# # make key for cache
|
||||
# return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream,
|
||||
# self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
|
||||
|
||||
|
||||
class Kernel:
|
||||
@staticmethod
|
||||
def _type_name(obj):
|
||||
@@ -700,8 +522,6 @@ class Kernel:
|
||||
return 'B'
|
||||
if isinstance(obj, str):
|
||||
return 'str'
|
||||
if isinstance(obj, JITFunction):
|
||||
return ''
|
||||
assert False
|
||||
|
||||
|
||||
@@ -733,16 +553,6 @@ class Kernel:
|
||||
name = Kernel._type_name(obj)
|
||||
return type_map[name](context)
|
||||
|
||||
@staticmethod
|
||||
def _types_key(*wargs, tensor_idxs):
|
||||
# type inference
|
||||
types_key = [None] * len(wargs)
|
||||
for i, arg in enumerate(wargs):
|
||||
prefix = 'P' if i in tensor_idxs else ''
|
||||
suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg)
|
||||
types_key[i] = prefix + suffix
|
||||
return tuple(types_key)
|
||||
|
||||
@staticmethod
|
||||
def pow2_divisor(N):
|
||||
if N % 16 == 0: return 16
|
||||
@@ -784,6 +594,53 @@ class Kernel:
|
||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||
return Binary(backend, name, asm, shared_mem, num_warps)
|
||||
|
||||
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
|
||||
# create cache directory
|
||||
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
||||
if cache_dir and not os.path.exists(cache_dir):
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
if cache_dir:
|
||||
bin_cache_path = os.path.join(cache_dir, hashed_key)
|
||||
bin_lock_path = bin_cache_path + ".lock"
|
||||
else:
|
||||
bin_cache_path = None
|
||||
bin_lock_path = None
|
||||
|
||||
binary = None
|
||||
if bin_cache_path and os.path.exists(bin_cache_path):
|
||||
assert bin_lock_path is not None
|
||||
with FileLock(bin_lock_path):
|
||||
with open(bin_cache_path, 'rb') as f:
|
||||
binary = pickle.load(f)["binary"]
|
||||
if binary is None:
|
||||
binary = self._compile(
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages,
|
||||
constants=constants,
|
||||
)
|
||||
if bin_cache_path:
|
||||
assert bin_lock_path is not None
|
||||
with FileLock(bin_lock_path):
|
||||
with open(bin_cache_path + ".tmp", "wb") as f:
|
||||
pickle.dump({"binary": binary, "key": key}, f)
|
||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||
if JITFunction.cache_hook is not None:
|
||||
JITFunction.cache_hook(key=key, binary=binary)
|
||||
|
||||
self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||
# handle arguments passed by name
|
||||
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
|
||||
@@ -793,112 +650,25 @@ class Kernel:
|
||||
if len(wargs) != len(self.fn.arg_names):
|
||||
raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
|
||||
# handle annotations
|
||||
for name, type in self.fn.__annotations__.items():
|
||||
pos = self.fn.arg_names.index(name)
|
||||
assert type == triton.language.core.constexpr
|
||||
wargs[pos] = type(wargs[pos])
|
||||
# device inference
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
if len(tensor_idxs) == 0:
|
||||
raise ValueError("No Tensor argument found.")
|
||||
invalid_args = []
|
||||
device_ids = []
|
||||
for idx in tensor_idxs:
|
||||
curr = wargs[idx]
|
||||
if not curr.is_cuda:
|
||||
invalid_args.append(idx)
|
||||
else:
|
||||
device_ids.append(curr.device.index)
|
||||
if invalid_args:
|
||||
raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) +
|
||||
" Only CUDA is supported at the moment")
|
||||
|
||||
device = torch.device('cuda', torch.cuda.current_device())
|
||||
device_idx = device.index
|
||||
# if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
|
||||
# # try to enable P2P communication
|
||||
# for arg_idx, dst_idx in zip(tensor_idxs, device_ids):
|
||||
# if dst_idx != device_idx:
|
||||
# try:
|
||||
# _triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr())
|
||||
# except RuntimeError as e:
|
||||
# raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}"
|
||||
# .format(device_idx, dst_idx, str(e)))
|
||||
|
||||
# enqueue kernel on the current device
|
||||
torch.cuda.set_device(device_idx)
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||
|
||||
# compute hash for caching this kernel
|
||||
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
||||
attr_key = tuple(attributes.items())
|
||||
const_key = tuple(constants.items())
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
key = (
|
||||
self.fn.cache_key, version_key(), compute_capability,
|
||||
types_key, attr_key, num_warps, num_stages, const_key
|
||||
)
|
||||
key = repr(key)
|
||||
|
||||
# get cached binary
|
||||
bin_cache = self.fn.bin_cache
|
||||
|
||||
if key not in bin_cache:
|
||||
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
|
||||
# create cache directory
|
||||
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
||||
if cache_dir and not os.path.exists(cache_dir):
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
if cache_dir:
|
||||
bin_cache_path = os.path.join(cache_dir, hashed_key)
|
||||
bin_lock_path = bin_cache_path + ".lock"
|
||||
else:
|
||||
bin_cache_path = None
|
||||
bin_lock_path = None
|
||||
|
||||
binary = None
|
||||
if bin_cache_path and os.path.exists(bin_cache_path):
|
||||
assert bin_lock_path is not None
|
||||
with FileLock(bin_lock_path):
|
||||
with open(bin_cache_path, 'rb') as f:
|
||||
binary = pickle.load(f)["binary"]
|
||||
if binary is None:
|
||||
binary = self._compile(
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages,
|
||||
constants=constants,
|
||||
)
|
||||
if bin_cache_path:
|
||||
assert bin_lock_path is not None
|
||||
with FileLock(bin_lock_path):
|
||||
with open(bin_cache_path + ".tmp", "wb") as f:
|
||||
pickle.dump({"binary": binary, "key": key}, f)
|
||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||
if JITFunction.cache_hook is not None:
|
||||
JITFunction.cache_hook(key=key, binary=binary)
|
||||
|
||||
bin_cache[key] = LoadedBinary(device_idx, binary)
|
||||
# pack arguments
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)])
|
||||
params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)])
|
||||
# enqueue cached function into stream
|
||||
callable = bin_cache[key]
|
||||
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
||||
csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)}
|
||||
grid = grid(csts) if hasattr(grid, '__call__') else grid
|
||||
if isinstance(grid, int):
|
||||
grid = tuple(grid)
|
||||
callable(stream, params, *grid)
|
||||
return callable
|
||||
for pos, _type in self.fn.annotations.items():
|
||||
wargs[pos] = _type(wargs[pos])
|
||||
# query device index and cuda stream
|
||||
device = torch.cuda.current_device()
|
||||
torch.cuda.set_device(device)
|
||||
cc = torch.cuda.get_device_capability(device)
|
||||
cc = str(cc[0]) + '-' + str(cc[1])
|
||||
# # query stream
|
||||
# # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream`
|
||||
# # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154
|
||||
# # building a C wrapper to re-use the unpack function would add a build-time torch dependency
|
||||
# # and require different wheels for different torch versions -- undesirable!
|
||||
# bits = torch._C._cuda_getCurrentStream(device)
|
||||
# mask = 1 << 47
|
||||
# stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask
|
||||
stream = torch.cuda.current_stream(device).cuda_stream
|
||||
# make key for cache
|
||||
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream,
|
||||
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
|
||||
|
||||
|
||||
class Launcher:
|
||||
|
Reference in New Issue
Block a user