[driver] adding opencl in the driver API

This commit is contained in:
Philippe Tillet
2019-03-18 23:12:14 -07:00
parent b73c3bdd25
commit 02775a226e
41 changed files with 28700 additions and 398 deletions

View File

@@ -46,9 +46,34 @@ namespace triton
namespace driver
{
std::string module::compile_llvm_module(llvm::Module* module) {
init_llvm();
/* ------------------------ */
// Base //
/* ------------------------ */
module::module(driver::context* ctx, CUmodule mod, bool has_ownership)
: polymorphic_resource(mod, has_ownership), ctx_(ctx) {
}
module::module(driver::context* ctx, cl_program mod, bool has_ownership)
: polymorphic_resource(mod, has_ownership), ctx_(ctx) {
}
driver::context* module::context() const {
return ctx_;
}
/* ------------------------ */
// OpenCL //
/* ------------------------ */
/* ------------------------ */
// CUDA //
/* ------------------------ */
std::string cu_module::compile_llvm_module(llvm::Module* module) {
init_llvm();
// create machine
module->setTargetTriple("nvptx64-nvidia-cuda");
std::string error;
@@ -67,18 +92,17 @@ std::string module::compile_llvm_module(llvm::Module* module) {
layout += "-p3:32:32-p4:32:32-p5:32:32";
layout += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
module->setDataLayout(layout);
// emit machine code
llvm::legacy::PassManager pass;
llvm::SmallVector<char, 0> buffer;
llvm::raw_svector_ostream stream(buffer);
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile);
pass.run(*module);
// done
return std::string(buffer.begin(), buffer.end());
}
void module::init_llvm() {
void cu_module::init_llvm() {
static bool init = false;
if(!init){
llvm::InitializeAllTargetInfos();
@@ -90,10 +114,10 @@ void module::init_llvm() {
}
}
module::module(driver::context const & context, llvm::Module* ll_module): module(context, compile_llvm_module(ll_module)){ }
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
module::module(driver::context const & context, std::string const & source) : context_(context), source_(source){
ContextSwitcher ctx_switch(context_);
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
unsigned int errbufsize = 8096;
@@ -108,17 +132,11 @@ module::module(driver::context const & context, std::string const & source) : co
}
}
driver::context const & module::context() const
{ return context_; }
handle<CUmodule> const & module::cu() const
{ return cu_; }
buffer module::symbol(const char *name) const{
cu_buffer cu_module::symbol(const char *name) const{
CUdeviceptr handle;
size_t size;
dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name);
return buffer(context_, handle, false);
return cu_buffer(ctx_, handle, false);
}