[CODEGEN] More work on the CPU backend

This commit is contained in:
Philippe Tillet
2020-09-11 11:44:34 -04:00
committed by Philippe Tillet
parent 64eaec016f
commit 840308ab5d
17 changed files with 258 additions and 185 deletions

View File

@@ -135,12 +135,6 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module> src): module(context, host_module_t(), true) {
init_llvm();
// host info
// std::string triple = llvm::sys::getDefaultTargetTriple();
// std::string cpu = llvm::sys::getHostCPUName();
// llvm::SmallVector<char, 0> buffer;
// module::compile_llvm_module(src, triple, cpu, "", buffer, "", Assembly);
// create kernel wrapper
llvm::LLVMContext &ctx = src->getContext();
llvm::Type *void_ty = llvm::Type::getVoidTy(ctx);
@@ -148,37 +142,72 @@ host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
std::vector<llvm::Type*> tys = {args_ty, int32_ty, int32_ty, int32_ty};
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false);
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", &*src);
llvm::Function* fn = src->getFunction("matmul");
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "_main", &*src);
llvm::Function* fn = &*src->getFunctionList().begin();
llvm::FunctionType *fn_ty = fn->getFunctionType();
std::vector<llvm::Value*> fn_args(fn_ty->getNumParams());
std::vector<llvm::Value*> ptrs(fn_args.size() - 3);
llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main);
llvm::IRBuilder<> ir_builder(ctx);
ir_builder.SetInsertPoint(entry);
for(unsigned i = 0; i < ptrs.size(); i++)
ptrs[i] = ir_builder.CreateGEP(main->arg_begin(), ir_builder.getInt32(i));
auto get_size = [](llvm::Type* ty) { return ty->isPointerTy() ? sizeof(char*) : ty->getPrimitiveSizeInBits() / 8; };
llvm::Value* base = main->arg_begin();
llvm::Value* args_base = ir_builder.CreateBitCast(base, base->getType()->getPointerElementType());
size_t offset = 0;
for(unsigned i = 0; i < ptrs.size(); i++){
llvm::Value* addr = ir_builder.CreateBitCast(ir_builder.CreateLoad(ptrs[i]), fn_ty->getParamType(i)->getPointerTo());
fn_args[i] = ir_builder.CreateLoad(addr);
ptrs[i] = ir_builder.CreateGEP(args_base, ir_builder.getInt32(offset));
size_t nbytes = get_size(fn_ty->getParamType(i));
offset += nbytes;
if(i < ptrs.size() - 1){
size_t np1bytes = get_size(fn_ty->getParamType(i+1));
offset = (offset + np1bytes - 1) / np1bytes * np1bytes;
}
}
for(unsigned i = 0; i < ptrs.size(); i++)
ptrs[i] = ir_builder.CreateBitCast(ptrs[i], fn_ty->getParamType(i)->getPointerTo());
for(unsigned i = 0; i < ptrs.size(); i++)
fn_args[i] = ir_builder.CreateLoad(ptrs[i]);
fn_args[fn_args.size() - 3] = main->arg_begin() + 1;
fn_args[fn_args.size() - 2] = main->arg_begin() + 2;
fn_args[fn_args.size() - 1] = main->arg_begin() + 3;
ir_builder.CreateCall(fn, fn_args);
ir_builder.CreateRetVoid();
// llvm::legacy::PassManager pm;
// pm.add(llvm::createPrintModulePass(llvm::outs()));
// pm.add(llvm::createVerifierPass());
// pm.run(*src);
// create execution engine
// create execution engine
for(llvm::Function& fn: src->functions())
hst_->functions[fn.getName()] = &fn;
// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost();
// auto DL = JTMB.getDefaultDataLayoutForTarget();
// auto CIRC = std::unique_ptr<llvm::orc::ConcurrentIRCompiler>(new llvm::orc::ConcurrentIRCompiler(JTMB));
// hst_->ES = new llvm::orc::ExecutionSession();
// hst_->ObjectLayer = new llvm::orc::RTDyldObjectLinkingLayer(*hst_->ES, []() { return std::unique_ptr<llvm::SectionMemoryManager>(new llvm::SectionMemoryManager()); });
// hst_->CompileLayer = new llvm::orc::IRCompileLayer(*hst_->ES, *hst_->ObjectLayer, *CIRC);
// hst_->DL = new llvm::DataLayout(std::move(*DL));
// hst_->Mangle = new llvm::orc::MangleAndInterner(*hst_->ES, *hst_->DL);
// hst_->Ctx = new llvm::orc::ThreadSafeContext(std::unique_ptr<llvm::LLVMContext>(new llvm::LLVMContext()));
// hst_->MainJD = &hst_->ES->createJITDylib("<main>");
// hst_->MainJD->setGenerator(llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
// hst_->DL->getGlobalPrefix())));
// llvm::cantFail(hst_->CompileLayer->add(*hst_->MainJD, llvm::orc::ThreadSafeModule(std::move(src), *hst_->Ctx)));
// hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->ES->lookup({hst_->MainJD}, (*hst_->Mangle)("_main"))->getAddress());
llvm::EngineBuilder builder(std::move(src));
builder.setErrorStr(&hst_->error);
builder.setMCJITMemoryManager(llvm::make_unique<llvm::SectionMemoryManager>());
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
builder.setEngineKind(llvm::EngineKind::JIT);
builder.setUseOrcMCJITReplacement(true);
hst_->engine = builder.create();
hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->engine->getFunctionAddress("_main"));
}
std::unique_ptr<buffer> host_module::symbol(const char *name) const {