[FRONTEND] Significantly reduce kernel launch time (#367)
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include "Python.h"
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include "llvm/IR/Module.h"
|
||||
@@ -23,6 +24,7 @@ namespace py = pybind11;
|
||||
namespace ir = triton::ir;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::driver */
|
||||
/*****************************************************************************/
|
||||
@@ -99,8 +101,113 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
|
||||
|
||||
}
|
||||
|
||||
std::string pow2_divisor(long N){
|
||||
if(N % 16 == 0) return "16";
|
||||
if(N % 8 == 0) return "8";
|
||||
if(N % 4 == 0) return "4";
|
||||
if(N % 2 == 0) return "2";
|
||||
return "1";
|
||||
}
|
||||
|
||||
// Launch
|
||||
void parse_args(py::handle& args, const std::string& func_key, py::handle& arg_names,
|
||||
std::string& cache_key, std::string& params, size_t& params_size, PyObject* constants,
|
||||
int num_warps, int num_stages) {
|
||||
size_t len = PyList_Size(args.ptr());
|
||||
params.reserve(8*len); // 8 max bytes by argument
|
||||
char* params_ptr = ¶ms[0];
|
||||
cache_key = func_key;
|
||||
for(int i = 0; i < len; i++){
|
||||
auto arg_ptr = PyList_GetItem(args.ptr(), i);
|
||||
auto arg = py::handle(arg_ptr);
|
||||
// argument is `long`
|
||||
if(PyLong_Check(arg_ptr)){
|
||||
int overflow;
|
||||
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
|
||||
// long and int have different kernels
|
||||
if(!overflow & (std::abs(value) <= 0xffffffff)){
|
||||
cache_key += 'I';
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
}
|
||||
else{
|
||||
cache_key += 'L';
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
if(overflow){
|
||||
unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr);
|
||||
std::memcpy(&value, &uvalue, 8);
|
||||
}
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
}
|
||||
// values equal to 1 are specialized
|
||||
if(value == 1)
|
||||
cache_key += '1';
|
||||
else
|
||||
cache_key += 'x';
|
||||
// values divisible by small powers of 2 are specialized
|
||||
cache_key += pow2_divisor(value);
|
||||
continue;
|
||||
}
|
||||
// argument is `float`
|
||||
if(PyFloat_Check(arg_ptr)){
|
||||
cache_key += "f";
|
||||
float value = PyFloat_AsDouble(arg_ptr);
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
continue;
|
||||
}
|
||||
// argument is `bool`
|
||||
if(PyBool_Check(arg_ptr)){
|
||||
cache_key += "B";
|
||||
bool value = arg_ptr == Py_True ? true : false;
|
||||
std::memcpy(params_ptr, &value, 1);
|
||||
params_ptr += 1;
|
||||
continue;
|
||||
}
|
||||
// argument is tensor
|
||||
PyObject* data_ptr = PyObject_CallMethod(arg_ptr, "data_ptr", nullptr);
|
||||
if(data_ptr){
|
||||
cache_key += "P";
|
||||
long value = PyLong_AsLong(data_ptr);
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
PyObject* dtype = PyObject_GetAttrString(arg_ptr, "dtype");
|
||||
PyObject* repr = PyObject_Repr(dtype);
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr) + 6; // remove 'torch.'
|
||||
size_t len = PyUnicode_GET_LENGTH(repr) - 6;
|
||||
cache_key += std::string(start, len);
|
||||
continue;
|
||||
}
|
||||
// argument is `constexpr`
|
||||
PyObject* value = PyObject_GetAttrString(arg_ptr, "value");
|
||||
if(value){
|
||||
PyObject* name = PyList_GetItem(arg_names.ptr(), i);
|
||||
PyDict_SetItem(constants, name, value);
|
||||
PyObject* repr = PyObject_Repr(value);
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr);
|
||||
size_t len = PyUnicode_GET_LENGTH(repr);
|
||||
cache_key += std::string(start, len);
|
||||
continue;
|
||||
}
|
||||
assert(false);
|
||||
}
|
||||
cache_key += std::to_string(num_warps);
|
||||
cache_key += std::to_string(num_stages);
|
||||
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
void init_triton_runtime(py::module &&m) {
|
||||
|
||||
// m.def("current_stream", [](uint64_t device){
|
||||
// return (uint64_t)(c10::cuda::getCurrentCUDAStream(device).stream());
|
||||
// });
|
||||
|
||||
// wrap backend_t
|
||||
py::enum_<backend_t>(m, "backend")
|
||||
.value("HOST", HOST)
|
||||
@@ -116,6 +223,51 @@ void init_triton_runtime(py::module &&m) {
|
||||
}
|
||||
);
|
||||
|
||||
// cache key
|
||||
m.def("launch", [](py::handle args, const std::string& func_key, py::list& arg_names,
|
||||
py::handle device, py::handle stream, py::handle bin_cache, py::handle num_warps, py::handle num_stages,
|
||||
py::handle add_to_cache, py::handle grid){
|
||||
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
|
||||
long _num_warps = PyLong_AsLong(num_warps.ptr());
|
||||
long _num_stages = PyLong_AsLong(num_stages.ptr());
|
||||
std::string cache_key;
|
||||
std::string params;
|
||||
size_t params_size;
|
||||
PyObject* constants = PyDict_New();
|
||||
parse_args(args, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages);
|
||||
// get cached binary
|
||||
PyObject* key = PyUnicode_FromString(cache_key.c_str());
|
||||
PyObject* bin = nullptr;
|
||||
if(!PyDict_Contains(bin_cache.ptr(), key)){
|
||||
add_to_cache(py::handle(key), args, device, num_warps, num_stages);
|
||||
}
|
||||
bin = PyDict_GetItem(bin_cache.ptr(), key);
|
||||
// get grid
|
||||
PyObject* grid_ptr = grid.ptr();
|
||||
if(!PySequence_Check(grid_ptr)){
|
||||
PyObject* grid_call = PyObject_GetAttrString(grid_ptr, "__call__");
|
||||
grid_ptr = PyObject_Call(grid_call, PyTuple_Pack(1, constants), nullptr);
|
||||
}
|
||||
int size = PySequence_Size(grid_ptr);
|
||||
int grid_0 = PyLong_AsLong(PySequence_GetItem(grid_ptr, 0));
|
||||
int grid_1 = size < 2 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 1));
|
||||
int grid_2 = size < 3 ? 1 : PyLong_AsLong(PySequence_GetItem(grid_ptr, 2));
|
||||
// enqueue
|
||||
uint64_t kernel = PyLong_AsLong(PyObject_GetAttrString(bin, "kernel"));
|
||||
uint64_t shared_mem = PyLong_AsLong(PyObject_GetAttrString(bin, "shared_mem"));
|
||||
// actually launch
|
||||
void *config[] = {
|
||||
CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, ¶ms_size,
|
||||
CU_LAUNCH_PARAM_END
|
||||
};
|
||||
uint64_t _stream = PyLong_AsLong(stream.ptr());
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||
nullptr, config);
|
||||
return py::handle(bin);
|
||||
});
|
||||
|
||||
// query maximum shared memory
|
||||
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
|
||||
if (backend == HOST)
|
||||
|
Reference in New Issue
Block a user