[FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#include "triton/codegen/pass.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/extern_lib.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/ir/builder.h"
|
||||
@@ -19,7 +20,6 @@
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
@@ -140,7 +140,7 @@ size_t get_pointer_range_size(uint64_t addr){
|
||||
// Launch
|
||||
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
|
||||
int num_warps, int num_stages) {
|
||||
int num_warps, int num_stages, py::dict& extern_libs) {
|
||||
size_t len = PyList_Size(args.ptr());
|
||||
params.reserve(8*len); // 8 max bytes by argument
|
||||
char* params_ptr = ¶ms[0];
|
||||
@@ -256,6 +256,11 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
throw std::runtime_error(err_msg);
|
||||
}
|
||||
params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]);
|
||||
|
||||
for (auto item : extern_libs) {
|
||||
cache_key += "-" + item.first.cast<std::string>();
|
||||
cache_key += "_" + item.second.cast<std::string>();
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@@ -288,7 +293,7 @@ void init_triton_runtime(py::module &&m) {
|
||||
// cache key
|
||||
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
|
||||
py::function add_to_cache, py::object grid){
|
||||
py::dict extern_libs, py::function add_to_cache, py::object grid){
|
||||
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
|
||||
long _num_warps = PyLong_AsLong(num_warps.ptr());
|
||||
long _num_stages = PyLong_AsLong(num_stages.ptr());
|
||||
@@ -296,13 +301,14 @@ void init_triton_runtime(py::module &&m) {
|
||||
std::string params;
|
||||
size_t params_size;
|
||||
py::dict constants;
|
||||
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages);
|
||||
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params,
|
||||
params_size, constants, _num_warps, _num_stages, extern_libs);
|
||||
|
||||
// get cached binary
|
||||
py::str key(cache_key);
|
||||
py::bool_ noop = false;
|
||||
if(!bin_cache.contains(key)) {
|
||||
noop = add_to_cache(key, args, device, num_warps, num_stages);
|
||||
noop = add_to_cache(key, args, device, num_warps, num_stages, extern_libs);
|
||||
}
|
||||
if (noop)
|
||||
return (py::object)py::none();
|
||||
@@ -467,11 +473,10 @@ std::tuple<uint64_t, uint64_t, uint64_t, uint64_t> hip_load_binary(const std::st
|
||||
// ---------------------------------------
|
||||
|
||||
// CUDA
|
||||
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
|
||||
uint64_t device, int num_warps, int num_stages,
|
||||
asm_map_t &asm_map){
|
||||
|
||||
int n_shared_bytes;
|
||||
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(
|
||||
const std::string &name, ir::module &ir, uint64_t device, int num_warps,
|
||||
int num_stages, asm_map_t &asm_map,
|
||||
const triton::codegen::ExternLibMap &extern_lib_map) {
|
||||
py::gil_scoped_release allow_threads;
|
||||
llvm::LLVMContext ctx;
|
||||
// device properties
|
||||
@@ -483,7 +488,9 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
||||
std::string ptxas_path = drv::path_to_ptxas(version);
|
||||
// Triton-IR -> NVPTX LLVM-IR
|
||||
triton::codegen::nvidia_cu_target target(cc);
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
|
||||
int n_shared_bytes;
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(
|
||||
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
|
||||
std::string tmp;
|
||||
llvm::raw_string_ostream llir(tmp);
|
||||
llir << *llvm;
|
||||
@@ -502,14 +509,16 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
||||
}
|
||||
|
||||
// HIP
|
||||
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name, ir::module &ir,
|
||||
uint64_t device, int num_warps, int num_stages,
|
||||
asm_map_t &asm_map){
|
||||
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(
|
||||
const std::string &name, ir::module &ir, uint64_t device, int num_warps,
|
||||
int num_stages, asm_map_t &asm_map,
|
||||
const triton::codegen::ExternLibMap &extern_lib_map) {
|
||||
llvm::LLVMContext ctx;
|
||||
// Triton-IR -> NVPTX LLVM-IR
|
||||
triton::codegen::amd_cl_target target;
|
||||
int n_shared_bytes;
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes);
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(
|
||||
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
|
||||
std::string tmp;
|
||||
llvm::raw_string_ostream llir(tmp);
|
||||
llir << *llvm;
|
||||
@@ -523,7 +532,9 @@ std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name
|
||||
|
||||
void init_triton_codegen(py::module &&m) {
|
||||
m.def(
|
||||
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) {
|
||||
"compile_ttir",
|
||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps,
|
||||
int num_stages, py::dict& extern_libs) {
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
// record asm as we generate
|
||||
asm_map_t asm_map;
|
||||
@@ -531,11 +542,20 @@ void init_triton_codegen(py::module &&m) {
|
||||
ir.print(ttir);
|
||||
asm_map["ttir"] = py::cast(ttir.str());
|
||||
llvm::LLVMContext ctx;
|
||||
// construct extern lib map
|
||||
triton::codegen::ExternLibMap extern_lib_map;
|
||||
for (auto item : extern_libs) {
|
||||
auto name = item.first.cast<std::string>();
|
||||
auto path = item.second.cast<std::string>();
|
||||
extern_lib_map.emplace(
|
||||
name, triton::codegen::create_extern_lib(name, path));
|
||||
}
|
||||
if(backend == CUDA)
|
||||
return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
|
||||
return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map);
|
||||
if(backend == ROCM)
|
||||
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
|
||||
}, py::return_value_policy::take_ownership);
|
||||
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map);
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||
py::gil_scoped_release allow_threads;
|
||||
if(backend == CUDA)
|
||||
@@ -931,7 +951,8 @@ void init_triton_ir(py::module &&m) {
|
||||
// Utilities
|
||||
.def("create_clock", &ir::builder::create_clock, ret::reference)
|
||||
.def("create_globaltimer", &ir::builder::create_globaltimer, ret::reference)
|
||||
|
||||
// Extern instruction
|
||||
.def("create_extern_elementwise", &ir::builder::create_extern_elementwise, ret::reference)
|
||||
// Built-in instruction
|
||||
.def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference)
|
||||
.def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference)
|
||||
|
Reference in New Issue
Block a user