[FRONTEND] add python e2e launch empty kernel test (#68)

This commit is contained in:
Yan Chunwei
2022-08-20 01:46:01 +08:00
committed by GitHub
parent 9aa00249a6
commit 10ba51c3bb
11 changed files with 311 additions and 69 deletions

View File

@@ -6,9 +6,11 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -254,19 +256,105 @@ void parse_args(py::list &args, py::list do_not_specialize,
params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
}
//
void parse_args(py::list &args, py::list &arg_names, std::string &params,
size_t &params_size, py::dict constants) {
char *params_ptr = params.data();
size_t len = PyList_Size(args.ptr());
for (int i = 0; i < len; i++) {
py::object arg = args[i];
auto arg_ptr = arg.ptr();
if (PyLong_Check(arg_ptr)) {
int overflow{};
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow && 0x8000'0000LL <= value &&
value <= 0xFFFF'FFFFLL) {
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow) {
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
} else {
if (PyErr_Occurred()) {
throw std::logic_error("An error occurred?");
}
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
if (PyErr_Occurred()) {
throw std::runtime_error("integer overflow in argument: " +
std::string(py::str(arg)));
}
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &unsigned_value, 8);
params_ptr += 8;
}
continue;
}
if (PyFloat_Check(arg_ptr)) {
float value = PyFloat_AsDouble(arg_ptr);
params_ptr = (char *)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
continue;
}
// argument is `bool`
if (PyBool_Check(arg_ptr)) {
bool value = arg_ptr == Py_True ? true : false;
std::memcpy(params_ptr, &value, 1);
params_ptr += 1;
continue;
}
// argument is torch.tensor, get data_ptr as memory address
if (py::hasattr(arg, "data_ptr")) {
py::object data_ptr = arg.attr("data_ptr")();
long value = data_ptr.cast<long>();
params_ptr = (char *)(((uintptr_t)params_ptr + 7) & (-8));
// copy param
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
// udpate cache key
continue;
}
// argument is `constexpr`
if (py::hasattr(arg, "value")) {
py::object value = arg.attr("value");
py::object name = arg_names[i];
constants[name] = value;
continue;
}
// argument is `LoadedBinary`
if (py::hasattr(arg, "get_sass")) {
// Do nothing, just a placeholder here to indicate validity.
continue;
}
std::string ty_str =
arg.attr("__class__").attr("__name__").cast<std::string>();
std::string err_msg = "Received type '" + ty_str + "' for argument " +
std::to_string(i) + "." +
" Only int, float, bool, torch.Tensor, and "
"triton.language.constexpr are supported.";
throw std::runtime_error(err_msg);
}
params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
}
void init_triton_runtime(py::module &&m) {
// m.def("current_stream", [](uint64_t device){
// return (uint64_t)(c10::cuda::getCurrentCUDAStream(device).stream());
// });
// wrap backend_t
py::enum_<backend_t>(m, "backend")
.value("HOST", HOST)
.value("CUDA", CUDA)
.value("ROCM", ROCM)
// .value("ROCM", ROCM)
.export_values();
// enable peer-to-peer
@@ -347,6 +435,49 @@ void init_triton_runtime(py::module &&m) {
return -1;
});
m.def("launch_binary", [](py::object binary, py::list args,
py::list do_not_specialize, py::list arg_names,
py::int_ stream, py::int_ num_warps,
py::int_ num_stages, py::object grid) {
long _num_warps = PyLong_AsLong(num_warps.ptr());
long _num_stages = PyLong_AsLong(num_stages.ptr());
// get grid
py::sequence seq;
py::dict constants;
std::string params;
size_t params_size{};
parse_args(args, arg_names, params, params_size, constants);
if (!PySequence_Check(grid.ptr()))
seq = grid(constants);
else
seq = grid;
int size = seq.size();
int grid_0 = py::cast<int>(seq[0]);
int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
uint64_t kernel = py::cast<uint64_t>(binary.attr("kernel"));
uint64_t shared_mem = py::cast<uint64_t>(binary.attr("shared_mem"));
// actually launch
void *config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
CU_LAUNCH_PARAM_BUFFER_SIZE, &params_size,
CU_LAUNCH_PARAM_END};
uint64_t _stream = PyLong_AsLong(stream.ptr());
const int numGrids = grid_0 * grid_1 * grid_2;
if (numGrids) {
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
py::gil_scoped_release allow_threads;
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps * 32, 1, 1, shared_mem,
(CUstream)_stream, nullptr, config);
}
return binary;
});
// query maximum shared memory
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
if (backend == HOST)
@@ -1517,7 +1648,7 @@ void init_triton_ir(py::module &&m) {
});
}
void init_translation(py::module &m) {
void init_triton_translation(py::module &m) {
m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string {
llvm::LLVMContext llvmContext;
auto llvmModule =
@@ -1531,10 +1662,16 @@ void init_translation(py::module &m) {
});
m.def("translate_triton_gpu_to_ptx",
[](mlir::ModuleOp module, uint64_t device) -> std::string {
[](mlir::ModuleOp module, uint64_t device)
-> std::tuple<std::string /*ptx code*/, size_t /*shem size*/> {
auto [ptxCode, cc, version, ptxasPath] =
triton::translateTritonGPUToPTX(module, device);
return ptxCode;
mlir::PassManager pm(module->getContext());
auto pass = std::make_unique<mlir::Allocation>(module);
size_t size = pass->getSharedMemorySize();
return std::make_tuple(ptxCode, size);
});
m.def("compile_ptx_to_cubin",
@@ -1550,6 +1687,16 @@ void init_translation(py::module &m) {
py::bytes bytes(cubin);
return bytes;
});
m.def(
"load_binary",
[](backend_t backend, const std::string &name, asm_map_t &asm_map,
size_t n_shared_bytes, uint64_t dev) {
py::gil_scoped_release allow_threads;
assert(backend == CUDA); // Only CUDA is supported now.
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
},
py::return_value_policy::take_ownership);
}
void init_triton(py::module &m) {
@@ -1557,5 +1704,5 @@ void init_triton(py::module &m) {
// init_triton_codegen(std::move(subm.def_submodule("code_gen")));
init_triton_runtime(std::move(subm.def_submodule("runtime")));
init_triton_ir(std::move(subm.def_submodule("ir")));
init_translation(subm);
init_triton_translation(subm);
}