[CODEGEN] More work on the CPU backend
This commit is contained in:
committed by
Philippe Tillet
parent
64eaec016f
commit
840308ab5d
@@ -47,6 +47,12 @@ void backend::platforms::init() {
|
||||
if(dispatch::cuinit()){
|
||||
cache_.push_back(new cu_platform());
|
||||
}
|
||||
//if host should be added
|
||||
bool host_visible = true;
|
||||
if(host_visible){
|
||||
cache_.push_back(new host_platform());
|
||||
}
|
||||
|
||||
// //if OpenCL is here
|
||||
// if(dispatch::clinit()){
|
||||
// cl_uint num_platforms;
|
||||
@@ -56,11 +62,7 @@ void backend::platforms::init() {
|
||||
// for(cl_platform_id id: ids)
|
||||
// cache_.push_back(new cl_platform(id));
|
||||
// }
|
||||
// //if host is here
|
||||
// bool host_visible = true;
|
||||
// if(host_visible){
|
||||
// cache_.push_back(new host_platform());
|
||||
// }
|
||||
|
||||
if(cache_.empty())
|
||||
throw std::runtime_error("Triton: No backend available. Make sure CUDA is available in your library path");
|
||||
}
|
||||
|
@@ -53,6 +53,14 @@ size_t buffer::size() {
|
||||
return size_;
|
||||
}
|
||||
|
||||
uintptr_t buffer::addr_as_uintptr_t() {
|
||||
switch(backend_){
|
||||
case CUDA: return *cu_;
|
||||
case Host: return (uintptr_t)hst_->data;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
buffer* buffer::create(driver::context* ctx, size_t size) {
|
||||
switch(ctx->backend()){
|
||||
|
@@ -135,12 +135,6 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
|
||||
|
||||
host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module> src): module(context, host_module_t(), true) {
|
||||
init_llvm();
|
||||
// host info
|
||||
// std::string triple = llvm::sys::getDefaultTargetTriple();
|
||||
// std::string cpu = llvm::sys::getHostCPUName();
|
||||
// llvm::SmallVector<char, 0> buffer;
|
||||
// module::compile_llvm_module(src, triple, cpu, "", buffer, "", Assembly);
|
||||
|
||||
// create kernel wrapper
|
||||
llvm::LLVMContext &ctx = src->getContext();
|
||||
llvm::Type *void_ty = llvm::Type::getVoidTy(ctx);
|
||||
@@ -148,37 +142,72 @@ host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module
|
||||
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
|
||||
std::vector<llvm::Type*> tys = {args_ty, int32_ty, int32_ty, int32_ty};
|
||||
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false);
|
||||
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", &*src);
|
||||
llvm::Function* fn = src->getFunction("matmul");
|
||||
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "_main", &*src);
|
||||
llvm::Function* fn = &*src->getFunctionList().begin();
|
||||
llvm::FunctionType *fn_ty = fn->getFunctionType();
|
||||
std::vector<llvm::Value*> fn_args(fn_ty->getNumParams());
|
||||
std::vector<llvm::Value*> ptrs(fn_args.size() - 3);
|
||||
llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main);
|
||||
llvm::IRBuilder<> ir_builder(ctx);
|
||||
ir_builder.SetInsertPoint(entry);
|
||||
for(unsigned i = 0; i < ptrs.size(); i++)
|
||||
ptrs[i] = ir_builder.CreateGEP(main->arg_begin(), ir_builder.getInt32(i));
|
||||
auto get_size = [](llvm::Type* ty) { return ty->isPointerTy() ? sizeof(char*) : ty->getPrimitiveSizeInBits() / 8; };
|
||||
llvm::Value* base = main->arg_begin();
|
||||
llvm::Value* args_base = ir_builder.CreateBitCast(base, base->getType()->getPointerElementType());
|
||||
|
||||
size_t offset = 0;
|
||||
for(unsigned i = 0; i < ptrs.size(); i++){
|
||||
llvm::Value* addr = ir_builder.CreateBitCast(ir_builder.CreateLoad(ptrs[i]), fn_ty->getParamType(i)->getPointerTo());
|
||||
fn_args[i] = ir_builder.CreateLoad(addr);
|
||||
ptrs[i] = ir_builder.CreateGEP(args_base, ir_builder.getInt32(offset));
|
||||
size_t nbytes = get_size(fn_ty->getParamType(i));
|
||||
offset += nbytes;
|
||||
if(i < ptrs.size() - 1){
|
||||
size_t np1bytes = get_size(fn_ty->getParamType(i+1));
|
||||
offset = (offset + np1bytes - 1) / np1bytes * np1bytes;
|
||||
}
|
||||
}
|
||||
for(unsigned i = 0; i < ptrs.size(); i++)
|
||||
ptrs[i] = ir_builder.CreateBitCast(ptrs[i], fn_ty->getParamType(i)->getPointerTo());
|
||||
for(unsigned i = 0; i < ptrs.size(); i++)
|
||||
fn_args[i] = ir_builder.CreateLoad(ptrs[i]);
|
||||
|
||||
fn_args[fn_args.size() - 3] = main->arg_begin() + 1;
|
||||
fn_args[fn_args.size() - 2] = main->arg_begin() + 2;
|
||||
fn_args[fn_args.size() - 1] = main->arg_begin() + 3;
|
||||
ir_builder.CreateCall(fn, fn_args);
|
||||
ir_builder.CreateRetVoid();
|
||||
|
||||
// llvm::legacy::PassManager pm;
|
||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
// pm.add(llvm::createVerifierPass());
|
||||
// pm.run(*src);
|
||||
|
||||
// create execution engine
|
||||
// create execution engine
|
||||
for(llvm::Function& fn: src->functions())
|
||||
hst_->functions[fn.getName()] = &fn;
|
||||
|
||||
// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost();
|
||||
// auto DL = JTMB.getDefaultDataLayoutForTarget();
|
||||
// auto CIRC = std::unique_ptr<llvm::orc::ConcurrentIRCompiler>(new llvm::orc::ConcurrentIRCompiler(JTMB));
|
||||
// hst_->ES = new llvm::orc::ExecutionSession();
|
||||
// hst_->ObjectLayer = new llvm::orc::RTDyldObjectLinkingLayer(*hst_->ES, []() { return std::unique_ptr<llvm::SectionMemoryManager>(new llvm::SectionMemoryManager()); });
|
||||
// hst_->CompileLayer = new llvm::orc::IRCompileLayer(*hst_->ES, *hst_->ObjectLayer, *CIRC);
|
||||
// hst_->DL = new llvm::DataLayout(std::move(*DL));
|
||||
// hst_->Mangle = new llvm::orc::MangleAndInterner(*hst_->ES, *hst_->DL);
|
||||
// hst_->Ctx = new llvm::orc::ThreadSafeContext(std::unique_ptr<llvm::LLVMContext>(new llvm::LLVMContext()));
|
||||
// hst_->MainJD = &hst_->ES->createJITDylib("<main>");
|
||||
// hst_->MainJD->setGenerator(llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
|
||||
// hst_->DL->getGlobalPrefix())));
|
||||
// llvm::cantFail(hst_->CompileLayer->add(*hst_->MainJD, llvm::orc::ThreadSafeModule(std::move(src), *hst_->Ctx)));
|
||||
// hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->ES->lookup({hst_->MainJD}, (*hst_->Mangle)("_main"))->getAddress());
|
||||
|
||||
|
||||
|
||||
llvm::EngineBuilder builder(std::move(src));
|
||||
builder.setErrorStr(&hst_->error);
|
||||
builder.setMCJITMemoryManager(llvm::make_unique<llvm::SectionMemoryManager>());
|
||||
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
|
||||
builder.setEngineKind(llvm::EngineKind::JIT);
|
||||
builder.setUseOrcMCJITReplacement(true);
|
||||
hst_->engine = builder.create();
|
||||
hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->engine->getFunctionAddress("_main"));
|
||||
}
|
||||
|
||||
std::unique_ptr<buffer> host_module::symbol(const char *name) const {
|
||||
|
@@ -72,21 +72,20 @@ driver::context* stream::context() const {
|
||||
/* ------------------------ */
|
||||
|
||||
host_stream::host_stream(driver::context *ctx): stream(ctx, host_stream_t(), true) {
|
||||
|
||||
hst_->pool.reset(new ThreadPool(8));
|
||||
}
|
||||
|
||||
void host_stream::synchronize() {
|
||||
|
||||
hst_->pool.reset(new ThreadPool(8));
|
||||
}
|
||||
|
||||
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **extra) {
|
||||
driver::host_kernel* hst_kernel = (host_kernel*)kernel;
|
||||
llvm::ExecutionEngine* engine = kernel->module()->hst()->engine;
|
||||
void (*fn)(char**, int32_t, int32_t, int32_t) = (void(*)(char**, int32_t, int32_t, int32_t))engine->getFunctionAddress("main");
|
||||
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **args, size_t args_size) {
|
||||
ThreadPool pool(4);
|
||||
auto hst = kernel->module()->hst();
|
||||
for(size_t i = 0; i < grid[0]; i++)
|
||||
for(size_t j = 0; j < grid[1]; j++)
|
||||
for(size_t k = 0; k < grid[2]; k++)
|
||||
fn((char**)hst_kernel->params().data(), int32_t(i), int32_t(j), int32_t(k));
|
||||
hst_->pool->enqueue(hst->fn, (char**)args, int32_t(i), int32_t(j), int32_t(k));
|
||||
}
|
||||
|
||||
void host_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
|
||||
@@ -112,7 +111,7 @@ void cl_stream::synchronize() {
|
||||
check(dispatch::clFinish(*cl_));
|
||||
}
|
||||
|
||||
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **extra) {
|
||||
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **args, size_t args_size) {
|
||||
std::array<size_t, 3> global = {grid[0]*block[0], grid[1]*block[1], grid[2]*block[2]};
|
||||
check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)global.data(), (const size_t*)block.data(), 0, NULL, NULL));
|
||||
}
|
||||
@@ -149,11 +148,16 @@ void cu_stream::synchronize() {
|
||||
dispatch::cuStreamSynchronize(*cu_);
|
||||
}
|
||||
|
||||
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void** extra) {
|
||||
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void** args, size_t args_size) {
|
||||
cu_context::context_switcher ctx_switch(*ctx_);
|
||||
void *config[] = {
|
||||
CU_LAUNCH_PARAM_BUFFER_POINTER, args,
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
CU_LAUNCH_PARAM_END
|
||||
};
|
||||
if(event)
|
||||
dispatch::cuEventRecord(event->cu()->first, *cu_);
|
||||
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, extra);
|
||||
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, config);
|
||||
if(event)
|
||||
dispatch::cuEventRecord(event->cu()->second, *cu_);
|
||||
}
|
||||
|
Reference in New Issue
Block a user