[driver] now passing std::unique_ptr<> instead of cloning LLVM module
when compiling it
This commit is contained in:
@@ -8,10 +8,9 @@ option(BUILD_TESTS "Build C++ Triton tests" ON)
|
||||
option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
|
||||
|
||||
# LLVM
|
||||
find_package(LLVM REQUIRED CONFIG)
|
||||
find_package(LLVM REQUIRED)
|
||||
include_directories(${LLVM_INCLUDE_DIRS})
|
||||
add_definitions(${LLVM_DEFINITIONS})
|
||||
llvm_map_components_to_libnames(llvm_libs all)
|
||||
|
||||
# Default build type
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
@@ -21,7 +20,7 @@ endif()
|
||||
|
||||
# Compiler flags
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${LLVM_CXXFLAGS} -std=c++11")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||
|
||||
# Tests
|
||||
if(BUILD_TESTS)
|
||||
@@ -53,13 +52,6 @@ endif()
|
||||
# Triton
|
||||
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
target_link_libraries(triton LLVM)
|
||||
|
||||
# Warning level
|
||||
#if(MSVC)
|
||||
# target_compile_options(triton PRIVATE /W4)
|
||||
#else()
|
||||
# target_compile_options(triton PRIVATE -Wno-unused-parameter -Wall -Wextra -pedantic)
|
||||
#endif()
|
||||
|
||||
link_directories(${LLVM_LIBRARY_DIRS})
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES})
|
||||
|
||||
|
@@ -66,7 +66,6 @@ struct backend
|
||||
|
||||
public:
|
||||
static void release();
|
||||
static driver::module* get(driver::stream* stream, std::string const & name, llvm::Module *src);
|
||||
|
||||
private:
|
||||
static std::map<std::tuple<driver::stream*, std::string>, driver::module*> cache_;
|
||||
|
@@ -38,9 +38,9 @@ public:
|
||||
module(driver::context* ctx, CUmodule mod, bool has_ownership);
|
||||
module(driver::context* ctx, cl_program mod, bool has_ownership);
|
||||
module(driver::context* ctx, host_module_t mod, bool has_ownership);
|
||||
static module* create(driver::context* ctx, llvm::Module *src);
|
||||
static module* create(driver::context* ctx, std::unique_ptr<llvm::Module> src);
|
||||
driver::context* context() const;
|
||||
void compile_llvm_module(llvm::Module* module, const std::string& triple,
|
||||
void compile_llvm_module(std::unique_ptr<llvm::Module> module, const std::string& triple,
|
||||
const std::string &proc, std::string layout,
|
||||
llvm::SmallVectorImpl<char> &buffer,
|
||||
const std::string &features,
|
||||
@@ -53,22 +53,22 @@ protected:
|
||||
// CPU
|
||||
class host_module: public module{
|
||||
public:
|
||||
host_module(driver::context* context, llvm::Module *module);
|
||||
host_module(driver::context* context, std::unique_ptr<llvm::Module> module);
|
||||
};
|
||||
|
||||
// OpenCL
|
||||
class ocl_module: public module{
|
||||
|
||||
public:
|
||||
ocl_module(driver::context* context, llvm::Module *module);
|
||||
ocl_module(driver::context* context, std::unique_ptr<llvm::Module> module);
|
||||
};
|
||||
|
||||
// CUDA
|
||||
class cu_module: public module {
|
||||
std::string compile_llvm_module(llvm::Module* module, driver::device* device);
|
||||
std::string compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device);
|
||||
|
||||
public:
|
||||
cu_module(driver::context* context, llvm::Module *module);
|
||||
cu_module(driver::context* context, std::unique_ptr<llvm::Module> module);
|
||||
cu_module(driver::context* context, const std::string& source);
|
||||
cu_buffer* symbol(const char * name) const;
|
||||
|
||||
|
@@ -103,14 +103,6 @@ void backend::modules::release(){
|
||||
cache_.clear();
|
||||
}
|
||||
|
||||
driver::module* backend::modules::get(driver::stream* stream, std::string const & name, llvm::Module* src){
|
||||
std::tuple<driver::stream*, std::string> key(stream, name);
|
||||
if(cache_.find(key)==cache_.end()){
|
||||
return &*cache_.insert({key, driver::module::create(stream->context(), src)}).first->second;
|
||||
}
|
||||
return &*cache_.at(key);
|
||||
}
|
||||
|
||||
std::map<std::tuple<driver::stream*, std::string>, driver::module*> backend::modules::cache_;
|
||||
|
||||
/*-----------------------------------*/
|
||||
|
@@ -76,16 +76,16 @@ driver::context* module::context() const {
|
||||
return ctx_;
|
||||
}
|
||||
|
||||
module* module::create(driver::context* ctx, llvm::Module *src) {
|
||||
module* module::create(driver::context* ctx, std::unique_ptr<llvm::Module> src) {
|
||||
switch(ctx->backend()){
|
||||
case CUDA: return new cu_module(ctx, src);
|
||||
case OpenCL: return new ocl_module(ctx, src);
|
||||
case Host: return new host_module(ctx, src);
|
||||
case CUDA: return new cu_module(ctx, std::move(src));
|
||||
case OpenCL: return new ocl_module(ctx, std::move(src));
|
||||
case Host: return new host_module(ctx, std::move(src));
|
||||
default: throw std::runtime_error("unknown backend");
|
||||
}
|
||||
}
|
||||
|
||||
void module::compile_llvm_module(llvm::Module* module, const std::string& triple,
|
||||
void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std::string& triple,
|
||||
const std::string &proc, std::string layout,
|
||||
llvm::SmallVectorImpl<char> &buffer,
|
||||
const std::string& features,
|
||||
@@ -133,7 +133,7 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
|
||||
// Host //
|
||||
/* ------------------------ */
|
||||
|
||||
host_module::host_module(driver::context * context, llvm::Module* src): module(context, host_module_t(), true) {
|
||||
host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module> src): module(context, host_module_t(), true) {
|
||||
init_llvm();
|
||||
// host info
|
||||
// std::string triple = llvm::sys::getDefaultTargetTriple();
|
||||
@@ -147,7 +147,7 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
|
||||
llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo();
|
||||
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
|
||||
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, {args_ty, int32_ty, int32_ty, int32_ty}, false);
|
||||
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", src);
|
||||
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", &*src);
|
||||
llvm::Function* fn = src->getFunction("matmul");
|
||||
llvm::FunctionType *fn_ty = fn->getFunctionType();
|
||||
std::vector<llvm::Value*> fn_args(fn_ty->getNumParams());
|
||||
@@ -169,10 +169,9 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
|
||||
|
||||
|
||||
// create execution engine
|
||||
auto cloned = llvm::CloneModule(*src);
|
||||
for(llvm::Function& fn: cloned->functions())
|
||||
for(llvm::Function& fn: src->functions())
|
||||
hst_->functions[fn.getName()] = &fn;
|
||||
llvm::EngineBuilder builder(std::move(cloned));
|
||||
llvm::EngineBuilder builder(std::move(src));
|
||||
builder.setErrorStr(&hst_->error);
|
||||
builder.setMCJITMemoryManager(llvm::make_unique<llvm::SectionMemoryManager>());
|
||||
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
|
||||
@@ -185,7 +184,7 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
|
||||
// OpenCL //
|
||||
/* ------------------------ */
|
||||
|
||||
ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(context, cl_program(), true) {
|
||||
ocl_module::ocl_module(driver::context * context, std::unique_ptr<llvm::Module> src): module(context, cl_program(), true) {
|
||||
throw std::runtime_error("not supported");
|
||||
// init_llvm();
|
||||
// llvm::SmallVector<char, 0> buffer;
|
||||
@@ -217,18 +216,20 @@ ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(con
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
|
||||
std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) {
|
||||
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
|
||||
// options
|
||||
auto options = llvm::cl::getRegisteredOptions();
|
||||
// for(auto& opt: options)
|
||||
// std::cout << opt.getKey().str() << std::endl;
|
||||
static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"])->setValue(true);
|
||||
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
|
||||
assert(short_ptr);
|
||||
short_ptr->setValue(true);
|
||||
// compute capability
|
||||
auto cc = ((driver::cu_device*)device)->compute_capability();
|
||||
std::string sm = "sm_" + std::to_string(cc.first) + std::to_string(cc.second);
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", sm, "", buffer, "ptx63", Assembly);
|
||||
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "ptx63", Assembly);
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
size_t start_replace = result.find(".version");
|
||||
size_t end_replace = result.find('\n', start_replace);
|
||||
@@ -237,7 +238,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device*
|
||||
return result;
|
||||
}
|
||||
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module, context->device())) { }
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
|
@@ -188,6 +188,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
||||
|
||||
std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::context *context, const options_t& opt) {
|
||||
std::unique_ptr<codegen::target> target = context->device()->make_target();
|
||||
|
||||
// create passes
|
||||
codegen::analysis::grids grids(opt.num_warps);
|
||||
codegen::analysis::meminfo shmem_info;
|
||||
@@ -201,20 +202,13 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::transform::reassociate reassociate(&alignment_info, &grids);
|
||||
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
|
||||
|
||||
|
||||
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
alignment_info.run(module);
|
||||
grids.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
|
||||
reassociate.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
|
||||
peephole.run(module);
|
||||
|
||||
if(target->is_gpu()){
|
||||
@@ -233,7 +227,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||
selection.run(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, llvm.get()));
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
return res;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user