[RUNTIME] Now using pybind11 to avoid memory leaks (#377)
This commit is contained in:
@@ -110,18 +110,19 @@ std::string pow2_divisor(long N){
|
||||
}
|
||||
|
||||
// Launch
|
||||
void parse_args(py::handle& args, py::handle do_not_specialize, const std::string& func_key, py::handle& arg_names,
|
||||
std::string& cache_key, std::string& params, size_t& params_size, PyObject* constants,
|
||||
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
|
||||
int num_warps, int num_stages) {
|
||||
size_t len = PyList_Size(args.ptr());
|
||||
params.reserve(8*len); // 8 max bytes by argument
|
||||
char* params_ptr = ¶ms[0];
|
||||
cache_key = func_key;
|
||||
for(int i = 0; i < len; i++){
|
||||
PyObject* py_i = PyLong_FromLong(i);
|
||||
bool specialize = !PySequence_Contains(do_not_specialize.ptr(), py_i);
|
||||
auto arg_ptr = PyList_GetItem(args.ptr(), i);
|
||||
auto arg = py::handle(arg_ptr);
|
||||
py::int_ py_i = py::int_(i);
|
||||
bool specialize = std::find(do_not_specialize.begin(), do_not_specialize.end(), py_i) == do_not_specialize.end();
|
||||
py::object arg = args[i];
|
||||
auto arg_ptr = arg.ptr();
|
||||
|
||||
// argument is `long`
|
||||
if(PyLong_Check(arg_ptr)){
|
||||
int overflow;
|
||||
@@ -172,28 +173,28 @@ void parse_args(py::handle& args, py::handle do_not_specialize, const std::strin
|
||||
continue;
|
||||
}
|
||||
// argument is tensor
|
||||
PyObject* data_ptr = PyObject_CallMethod(arg_ptr, "data_ptr", nullptr);
|
||||
if(data_ptr){
|
||||
if(py::hasattr(arg, "data_ptr")){
|
||||
py::object data_ptr = arg.attr("data_ptr")();
|
||||
cache_key += "P";
|
||||
long value = PyLong_AsLong(data_ptr);
|
||||
long value = data_ptr.cast<long>();
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
PyObject* dtype = PyObject_GetAttrString(arg_ptr, "dtype");
|
||||
PyObject* repr = PyObject_Repr(dtype);
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr) + 6; // remove 'torch.'
|
||||
size_t len = PyUnicode_GET_LENGTH(repr) - 6;
|
||||
py::object dtype = arg.attr("dtype");
|
||||
py::object repr = py::repr(dtype);
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
|
||||
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
|
||||
cache_key += std::string(start, len);
|
||||
continue;
|
||||
}
|
||||
// argument is `constexpr`
|
||||
PyObject* value = PyObject_GetAttrString(arg_ptr, "value");
|
||||
py::object value = arg.attr("value");
|
||||
if(value){
|
||||
PyObject* name = PyList_GetItem(arg_names.ptr(), i);
|
||||
PyDict_SetItem(constants, name, value);
|
||||
PyObject* repr = PyObject_Repr(value);
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr);
|
||||
size_t len = PyUnicode_GET_LENGTH(repr);
|
||||
py::object name = arg_names[i];
|
||||
constants[name] = value;
|
||||
py::object repr = py::repr(value);
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr());
|
||||
size_t len = PyUnicode_GET_LENGTH(repr.ptr());
|
||||
cache_key += std::string(start, len);
|
||||
continue;
|
||||
}
|
||||
@@ -228,37 +229,39 @@ void init_triton_runtime(py::module &&m) {
|
||||
);
|
||||
|
||||
// cache key
|
||||
m.def("launch", [](py::handle args, py::handle do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
py::handle device, py::handle stream, py::handle bin_cache, py::handle num_warps, py::handle num_stages,
|
||||
py::handle add_to_cache, py::handle grid){
|
||||
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
|
||||
py::function add_to_cache, py::object grid){
|
||||
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
|
||||
long _num_warps = PyLong_AsLong(num_warps.ptr());
|
||||
long _num_stages = PyLong_AsLong(num_stages.ptr());
|
||||
std::string cache_key;
|
||||
std::string params;
|
||||
size_t params_size;
|
||||
PyObject* constants = PyDict_New();
|
||||
py::dict constants;
|
||||
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages);
|
||||
|
||||
// get cached binary
|
||||
PyObject* key = PyUnicode_FromString(cache_key.c_str());
|
||||
PyObject* bin = nullptr;
|
||||
if(!PyDict_Contains(bin_cache.ptr(), key)){
|
||||
add_to_cache(py::handle(key), args, device, num_warps, num_stages);
|
||||
}
|
||||
bin = PyDict_GetItem(bin_cache.ptr(), key);
|
||||
py::str key(cache_key);
|
||||
if(!bin_cache.contains(key))
|
||||
add_to_cache(key, args, device, num_warps, num_stages);
|
||||
py::object bin = bin_cache[key];
|
||||
|
||||
// get grid
|
||||
PyObject* grid_ptr = grid.ptr();
|
||||
if(!PySequence_Check(grid_ptr)){
|
||||
PyObject* grid_call = PyObject_GetAttrString(grid_ptr, "__call__");
|
||||
grid_ptr = PyObject_Call(grid_call, PyTuple_Pack(1, constants), nullptr);
|
||||
}
|
||||
int size = PySequence_Size(grid_ptr);
|
||||
int grid_0 = PyLong_AsLong(PySequence_GetItem(grid_ptr, 0));
|
||||
int grid_1 = size < 2 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 1));
|
||||
int grid_2 = size < 3 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 2));
|
||||
py::sequence seq;
|
||||
if(!PySequence_Check(grid.ptr()))
|
||||
seq = grid(constants);
|
||||
else
|
||||
seq = grid;
|
||||
int size = seq.size();
|
||||
int grid_0 = py::cast<int>(seq[0]);
|
||||
int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
|
||||
int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
|
||||
|
||||
// enqueue
|
||||
uint64_t kernel = PyLong_AsLong(PyObject_GetAttrString(bin, "kernel"));
|
||||
uint64_t shared_mem = PyLong_AsLong(PyObject_GetAttrString(bin, "shared_mem"));
|
||||
uint64_t kernel = py::cast<uint64_t>(bin.attr("kernel"));
|
||||
uint64_t shared_mem = py::cast<uint64_t>(bin.attr("shared_mem"));
|
||||
|
||||
// actually launch
|
||||
void *config[] = {
|
||||
CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
|
||||
@@ -269,7 +272,7 @@ void init_triton_runtime(py::module &&m) {
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||
nullptr, config);
|
||||
return py::handle(bin);
|
||||
return bin;
|
||||
});
|
||||
|
||||
// query maximum shared memory
|
||||
|
Reference in New Issue
Block a user