[FRONTEND] add python e2e launch empty kernel test (#68)
This commit is contained in:
@@ -6,9 +6,11 @@
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
@@ -254,19 +256,105 @@ void parse_args(py::list &args, py::list do_not_specialize,
|
||||
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
||||
}
|
||||
|
||||
//
|
||||
void parse_args(py::list &args, py::list &arg_names, std::string ¶ms,
|
||||
size_t ¶ms_size, py::dict constants) {
|
||||
char *params_ptr = params.data();
|
||||
|
||||
size_t len = PyList_Size(args.ptr());
|
||||
for (int i = 0; i < len; i++) {
|
||||
py::object arg = args[i];
|
||||
auto arg_ptr = arg.ptr();
|
||||
|
||||
if (PyLong_Check(arg_ptr)) {
|
||||
int overflow{};
|
||||
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
|
||||
|
||||
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
} else if (!overflow && 0x8000'0000LL <= value &&
|
||||
value <= 0xFFFF'FFFFLL) {
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
} else if (!overflow) {
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
} else {
|
||||
if (PyErr_Occurred()) {
|
||||
throw std::logic_error("An error occurred?");
|
||||
}
|
||||
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
|
||||
if (PyErr_Occurred()) {
|
||||
throw std::runtime_error("integer overflow in argument: " +
|
||||
std::string(py::str(arg)));
|
||||
}
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &unsigned_value, 8);
|
||||
params_ptr += 8;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (PyFloat_Check(arg_ptr)) {
|
||||
float value = PyFloat_AsDouble(arg_ptr);
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
// argument is `bool`
|
||||
if (PyBool_Check(arg_ptr)) {
|
||||
bool value = arg_ptr == Py_True ? true : false;
|
||||
std::memcpy(params_ptr, &value, 1);
|
||||
params_ptr += 1;
|
||||
continue;
|
||||
}
|
||||
// argument is torch.tensor, get data_ptr as memory address
|
||||
if (py::hasattr(arg, "data_ptr")) {
|
||||
py::object data_ptr = arg.attr("data_ptr")();
|
||||
long value = data_ptr.cast<long>();
|
||||
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
// copy param
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
// udpate cache key
|
||||
continue;
|
||||
}
|
||||
// argument is `constexpr`
|
||||
if (py::hasattr(arg, "value")) {
|
||||
py::object value = arg.attr("value");
|
||||
py::object name = arg_names[i];
|
||||
constants[name] = value;
|
||||
continue;
|
||||
}
|
||||
// argument is `LoadedBinary`
|
||||
if (py::hasattr(arg, "get_sass")) {
|
||||
// Do nothing, just a placeholder here to indicate validity.
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string ty_str =
|
||||
arg.attr("__class__").attr("__name__").cast<std::string>();
|
||||
std::string err_msg = "Received type '" + ty_str + "' for argument " +
|
||||
std::to_string(i) + "." +
|
||||
" Only int, float, bool, torch.Tensor, and "
|
||||
"triton.language.constexpr are supported.";
|
||||
throw std::runtime_error(err_msg);
|
||||
}
|
||||
|
||||
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
||||
}
|
||||
|
||||
void init_triton_runtime(py::module &&m) {
|
||||
|
||||
// m.def("current_stream", [](uint64_t device){
|
||||
// return (uint64_t)(c10::cuda::getCurrentCUDAStream(device).stream());
|
||||
// });
|
||||
|
||||
// wrap backend_t
|
||||
py::enum_<backend_t>(m, "backend")
|
||||
.value("HOST", HOST)
|
||||
.value("CUDA", CUDA)
|
||||
.value("ROCM", ROCM)
|
||||
// .value("ROCM", ROCM)
|
||||
.export_values();
|
||||
|
||||
// enable peer-to-peer
|
||||
@@ -347,6 +435,49 @@ void init_triton_runtime(py::module &&m) {
|
||||
return -1;
|
||||
});
|
||||
|
||||
m.def("launch_binary", [](py::object binary, py::list args,
|
||||
py::list do_not_specialize, py::list arg_names,
|
||||
py::int_ stream, py::int_ num_warps,
|
||||
py::int_ num_stages, py::object grid) {
|
||||
long _num_warps = PyLong_AsLong(num_warps.ptr());
|
||||
long _num_stages = PyLong_AsLong(num_stages.ptr());
|
||||
|
||||
// get grid
|
||||
py::sequence seq;
|
||||
py::dict constants;
|
||||
std::string params;
|
||||
size_t params_size{};
|
||||
parse_args(args, arg_names, params, params_size, constants);
|
||||
if (!PySequence_Check(grid.ptr()))
|
||||
seq = grid(constants);
|
||||
else
|
||||
seq = grid;
|
||||
|
||||
int size = seq.size();
|
||||
int grid_0 = py::cast<int>(seq[0]);
|
||||
int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
|
||||
int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
|
||||
|
||||
uint64_t kernel = py::cast<uint64_t>(binary.attr("kernel"));
|
||||
uint64_t shared_mem = py::cast<uint64_t>(binary.attr("shared_mem"));
|
||||
|
||||
// actually launch
|
||||
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, ¶ms_size,
|
||||
CU_LAUNCH_PARAM_END};
|
||||
uint64_t _stream = PyLong_AsLong(stream.ptr());
|
||||
const int numGrids = grid_0 * grid_1 * grid_2;
|
||||
if (numGrids) {
|
||||
// release the gil in case the enqueue blocks
|
||||
// cuda will block if too many ops are enqueued
|
||||
py::gil_scoped_release allow_threads;
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps * 32, 1, 1, shared_mem,
|
||||
(CUstream)_stream, nullptr, config);
|
||||
}
|
||||
return binary;
|
||||
});
|
||||
|
||||
// query maximum shared memory
|
||||
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
|
||||
if (backend == HOST)
|
||||
@@ -1517,7 +1648,7 @@ void init_triton_ir(py::module &&m) {
|
||||
});
|
||||
}
|
||||
|
||||
void init_translation(py::module &m) {
|
||||
void init_triton_translation(py::module &m) {
|
||||
m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string {
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule =
|
||||
@@ -1531,10 +1662,16 @@ void init_translation(py::module &m) {
|
||||
});
|
||||
|
||||
m.def("translate_triton_gpu_to_ptx",
|
||||
[](mlir::ModuleOp module, uint64_t device) -> std::string {
|
||||
[](mlir::ModuleOp module, uint64_t device)
|
||||
-> std::tuple<std::string /*ptx code*/, size_t /*shem size*/> {
|
||||
auto [ptxCode, cc, version, ptxasPath] =
|
||||
triton::translateTritonGPUToPTX(module, device);
|
||||
return ptxCode;
|
||||
|
||||
mlir::PassManager pm(module->getContext());
|
||||
auto pass = std::make_unique<mlir::Allocation>(module);
|
||||
size_t size = pass->getSharedMemorySize();
|
||||
|
||||
return std::make_tuple(ptxCode, size);
|
||||
});
|
||||
|
||||
m.def("compile_ptx_to_cubin",
|
||||
@@ -1550,6 +1687,16 @@ void init_translation(py::module &m) {
|
||||
py::bytes bytes(cubin);
|
||||
return bytes;
|
||||
});
|
||||
|
||||
m.def(
|
||||
"load_binary",
|
||||
[](backend_t backend, const std::string &name, asm_map_t &asm_map,
|
||||
size_t n_shared_bytes, uint64_t dev) {
|
||||
py::gil_scoped_release allow_threads;
|
||||
assert(backend == CUDA); // Only CUDA is supported now.
|
||||
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
@@ -1557,5 +1704,5 @@ void init_triton(py::module &m) {
|
||||
// init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||
init_triton_ir(std::move(subm.def_submodule("ir")));
|
||||
init_translation(subm);
|
||||
init_triton_translation(subm);
|
||||
}
|
||||
|
Reference in New Issue
Block a user