[FRONTEND] Expose end-to-end compile to python frontend (#58)
This commit is contained in:
@@ -1,6 +1,4 @@
|
||||
// #include "triton/codegen/pass.h"
|
||||
// #include "triton/codegen/target.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
@@ -17,14 +15,15 @@
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
#include "triton/Target/PTX/PTXTranslation.h"
|
||||
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "Python.h"
|
||||
#include <Python.h>
|
||||
#include <optional>
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
@@ -40,21 +39,7 @@ namespace py = pybind11;
|
||||
// namespace ir = triton::ir;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::driver */
|
||||
/*****************************************************************************/
|
||||
// information query
|
||||
template <CUdevice_attribute attr> int cuGetInfo(CUdevice device) {
|
||||
int res;
|
||||
drv::dispatch::cuDeviceGetAttribute(&res, attr, device);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <hipDeviceAttribute_t attr> int hipGetInfo(hipDevice_t device) {
|
||||
int res;
|
||||
drv::dispatch::hipDeviceGetAttribute(&res, attr, device);
|
||||
return res;
|
||||
}
|
||||
using triton::cuGetInfo;
|
||||
|
||||
enum backend_t {
|
||||
HOST,
|
||||
@@ -100,18 +85,6 @@ void cu_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
|
||||
(CUstream)stream, nullptr, config);
|
||||
}
|
||||
|
||||
void hip_enqueue(uint64_t stream, uint64_t kernel, uint64_t grid_0,
|
||||
uint64_t grid_1, uint64_t grid_2, uint64_t block_0,
|
||||
uint64_t block_1, uint64_t block_2, void *args_ptr,
|
||||
size_t args_size, int64_t shared_mem) {
|
||||
void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, (void *)args_ptr,
|
||||
HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
HIP_LAUNCH_PARAM_END};
|
||||
drv::dispatch::hipModuleLaunchKernel(
|
||||
(hipFunction_t)kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2,
|
||||
shared_mem, (hipStream_t)stream, nullptr, config);
|
||||
}
|
||||
|
||||
long pow2_divisor(long N) {
|
||||
if (N % 16 == 0)
|
||||
return 16;
|
||||
@@ -381,8 +354,6 @@ void init_triton_runtime(py::module &&m) {
|
||||
if (backend == CUDA)
|
||||
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(
|
||||
device);
|
||||
if (backend == ROCM)
|
||||
return hipGetInfo<hipDeviceAttributeMaxSharedMemoryPerBlock>(device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
@@ -432,9 +403,6 @@ void init_triton_runtime(py::module &&m) {
|
||||
if (backend == CUDA)
|
||||
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1,
|
||||
block_2, args_ptr, args_size, shared_mem);
|
||||
if (backend == ROCM)
|
||||
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0,
|
||||
block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -487,120 +455,6 @@ std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string &name,
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||
}
|
||||
|
||||
// ROCM
|
||||
std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string &name,
|
||||
asm_map_t &asm_map,
|
||||
size_t n_shared_bytes,
|
||||
uint64_t dev) {
|
||||
py::bytes _assembly = asm_map["hsaco"];
|
||||
std::string assembly = py::cast<std::string>(_assembly);
|
||||
// HSA-CO -> hipModule
|
||||
hipModule_t mod = drv::amdgpu_to_hipmodule(assembly);
|
||||
// Handle to the kernel
|
||||
hipFunction_t fun;
|
||||
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
||||
// record asm
|
||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
|
||||
}
|
||||
|
||||
// ---------------------------------------
|
||||
// Compile Triton-IR to assembly
|
||||
// ---------------------------------------
|
||||
|
||||
// // CUDA
|
||||
// std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string&
|
||||
// name, ir::module &ir,
|
||||
// uint64_t
|
||||
// device, int
|
||||
// num_warps, int
|
||||
// num_stages,
|
||||
// asm_map_t
|
||||
// &asm_map){
|
||||
|
||||
// int n_shared_bytes;
|
||||
// py::gil_scoped_release allow_threads;
|
||||
// llvm::LLVMContext ctx;
|
||||
// // device properties
|
||||
// CUdevice dev = (CUdevice)device;
|
||||
// size_t major =
|
||||
// cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev); size_t minor
|
||||
// = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev); size_t cc =
|
||||
// major*10 + minor; int version; std::string ptxas_path =
|
||||
// drv::path_to_ptxas(version);
|
||||
// // Triton-IR -> NVPTX LLVM-IR
|
||||
// triton::codegen::nvidia_cu_target target(cc);
|
||||
// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc,
|
||||
// num_warps, num_stages, n_shared_bytes); std::string tmp;
|
||||
// llvm::raw_string_ostream llir(tmp);
|
||||
// llir << *llvm;
|
||||
// llir.flush();
|
||||
// asm_map["llir"] = py::cast(tmp);
|
||||
// // LLVM-IR -> PTX
|
||||
// std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
|
||||
// asm_map["ptx"] = py::cast(ptx);
|
||||
// // PTX -> Binary
|
||||
// std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
|
||||
// if(!cubin.empty()){
|
||||
// py::bytes bytes(cubin);
|
||||
// asm_map["cubin"] = bytes;
|
||||
// }
|
||||
// return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||
// }
|
||||
|
||||
// // HIP
|
||||
// std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string&
|
||||
// name, ir::module &ir,
|
||||
// uint64_t
|
||||
// device, int
|
||||
// num_warps,
|
||||
// int
|
||||
// num_stages,
|
||||
// asm_map_t
|
||||
// &asm_map){
|
||||
// llvm::LLVMContext ctx;
|
||||
// // Triton-IR -> NVPTX LLVM-IR
|
||||
// triton::codegen::amd_cl_target target;
|
||||
// int n_shared_bytes;
|
||||
// auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70,
|
||||
// num_warps, num_stages, n_shared_bytes); std::string tmp;
|
||||
// llvm::raw_string_ostream llir(tmp);
|
||||
// llir << *llvm;
|
||||
// llir.flush();
|
||||
// asm_map["llir"] = py::cast(tmp);
|
||||
// // LLVM-IR -> HSA-CO
|
||||
// std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908");
|
||||
// asm_map["hsaco"] = py::cast(path);
|
||||
// return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||
// }
|
||||
|
||||
// void init_triton_codegen(py::module &&m) {
|
||||
// m.def(
|
||||
// "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device,
|
||||
// int num_warps, int num_stages) {
|
||||
// std::string name = ir.get_function_list()[0]->get_name();
|
||||
// // record asm as we generate
|
||||
// asm_map_t asm_map;
|
||||
// std::ostringstream ttir;
|
||||
// ir.print(ttir);
|
||||
// asm_map["ttir"] = py::cast(ttir.str());
|
||||
// llvm::LLVMContext ctx;
|
||||
// if(backend == CUDA)
|
||||
// return cu_compile_ttir(name, ir, device, num_warps, num_stages,
|
||||
// asm_map);
|
||||
// if(backend == ROCM)
|
||||
// return hip_compile_ttir(name, ir, device, num_warps, num_stages,
|
||||
// asm_map);
|
||||
// }, py::return_value_policy::take_ownership);
|
||||
// m.def("load_binary", [](backend_t backend, const std::string& name,
|
||||
// asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||
// py::gil_scoped_release allow_threads;
|
||||
// if(backend == CUDA)
|
||||
// return cu_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||
// if(backend == ROCM)
|
||||
// return hip_load_binary(name, asm_map, n_shared_bytes, dev);
|
||||
// }, py::return_value_policy::take_ownership);
|
||||
// }
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::ir */
|
||||
/*****************************************************************************/
|
||||
@@ -1655,9 +1509,45 @@ void init_triton_ir(py::module &&m) {
|
||||
});
|
||||
}
|
||||
|
||||
void init_translation(py::module &m) {
|
||||
m.def("translate_triton_gpu_to_llvmir", [](mlir::ModuleOp op) -> std::string {
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule =
|
||||
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
|
||||
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
llvmModule->print(os, nullptr);
|
||||
os.flush();
|
||||
return str;
|
||||
});
|
||||
|
||||
m.def("translate_triton_gpu_to_ptx",
|
||||
[](mlir::ModuleOp module, uint64_t device) -> std::string {
|
||||
auto [ptxCode, cc, version, ptxasPath] =
|
||||
triton::translateTritonGPUToPTX(module, device);
|
||||
return ptxCode;
|
||||
});
|
||||
|
||||
m.def("compile_ptx_to_cubin",
|
||||
[](const std::string &ptxCode, uint64_t device) -> py::object {
|
||||
py::gil_scoped_release allow_threads;
|
||||
int version;
|
||||
int cc;
|
||||
std::string ptxasPath;
|
||||
triton::getCuCCAndVersionFromDevice(device, &cc, &version,
|
||||
&ptxasPath);
|
||||
|
||||
std::string cubin = drv::ptx_to_cubin(ptxCode, ptxasPath, cc);
|
||||
py::bytes bytes(cubin);
|
||||
return bytes;
|
||||
});
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
py::module subm = m.def_submodule("triton");
|
||||
// init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||
init_triton_ir(std::move(subm.def_submodule("ir")));
|
||||
init_translation(subm);
|
||||
}
|
||||
|
Reference in New Issue
Block a user