[FRONTEND] Significantly reduce kernel launch time (#367)
This commit is contained in:
@@ -127,7 +127,7 @@ setup(
|
|||||||
description="A language and compiler for custom Deep Learning operations",
|
description="A language and compiler for custom Deep Learning operations",
|
||||||
long_description="",
|
long_description="",
|
||||||
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
|
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
|
||||||
install_requires=["torch", "filelock"],
|
install_requires=["cmake", "torch", "filelock"],
|
||||||
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||||
|
@@ -13,6 +13,7 @@
|
|||||||
#include <pybind11/functional.h>
|
#include <pybind11/functional.h>
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
|
#include "Python.h"
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
@@ -23,6 +24,7 @@ namespace py = pybind11;
|
|||||||
namespace ir = triton::ir;
|
namespace ir = triton::ir;
|
||||||
namespace drv = triton::driver;
|
namespace drv = triton::driver;
|
||||||
|
|
||||||
|
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
/* Python bindings for triton::driver */
|
/* Python bindings for triton::driver */
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
@@ -99,8 +101,113 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string pow2_divisor(long N){
|
||||||
|
if(N % 16 == 0) return "16";
|
||||||
|
if(N % 8 == 0) return "8";
|
||||||
|
if(N % 4 == 0) return "4";
|
||||||
|
if(N % 2 == 0) return "2";
|
||||||
|
return "1";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch
|
||||||
|
void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_names,
|
||||||
|
std::string& cache_key, std::string& params, size_t& params_size, PyObject* constants,
|
||||||
|
int num_warps, int num_stages) {
|
||||||
|
size_t len = PyList_Size(args.ptr());
|
||||||
|
params.reserve(8*len); // 8 max bytes by argument
|
||||||
|
char* params_ptr = ¶ms[0];
|
||||||
|
cache_key = func_key;
|
||||||
|
for(int i = 0; i < len; i++){
|
||||||
|
auto arg_ptr = PyList_GetItem(args.ptr(), i);
|
||||||
|
auto arg = py::handle(arg_ptr);
|
||||||
|
// argument is `long`
|
||||||
|
if(PyLong_Check(arg_ptr)){
|
||||||
|
int overflow;
|
||||||
|
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
|
||||||
|
// long and int have different kernels
|
||||||
|
if(!overflow & (std::abs(value) <= 0xffffffff)){
|
||||||
|
cache_key += 'I';
|
||||||
|
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||||
|
std::memcpy(params_ptr, &value, 4);
|
||||||
|
params_ptr += 4;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
cache_key += 'L';
|
||||||
|
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||||
|
if(overflow){
|
||||||
|
unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr);
|
||||||
|
std::memcpy(&value, &uvalue, 8);
|
||||||
|
}
|
||||||
|
std::memcpy(params_ptr, &value, 8);
|
||||||
|
params_ptr += 8;
|
||||||
|
}
|
||||||
|
// values equal to 1 are specialized
|
||||||
|
if(value == 1)
|
||||||
|
cache_key += '1';
|
||||||
|
else
|
||||||
|
cache_key += 'x';
|
||||||
|
// values divisible by small powers of 2 are specialized
|
||||||
|
cache_key += pow2_divisor(value);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// argument is `float`
|
||||||
|
if(PyFloat_Check(arg_ptr)){
|
||||||
|
cache_key += "f";
|
||||||
|
float value = PyFloat_AsDouble(arg_ptr);
|
||||||
|
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||||
|
std::memcpy(params_ptr, &value, 4);
|
||||||
|
params_ptr += 4;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// argument is `bool`
|
||||||
|
if(PyBool_Check(arg_ptr)){
|
||||||
|
cache_key += "B";
|
||||||
|
bool value = arg_ptr == Py_True ? true : false;
|
||||||
|
std::memcpy(params_ptr, &value, 1);
|
||||||
|
params_ptr += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// argument is tensor
|
||||||
|
PyObject* data_ptr = PyObject_CallMethod(arg_ptr, "data_ptr", nullptr);
|
||||||
|
if(data_ptr){
|
||||||
|
cache_key += "P";
|
||||||
|
long value = PyLong_AsLong(data_ptr);
|
||||||
|
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||||
|
std::memcpy(params_ptr, &value, 8);
|
||||||
|
params_ptr += 8;
|
||||||
|
PyObject* dtype = PyObject_GetAttrString(arg_ptr, "dtype");
|
||||||
|
PyObject* repr = PyObject_Repr(dtype);
|
||||||
|
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr) + 6; // remove 'torch.'
|
||||||
|
size_t len = PyUnicode_GET_LENGTH(repr) - 6;
|
||||||
|
cache_key += std::string(start, len);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// argument is `constexpr`
|
||||||
|
PyObject* value = PyObject_GetAttrString(arg_ptr, "value");
|
||||||
|
if(value){
|
||||||
|
PyObject* name = PyList_GetItem(arg_names.ptr(), i);
|
||||||
|
PyDict_SetItem(constants, name, value);
|
||||||
|
PyObject* repr = PyObject_Repr(value);
|
||||||
|
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr);
|
||||||
|
size_t len = PyUnicode_GET_LENGTH(repr);
|
||||||
|
cache_key += std::string(start, len);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
cache_key += std::to_string(num_warps);
|
||||||
|
cache_key += std::to_string(num_stages);
|
||||||
|
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
|
||||||
void init_triton_runtime(py::module &&m) {
|
void init_triton_runtime(py::module &&m) {
|
||||||
|
|
||||||
|
// m.def("current_stream", [](uint64_t device){
|
||||||
|
// return (uint64_t)(c10::cuda::getCurrentCUDAStream(device).stream());
|
||||||
|
// });
|
||||||
|
|
||||||
// wrap backend_t
|
// wrap backend_t
|
||||||
py::enum_<backend_t>(m, "backend")
|
py::enum_<backend_t>(m, "backend")
|
||||||
.value("HOST", HOST)
|
.value("HOST", HOST)
|
||||||
@@ -116,6 +223,51 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// cache key
|
||||||
|
m.def("launch", [](py::handle args, const std::string& func_key, py::list& arg_names,
|
||||||
|
py::handle device, py::handle stream, py::handle bin_cache, py::handle num_warps, py::handle num_stages,
|
||||||
|
py::handle add_to_cache, py::handle grid){
|
||||||
|
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
|
||||||
|
long _num_warps = PyLong_AsLong(num_warps.ptr());
|
||||||
|
long _num_stages = PyLong_AsLong(num_stages.ptr());
|
||||||
|
std::string cache_key;
|
||||||
|
std::string params;
|
||||||
|
size_t params_size;
|
||||||
|
PyObject* constants = PyDict_New();
|
||||||
|
parse_args(args, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages);
|
||||||
|
// get cached binary
|
||||||
|
PyObject* key = PyUnicode_FromString(cache_key.c_str());
|
||||||
|
PyObject* bin = nullptr;
|
||||||
|
if(!PyDict_Contains(bin_cache.ptr(), key)){
|
||||||
|
add_to_cache(py::handle(key), args, device, num_warps, num_stages);
|
||||||
|
}
|
||||||
|
bin = PyDict_GetItem(bin_cache.ptr(), key);
|
||||||
|
// get grid
|
||||||
|
PyObject* grid_ptr = grid.ptr();
|
||||||
|
if(!PySequence_Check(grid_ptr)){
|
||||||
|
PyObject* grid_call = PyObject_GetAttrString(grid_ptr, "__call__");
|
||||||
|
grid_ptr = PyObject_Call(grid_call, PyTuple_Pack(1, constants), nullptr);
|
||||||
|
}
|
||||||
|
int size = PySequence_Size(grid_ptr);
|
||||||
|
int grid_0 = PyLong_AsLong(PySequence_GetItem(grid_ptr, 0));
|
||||||
|
int grid_1 = size < 2 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 1));
|
||||||
|
int grid_2 = size < 3 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 2));
|
||||||
|
// enqueue
|
||||||
|
uint64_t kernel = PyLong_AsLong(PyObject_GetAttrString(bin, "kernel"));
|
||||||
|
uint64_t shared_mem = PyLong_AsLong(PyObject_GetAttrString(bin, "shared_mem"));
|
||||||
|
// actually launch
|
||||||
|
void *config[] = {
|
||||||
|
CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
|
||||||
|
CU_LAUNCH_PARAM_BUFFER_SIZE, ¶ms_size,
|
||||||
|
CU_LAUNCH_PARAM_END
|
||||||
|
};
|
||||||
|
uint64_t _stream = PyLong_AsLong(stream.ptr());
|
||||||
|
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||||
|
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||||
|
nullptr, config);
|
||||||
|
return py::handle(bin);
|
||||||
|
});
|
||||||
|
|
||||||
// query maximum shared memory
|
// query maximum shared memory
|
||||||
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
|
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
|
||||||
if (backend == HOST)
|
if (backend == HOST)
|
||||||
|
@@ -464,6 +464,7 @@ class LoadedBinary:
|
|||||||
self.module = module
|
self.module = module
|
||||||
self.kernel = kernel
|
self.kernel = kernel
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.shared_mem = bin.shared_mem
|
||||||
|
|
||||||
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
||||||
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
|
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
|
||||||
@@ -548,16 +549,6 @@ class Kernel:
|
|||||||
name = Kernel._type_name(obj)
|
name = Kernel._type_name(obj)
|
||||||
return type_map[name](context)
|
return type_map[name](context)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _types_key(*wargs, tensor_idxs):
|
|
||||||
# type inference
|
|
||||||
types_key = [None] * len(wargs)
|
|
||||||
for i, arg in enumerate(wargs):
|
|
||||||
prefix = 'P' if i in tensor_idxs else ''
|
|
||||||
suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg)
|
|
||||||
types_key[i] = prefix + suffix
|
|
||||||
return tuple(types_key)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pow2_divisor(N):
|
def pow2_divisor(N):
|
||||||
if N % 16 == 0: return 16
|
if N % 16 == 0: return 16
|
||||||
@@ -599,6 +590,53 @@ class Kernel:
|
|||||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||||
return Binary(backend, name, asm, shared_mem, num_warps)
|
return Binary(backend, name, asm, shared_mem, num_warps)
|
||||||
|
|
||||||
|
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
||||||
|
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||||
|
# attributes
|
||||||
|
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||||
|
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
||||||
|
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||||
|
|
||||||
|
# transforms ints whose value is one into constants for just-in-time compilation
|
||||||
|
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||||
|
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||||
|
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
# create cache directory
|
||||||
|
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
||||||
|
if cache_dir and not os.path.exists(cache_dir):
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if cache_dir:
|
||||||
|
bin_cache_path = os.path.join(cache_dir, hashed_key)
|
||||||
|
bin_lock_path = bin_cache_path + ".lock"
|
||||||
|
else:
|
||||||
|
bin_cache_path = None
|
||||||
|
bin_lock_path = None
|
||||||
|
|
||||||
|
binary = None
|
||||||
|
if bin_cache_path and os.path.exists(bin_cache_path):
|
||||||
|
assert bin_lock_path is not None
|
||||||
|
with FileLock(bin_lock_path):
|
||||||
|
with open(bin_cache_path, 'rb') as f:
|
||||||
|
binary = pickle.load(f)["binary"]
|
||||||
|
if binary is None:
|
||||||
|
binary = self._compile(
|
||||||
|
*wargs, device=device_idx, attributes=attributes,
|
||||||
|
num_warps=num_warps, num_stages=num_stages,
|
||||||
|
constants=constants,
|
||||||
|
)
|
||||||
|
if bin_cache_path:
|
||||||
|
assert bin_lock_path is not None
|
||||||
|
with FileLock(bin_lock_path):
|
||||||
|
with open(bin_cache_path + ".tmp", "wb") as f:
|
||||||
|
pickle.dump({"binary": binary, "key": key}, f)
|
||||||
|
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||||
|
if JITFunction.cache_hook is not None:
|
||||||
|
JITFunction.cache_hook(key=key, binary=binary)
|
||||||
|
|
||||||
|
self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)
|
||||||
|
|
||||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||||
# handle arguments passed by name
|
# handle arguments passed by name
|
||||||
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
|
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
|
||||||
@@ -608,112 +646,21 @@ class Kernel:
|
|||||||
if len(wargs) != len(self.fn.arg_names):
|
if len(wargs) != len(self.fn.arg_names):
|
||||||
raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
|
raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
|
||||||
# handle annotations
|
# handle annotations
|
||||||
for name, type in self.fn.__annotations__.items():
|
for pos, _type in self.fn.annotations.items():
|
||||||
pos = self.fn.arg_names.index(name)
|
wargs[pos] = _type(wargs[pos])
|
||||||
assert type == triton.language.core.constexpr
|
# query device index and cuda stream
|
||||||
wargs[pos] = type(wargs[pos])
|
device = torch.cuda.current_device()
|
||||||
# device inference
|
# query stream
|
||||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
# this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream`
|
||||||
if len(tensor_idxs) == 0:
|
# https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154
|
||||||
raise ValueError("No Tensor argument found.")
|
# building a C wrapper to re-use the unpack function would add a build-time torch dependency
|
||||||
invalid_args = []
|
# and require different wheels for different torch versions -- undesirable!
|
||||||
device_ids = []
|
bits = torch._C._cuda_getCurrentStream(device)
|
||||||
for idx in tensor_idxs:
|
mask = 1 << 47
|
||||||
curr = wargs[idx]
|
stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask
|
||||||
if not curr.is_cuda:
|
# make key for cache
|
||||||
invalid_args.append(idx)
|
return _triton.runtime.launch(wargs, self.fn.cache_key, self.fn.arg_names, device, stream,
|
||||||
else:
|
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
|
||||||
device_ids.append(curr.device.index)
|
|
||||||
if invalid_args:
|
|
||||||
raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) +
|
|
||||||
" Only CUDA is supported at the moment")
|
|
||||||
|
|
||||||
device = torch.device('cuda', torch.cuda.current_device())
|
|
||||||
device_idx = device.index
|
|
||||||
# if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
|
|
||||||
# # try to enable P2P communication
|
|
||||||
# for arg_idx, dst_idx in zip(tensor_idxs, device_ids):
|
|
||||||
# if dst_idx != device_idx:
|
|
||||||
# try:
|
|
||||||
# _triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr())
|
|
||||||
# except RuntimeError as e:
|
|
||||||
# raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}"
|
|
||||||
# .format(device_idx, dst_idx, str(e)))
|
|
||||||
|
|
||||||
# enqueue kernel on the current device
|
|
||||||
torch.cuda.set_device(device_idx)
|
|
||||||
# attributes
|
|
||||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
|
||||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
|
|
||||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
|
||||||
|
|
||||||
# transforms ints whose value is one into constants for just-in-time compilation
|
|
||||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
|
||||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
|
||||||
|
|
||||||
# compute hash for caching this kernel
|
|
||||||
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
|
|
||||||
attr_key = tuple(attributes.items())
|
|
||||||
const_key = tuple(constants.items())
|
|
||||||
compute_capability = torch.cuda.get_device_capability(device)
|
|
||||||
key = (
|
|
||||||
self.fn.cache_key, version_key(), compute_capability,
|
|
||||||
types_key, attr_key, num_warps, num_stages, const_key
|
|
||||||
)
|
|
||||||
key = repr(key)
|
|
||||||
|
|
||||||
# get cached binary
|
|
||||||
drv_cache = self.fn.drv_cache
|
|
||||||
|
|
||||||
if key not in drv_cache:
|
|
||||||
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
|
||||||
|
|
||||||
# create cache directory
|
|
||||||
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
|
|
||||||
if cache_dir and not os.path.exists(cache_dir):
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
|
|
||||||
if cache_dir:
|
|
||||||
bin_cache_path = os.path.join(cache_dir, hashed_key)
|
|
||||||
bin_lock_path = bin_cache_path + ".lock"
|
|
||||||
else:
|
|
||||||
bin_cache_path = None
|
|
||||||
bin_lock_path = None
|
|
||||||
|
|
||||||
binary = None
|
|
||||||
if bin_cache_path and os.path.exists(bin_cache_path):
|
|
||||||
assert bin_lock_path is not None
|
|
||||||
with FileLock(bin_lock_path):
|
|
||||||
with open(bin_cache_path, 'rb') as f:
|
|
||||||
binary = pickle.load(f)["binary"]
|
|
||||||
if binary is None:
|
|
||||||
binary = self._compile(
|
|
||||||
*wargs, device=device_idx, attributes=attributes,
|
|
||||||
num_warps=num_warps, num_stages=num_stages,
|
|
||||||
constants=constants,
|
|
||||||
)
|
|
||||||
if bin_cache_path:
|
|
||||||
assert bin_lock_path is not None
|
|
||||||
with FileLock(bin_lock_path):
|
|
||||||
with open(bin_cache_path + ".tmp", "wb") as f:
|
|
||||||
pickle.dump({"binary": binary, "key": key}, f)
|
|
||||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
|
||||||
if JITFunction.cache_hook is not None:
|
|
||||||
JITFunction.cache_hook(key=key, binary=binary)
|
|
||||||
|
|
||||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
|
||||||
# pack arguments
|
|
||||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)])
|
|
||||||
params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)])
|
|
||||||
# enqueue cached function into stream
|
|
||||||
callable = drv_cache[key]
|
|
||||||
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
|
||||||
csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)}
|
|
||||||
grid = grid(csts) if hasattr(grid, '__call__') else grid
|
|
||||||
if isinstance(grid, int):
|
|
||||||
grid = tuple(grid)
|
|
||||||
callable(stream, params, *grid)
|
|
||||||
return callable
|
|
||||||
|
|
||||||
|
|
||||||
class Launcher:
|
class Launcher:
|
||||||
@@ -725,6 +672,7 @@ class Launcher:
|
|||||||
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Autotuner:
|
class Autotuner:
|
||||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero):
|
def __init__(self, kernel, arg_names, configs, key, reset_to_zero):
|
||||||
if not configs:
|
if not configs:
|
||||||
@@ -773,6 +721,11 @@ class Autotuner:
|
|||||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache()
|
||||||
|
def compute_capability():
|
||||||
|
device = torch.device('cuda', 0)
|
||||||
|
return '-'.join(map(str, torch.cuda.get_device_capability(device)))
|
||||||
|
|
||||||
@functools.lru_cache()
|
@functools.lru_cache()
|
||||||
def version_key():
|
def version_key():
|
||||||
import pkgutil
|
import pkgutil
|
||||||
@@ -784,22 +737,27 @@ def version_key():
|
|||||||
with open(triton._C.libtriton.__file__, "rb") as f:
|
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||||
# language
|
# language
|
||||||
for lib in pkgutil.iter_modules(triton.language.__path__):
|
language_path = os.path.join(*triton.__path__, 'language')
|
||||||
|
for lib in pkgutil.iter_modules([language_path]):
|
||||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||||
# ptxas version
|
# ptxas version
|
||||||
try:
|
try:
|
||||||
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
||||||
except Exception:
|
except Exception:
|
||||||
ptxas_version = None
|
ptxas_version = ''
|
||||||
return (triton.__version__, ptxas_version) + tuple(contents)
|
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||||
|
|
||||||
class JITFunction:
|
class JITFunction:
|
||||||
|
|
||||||
cache_hook = None
|
cache_hook = None
|
||||||
|
|
||||||
def _set_cache_key(self):
|
def _set_cache_key(self):
|
||||||
self.cache_key = (hashlib.md5(self.src.encode("utf-8")).hexdigest(), self.version)
|
self.cache_key = hashlib.md5(self.src.encode("utf-8")).hexdigest()
|
||||||
|
self.cache_key += str(self.version)
|
||||||
|
self.cache_key += version_key()
|
||||||
|
self.cache_key += compute_capability()
|
||||||
|
self.cache_key = hashlib.md5(self.cache_key.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||||
# information of wrapped function
|
# information of wrapped function
|
||||||
@@ -811,7 +769,7 @@ class JITFunction:
|
|||||||
self.do_not_specialize = [] if do_not_specialize is None else\
|
self.do_not_specialize = [] if do_not_specialize is None else\
|
||||||
[self.arg_names.index(arg) for arg in do_not_specialize]
|
[self.arg_names.index(arg) for arg in do_not_specialize]
|
||||||
# cache for callable driver objects (e.g. CUkernel)
|
# cache for callable driver objects (e.g. CUkernel)
|
||||||
self.drv_cache = dict()
|
self.bin_cache = dict()
|
||||||
# cache for binaries (on-disk)
|
# cache for binaries (on-disk)
|
||||||
self._set_cache_key()
|
self._set_cache_key()
|
||||||
# JITFunction can be instantiated as kernel
|
# JITFunction can be instantiated as kernel
|
||||||
@@ -819,6 +777,7 @@ class JITFunction:
|
|||||||
self.kernel_decorators = []
|
self.kernel_decorators = []
|
||||||
self.kernel = None
|
self.kernel = None
|
||||||
# annotations
|
# annotations
|
||||||
|
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
||||||
self.__annotations__ = fn.__annotations__
|
self.__annotations__ = fn.__annotations__
|
||||||
# forward docs
|
# forward docs
|
||||||
self.__doc__ = fn.__doc__
|
self.__doc__ = fn.__doc__
|
||||||
@@ -834,7 +793,7 @@ class JITFunction:
|
|||||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
def __call__(self, *args, generator: CodeGenerator, **meta):
|
def __call__(self, *args, generator: CodeGenerator):
|
||||||
try:
|
try:
|
||||||
gscope = generator.gscope.copy()
|
gscope = generator.gscope.copy()
|
||||||
lscope = generator.lscope.copy()
|
lscope = generator.lscope.copy()
|
||||||
|
@@ -119,6 +119,9 @@ class block:
|
|||||||
self.shape = (1, )
|
self.shape = (1, )
|
||||||
if self.handle.type.is_block():
|
if self.handle.type.is_block():
|
||||||
self.shape = self.handle.type.shape
|
self.shape = self.handle.type.shape
|
||||||
|
self.numel = 1
|
||||||
|
for s in self.shape:
|
||||||
|
self.numel *= s
|
||||||
# Data-type wrapper
|
# Data-type wrapper
|
||||||
self.dtype = block._init_dtype(self.handle.type.scalar)
|
self.dtype = block._init_dtype(self.handle.type.scalar)
|
||||||
|
|
||||||
@@ -352,6 +355,13 @@ def program_id(axis, _builder=None):
|
|||||||
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
||||||
:type axis: int
|
:type axis: int
|
||||||
"""
|
"""
|
||||||
|
# if axis == -1:
|
||||||
|
# pid0 = frontend.program_id(0, _builder)
|
||||||
|
# pid1 = frontend.program_id(1, _builder)
|
||||||
|
# pid2 = frontend.program_id(2, _builder)
|
||||||
|
# npg0 = frontend.num_programs(0, _builder)
|
||||||
|
# npg1 = frontend.num_programs(0, _builder)
|
||||||
|
# return pid0 + pid1*npg0 + pid2*npg0*npg1
|
||||||
return frontend.program_id(axis, _builder)
|
return frontend.program_id(axis, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user