[driver] adding opencl in the driver API
This commit is contained in:
@@ -46,9 +46,34 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
std::string module::compile_llvm_module(llvm::Module* module) {
|
||||
init_llvm();
|
||||
/* ------------------------ */
|
||||
// Base //
|
||||
/* ------------------------ */
|
||||
|
||||
module::module(driver::context* ctx, CUmodule mod, bool has_ownership)
|
||||
: polymorphic_resource(mod, has_ownership), ctx_(ctx) {
|
||||
}
|
||||
|
||||
module::module(driver::context* ctx, cl_program mod, bool has_ownership)
|
||||
: polymorphic_resource(mod, has_ownership), ctx_(ctx) {
|
||||
}
|
||||
|
||||
driver::context* module::context() const {
|
||||
return ctx_;
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------ */
|
||||
// OpenCL //
|
||||
/* ------------------------ */
|
||||
|
||||
|
||||
/* ------------------------ */
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
|
||||
std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
init_llvm();
|
||||
// create machine
|
||||
module->setTargetTriple("nvptx64-nvidia-cuda");
|
||||
std::string error;
|
||||
@@ -67,18 +92,17 @@ std::string module::compile_llvm_module(llvm::Module* module) {
|
||||
layout += "-p3:32:32-p4:32:32-p5:32:32";
|
||||
layout += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
|
||||
module->setDataLayout(layout);
|
||||
|
||||
// emit machine code
|
||||
llvm::legacy::PassManager pass;
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile);
|
||||
pass.run(*module);
|
||||
|
||||
// done
|
||||
return std::string(buffer.begin(), buffer.end());
|
||||
}
|
||||
|
||||
void module::init_llvm() {
|
||||
void cu_module::init_llvm() {
|
||||
static bool init = false;
|
||||
if(!init){
|
||||
llvm::InitializeAllTargetInfos();
|
||||
@@ -90,10 +114,10 @@ void module::init_llvm() {
|
||||
}
|
||||
}
|
||||
|
||||
module::module(driver::context const & context, llvm::Module* ll_module): module(context, compile_llvm_module(ll_module)){ }
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
module::module(driver::context const & context, std::string const & source) : context_(context), source_(source){
|
||||
ContextSwitcher ctx_switch(context_);
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
unsigned int errbufsize = 8096;
|
||||
@@ -108,17 +132,11 @@ module::module(driver::context const & context, std::string const & source) : co
|
||||
}
|
||||
}
|
||||
|
||||
driver::context const & module::context() const
|
||||
{ return context_; }
|
||||
|
||||
handle<CUmodule> const & module::cu() const
|
||||
{ return cu_; }
|
||||
|
||||
buffer module::symbol(const char *name) const{
|
||||
cu_buffer cu_module::symbol(const char *name) const{
|
||||
CUdeviceptr handle;
|
||||
size_t size;
|
||||
dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name);
|
||||
return buffer(context_, handle, false);
|
||||
return cu_buffer(ctx_, handle, false);
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user