[code generation] more flexibility in backend selection

This commit is contained in:
Philippe Tillet
2019-03-27 11:29:42 -07:00
parent e04253c0dd
commit bc2a257d5c
7 changed files with 47 additions and 39 deletions

View File

@@ -157,7 +157,7 @@ int main() {
1, 8,
1,
};
// params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4};
params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4};
// jit.autotune(src, benchmark);
jit.add_module(src, params);

View File

@@ -49,14 +49,17 @@ public:
allocation(&liveness, &buffer_info),
barriers(&allocation, &buffer_info),
vectorize(&tune),
selection(&allocation, &tune, &buffer_info, target) { }
selection(&allocation, &tune, &buffer_info, target),
target_(target) { }
void init(ir::module &module) {
// buffer_info.run(module);
// shared.run(module);
// liveness.run(module);
// allocation.run();
// barriers.run(module);
if(target_->is_gpu()){
buffer_info.run(module);
shared.run(module);
liveness.run(module);
allocation.run();
barriers.run(module);
}
vectorize.run(module);
// triton::ir::print(module, std::cout);
}
@@ -69,6 +72,7 @@ public:
codegen::barriers barriers;
codegen::vectorize vectorize;
codegen::selection selection;
codegen::target* target_;
};
private:

View File

@@ -810,14 +810,17 @@ void selection::run(ir::module &src, Module &dst) {
for(ir::function *fn: src.get_function_list()) {
// create LLVM function
FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), dst_ctx);
Type *dst_fn_ret_ty = fn_ty->getReturnType();
std::vector<Type*> dst_fn_args_ty;
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
dst_fn_args_ty.push_back(fn_ty->getParamType(i));
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
FunctionType *dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
FunctionType *dst_fn_ty = fn_ty;
if(!tgt_->is_gpu()){
Type *dst_fn_ret_ty = fn_ty->getReturnType();
std::vector<Type*> dst_fn_args_ty;
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
dst_fn_args_ty.push_back(fn_ty->getParamType(i));
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
}
// grid indices
fn->get_fn_type()->get_return_ty();
Function *dst_fn = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
@@ -845,15 +848,16 @@ void selection::run(ir::module &src, Module &dst) {
// allocate shared memory
Value *sh_mem_ptr = nullptr;
// if(unsigned alloc_size = alloc_->get_allocated_size()){
// Type *int_8_ty = Type::getInt8Ty(dst_ctx);
// ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
// Type *ptr_ty = PointerType::get(int_8_ty, 3);
// GlobalVariable *sh_mem_array =
// new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
// nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
// sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
// }
if(tgt_->is_gpu())
if(unsigned alloc_size = alloc_->get_allocated_size()){
Type *int_8_ty = Type::getInt8Ty(dst_ctx);
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
Type *ptr_ty = PointerType::get(int_8_ty, 3);
GlobalVariable *sh_mem_array =
new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
}
// create grids
init_grids(fn, dst_builder, sh_mem_ptr);

View File

@@ -246,7 +246,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
int num_threads = 1;
for(size_t k = 0; k < shapes.size(); k++)
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
if(num_threads % 1 != 0)
if(num_threads % 64 != 0)
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of 32");
if(num_threads != num_threads_)
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");

View File

@@ -48,15 +48,15 @@ void backend::platforms::init() {
if(dispatch::cuinit()){
cache_.push_back(new cu_platform());
}
// //if OpenCL is here
// if(dispatch::clinit()){
// cl_uint num_platforms;
// dispatch::clGetPlatformIDs(0, nullptr, &num_platforms);
// std::vector<cl_platform_id> ids(num_platforms);
// dispatch::clGetPlatformIDs(num_platforms, ids.data(), nullptr);
// for(cl_platform_id id: ids)
// cache_.push_back(new cl_platform(id));
// }
//if OpenCL is here
if(dispatch::clinit()){
cl_uint num_platforms;
dispatch::clGetPlatformIDs(0, nullptr, &num_platforms);
std::vector<cl_platform_id> ids(num_platforms);
dispatch::clGetPlatformIDs(num_platforms, ids.data(), nullptr);
for(cl_platform_id id: ids)
cache_.push_back(new cl_platform(id));
}
//if host is here
bool host_visible = true;
if(host_visible){

View File

@@ -179,10 +179,10 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
// create execution engine
llvm::legacy::PassManager pass;
pass.add(llvm::createPrintModulePass(llvm::outs()));
pass.add(llvm::createVerifierPass());
pass.run(*src);
// llvm::legacy::PassManager pass;
// pass.add(llvm::createPrintModulePass(llvm::outs()));
// pass.add(llvm::createVerifierPass());
// pass.run(*src);
auto cloned = llvm::CloneModule(*src);
for(llvm::Function& fn: cloned->functions())
hst_->functions[fn.getName()] = &fn;

View File

@@ -92,7 +92,7 @@ std::unique_ptr<ir::module> jit::make_triton_module(const std::string &src) {
}
jit::jit(driver::context *context): driver_context_(context), target_(new triton::codegen::cpu_target()) {
jit::jit(driver::context *context): driver_context_(context), target_(new triton::codegen::amd_cl_target()) {
}