[CI] Various improvements to CI (#137)

Add clean-up before CI runs. Now using static LLVM-11 libraries from system rather than recompilation. Still no run-time LLVM dependencies
This commit is contained in:
Philippe Tillet
2021-07-22 11:41:51 -07:00
committed by Philippe Tillet
parent 298aead378
commit 8eb63bcb01
7 changed files with 196 additions and 190 deletions

View File

@@ -59,6 +59,13 @@ std::string exec(const char* cmd) {
return result;
}
void LLVMInitializeNVPTXTargetInfo();
void LLVMInitializeNVPTXTarget();
void LLVMInitializeNVPTXTargetMC();
void LLVMInitializeNVPTXAsmPrinter();
void LLVMInitializeNVPTXAsmParser();
namespace triton
{
namespace driver
@@ -68,14 +75,14 @@ namespace driver
// Base //
/* ------------------------ */
void module::init_llvm() {
static bool init = false;
if(!init){
llvm::InitializeAllTargetInfos();
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmParsers();
llvm::InitializeAllAsmPrinters();
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
init = true;
}
}
@@ -111,80 +118,81 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
/* ------------------------ */
host_module::host_module(std::unique_ptr<llvm::Module> src): module(host_module_t(), true) {
init_llvm();
// create kernel wrapper
llvm::LLVMContext &ctx = src->getContext();
llvm::Type *void_ty = llvm::Type::getVoidTy(ctx);
llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo();
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
std::vector<llvm::Type*> tys = {args_ty, int32_ty, int32_ty, int32_ty};
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false);
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "_main", &*src);
llvm::Function* fn = &*src->getFunctionList().begin();
llvm::FunctionType *fn_ty = fn->getFunctionType();
std::vector<llvm::Value*> fn_args(fn_ty->getNumParams());
std::vector<llvm::Value*> ptrs(fn_args.size() - 3);
llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main);
llvm::IRBuilder<> ir_builder(ctx);
ir_builder.SetInsertPoint(entry);
auto get_size = [](llvm::Type* ty) { return ty->isPointerTy() ? sizeof(char*) : ty->getPrimitiveSizeInBits() / 8; };
llvm::Value* base = main->arg_begin();
llvm::Value* args_base = ir_builder.CreateBitCast(base, base->getType()->getPointerElementType());
throw std::runtime_error("CPU unsupported");
// init_llvm();
// // create kernel wrapper
// llvm::LLVMContext &ctx = src->getContext();
// llvm::Type *void_ty = llvm::Type::getVoidTy(ctx);
// llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo();
// llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
// std::vector<llvm::Type*> tys = {args_ty, int32_ty, int32_ty, int32_ty};
// llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false);
// llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "_main", &*src);
// llvm::Function* fn = &*src->getFunctionList().begin();
// llvm::FunctionType *fn_ty = fn->getFunctionType();
// std::vector<llvm::Value*> fn_args(fn_ty->getNumParams());
// std::vector<llvm::Value*> ptrs(fn_args.size() - 3);
// llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main);
// llvm::IRBuilder<> ir_builder(ctx);
// ir_builder.SetInsertPoint(entry);
// auto get_size = [](llvm::Type* ty) { return ty->isPointerTy() ? sizeof(char*) : ty->getPrimitiveSizeInBits() / 8; };
// llvm::Value* base = main->arg_begin();
// llvm::Value* args_base = ir_builder.CreateBitCast(base, base->getType()->getPointerElementType());
size_t offset = 0;
for(unsigned i = 0; i < ptrs.size(); i++){
ptrs[i] = ir_builder.CreateGEP(args_base, ir_builder.getInt32(offset));
size_t nbytes = get_size(fn_ty->getParamType(i));
offset += nbytes;
if(i < ptrs.size() - 1){
size_t np1bytes = get_size(fn_ty->getParamType(i+1));
offset = (offset + np1bytes - 1) / np1bytes * np1bytes;
}
}
for(unsigned i = 0; i < ptrs.size(); i++)
ptrs[i] = ir_builder.CreateBitCast(ptrs[i], fn_ty->getParamType(i)->getPointerTo());
for(unsigned i = 0; i < ptrs.size(); i++)
fn_args[i] = ir_builder.CreateLoad(ptrs[i]);
// size_t offset = 0;
// for(unsigned i = 0; i < ptrs.size(); i++){
// ptrs[i] = ir_builder.CreateGEP(args_base, ir_builder.getInt32(offset));
// size_t nbytes = get_size(fn_ty->getParamType(i));
// offset += nbytes;
// if(i < ptrs.size() - 1){
// size_t np1bytes = get_size(fn_ty->getParamType(i+1));
// offset = (offset + np1bytes - 1) / np1bytes * np1bytes;
// }
// }
// for(unsigned i = 0; i < ptrs.size(); i++)
// ptrs[i] = ir_builder.CreateBitCast(ptrs[i], fn_ty->getParamType(i)->getPointerTo());
// for(unsigned i = 0; i < ptrs.size(); i++)
// fn_args[i] = ir_builder.CreateLoad(ptrs[i]);
fn_args[fn_args.size() - 3] = main->arg_begin() + 1;
fn_args[fn_args.size() - 2] = main->arg_begin() + 2;
fn_args[fn_args.size() - 1] = main->arg_begin() + 3;
ir_builder.CreateCall(fn, fn_args);
ir_builder.CreateRetVoid();
// fn_args[fn_args.size() - 3] = main->arg_begin() + 1;
// fn_args[fn_args.size() - 2] = main->arg_begin() + 2;
// fn_args[fn_args.size() - 1] = main->arg_begin() + 3;
// ir_builder.CreateCall(fn, fn_args);
// ir_builder.CreateRetVoid();
// llvm::legacy::PassManager pm;
// pm.add(llvm::createPrintModulePass(llvm::outs()));
// pm.add(llvm::createVerifierPass());
// pm.run(*src);
//// llvm::legacy::PassManager pm;
//// pm.add(llvm::createPrintModulePass(llvm::outs()));
//// pm.add(llvm::createVerifierPass());
//// pm.run(*src);
// create execution engine
for(llvm::Function& fn: src->functions())
hst_->functions[fn.getName().str()] = &fn;
//// create execution engine
// for(llvm::Function& fn: src->functions())
// hst_->functions[fn.getName().str()] = &fn;
// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost();
// auto DL = JTMB.getDefaultDataLayoutForTarget();
// auto CIRC = std::unique_ptr<llvm::orc::ConcurrentIRCompiler>(new llvm::orc::ConcurrentIRCompiler(JTMB));
// hst_->ES = new llvm::orc::ExecutionSession();
// hst_->ObjectLayer = new llvm::orc::RTDyldObjectLinkingLayer(*hst_->ES, []() { return std::unique_ptr<llvm::SectionMemoryManager>(new llvm::SectionMemoryManager()); });
// hst_->CompileLayer = new llvm::orc::IRCompileLayer(*hst_->ES, *hst_->ObjectLayer, *CIRC);
// hst_->DL = new llvm::DataLayout(std::move(*DL));
// hst_->Mangle = new llvm::orc::MangleAndInterner(*hst_->ES, *hst_->DL);
// hst_->Ctx = new llvm::orc::ThreadSafeContext(std::unique_ptr<llvm::LLVMContext>(new llvm::LLVMContext()));
// hst_->MainJD = &hst_->ES->createJITDylib("<main>");
// hst_->MainJD->setGenerator(llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
// hst_->DL->getGlobalPrefix())));
// llvm::cantFail(hst_->CompileLayer->add(*hst_->MainJD, llvm::orc::ThreadSafeModule(std::move(src), *hst_->Ctx)));
// hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->ES->lookup({hst_->MainJD}, (*hst_->Mangle)("_main"))->getAddress());
//// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost();
//// auto DL = JTMB.getDefaultDataLayoutForTarget();
//// auto CIRC = std::unique_ptr<llvm::orc::ConcurrentIRCompiler>(new llvm::orc::ConcurrentIRCompiler(JTMB));
//// hst_->ES = new llvm::orc::ExecutionSession();
//// hst_->ObjectLayer = new llvm::orc::RTDyldObjectLinkingLayer(*hst_->ES, []() { return std::unique_ptr<llvm::SectionMemoryManager>(new llvm::SectionMemoryManager()); });
//// hst_->CompileLayer = new llvm::orc::IRCompileLayer(*hst_->ES, *hst_->ObjectLayer, *CIRC);
//// hst_->DL = new llvm::DataLayout(std::move(*DL));
//// hst_->Mangle = new llvm::orc::MangleAndInterner(*hst_->ES, *hst_->DL);
//// hst_->Ctx = new llvm::orc::ThreadSafeContext(std::unique_ptr<llvm::LLVMContext>(new llvm::LLVMContext()));
//// hst_->MainJD = &hst_->ES->createJITDylib("<main>");
//// hst_->MainJD->setGenerator(llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
//// hst_->DL->getGlobalPrefix())));
//// llvm::cantFail(hst_->CompileLayer->add(*hst_->MainJD, llvm::orc::ThreadSafeModule(std::move(src), *hst_->Ctx)));
//// hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->ES->lookup({hst_->MainJD}, (*hst_->Mangle)("_main"))->getAddress());
llvm::EngineBuilder builder(std::move(src));
builder.setErrorStr(&hst_->error);
builder.setMCJITMemoryManager(std::make_unique<llvm::SectionMemoryManager>());
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
builder.setEngineKind(llvm::EngineKind::JIT);
hst_->engine = builder.create();
hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->engine->getFunctionAddress("_main"));
// llvm::EngineBuilder builder(std::move(src));
// builder.setErrorStr(&hst_->error);
// builder.setMCJITMemoryManager(std::make_unique<llvm::SectionMemoryManager>());
// builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
// builder.setEngineKind(llvm::EngineKind::JIT);
// hst_->engine = builder.create();
// hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->engine->getFunctionAddress("_main"));
}
std::unique_ptr<buffer> host_module::symbol(const char *name) const {
@@ -211,7 +219,7 @@ static std::map<int, int> vptx = {
{11010, 71},
{11020, 72},
{11030, 73},
{11040, 74}
{11040, 73}
};
std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) {