[code generation] more flexibility in backend selection
This commit is contained in:
@@ -157,7 +157,7 @@ int main() {
|
||||
1, 8,
|
||||
1,
|
||||
};
|
||||
// params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4};
|
||||
params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4};
|
||||
|
||||
// jit.autotune(src, benchmark);
|
||||
jit.add_module(src, params);
|
||||
|
@@ -49,14 +49,17 @@ public:
|
||||
allocation(&liveness, &buffer_info),
|
||||
barriers(&allocation, &buffer_info),
|
||||
vectorize(&tune),
|
||||
selection(&allocation, &tune, &buffer_info, target) { }
|
||||
selection(&allocation, &tune, &buffer_info, target),
|
||||
target_(target) { }
|
||||
|
||||
void init(ir::module &module) {
|
||||
// buffer_info.run(module);
|
||||
// shared.run(module);
|
||||
// liveness.run(module);
|
||||
// allocation.run();
|
||||
// barriers.run(module);
|
||||
if(target_->is_gpu()){
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
}
|
||||
vectorize.run(module);
|
||||
// triton::ir::print(module, std::cout);
|
||||
}
|
||||
@@ -69,6 +72,7 @@ public:
|
||||
codegen::barriers barriers;
|
||||
codegen::vectorize vectorize;
|
||||
codegen::selection selection;
|
||||
codegen::target* target_;
|
||||
};
|
||||
|
||||
private:
|
||||
|
@@ -810,14 +810,17 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
for(ir::function *fn: src.get_function_list()) {
|
||||
// create LLVM function
|
||||
FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), dst_ctx);
|
||||
Type *dst_fn_ret_ty = fn_ty->getReturnType();
|
||||
std::vector<Type*> dst_fn_args_ty;
|
||||
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
|
||||
dst_fn_args_ty.push_back(fn_ty->getParamType(i));
|
||||
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
||||
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
||||
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
||||
FunctionType *dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
|
||||
FunctionType *dst_fn_ty = fn_ty;
|
||||
if(!tgt_->is_gpu()){
|
||||
Type *dst_fn_ret_ty = fn_ty->getReturnType();
|
||||
std::vector<Type*> dst_fn_args_ty;
|
||||
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
|
||||
dst_fn_args_ty.push_back(fn_ty->getParamType(i));
|
||||
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
||||
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
||||
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
||||
dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
|
||||
}
|
||||
// grid indices
|
||||
fn->get_fn_type()->get_return_ty();
|
||||
Function *dst_fn = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
|
||||
@@ -845,15 +848,16 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
|
||||
// allocate shared memory
|
||||
Value *sh_mem_ptr = nullptr;
|
||||
// if(unsigned alloc_size = alloc_->get_allocated_size()){
|
||||
// Type *int_8_ty = Type::getInt8Ty(dst_ctx);
|
||||
// ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
||||
// Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
||||
// GlobalVariable *sh_mem_array =
|
||||
// new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||
// nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
// sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
|
||||
// }
|
||||
if(tgt_->is_gpu())
|
||||
if(unsigned alloc_size = alloc_->get_allocated_size()){
|
||||
Type *int_8_ty = Type::getInt8Ty(dst_ctx);
|
||||
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
||||
Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
||||
GlobalVariable *sh_mem_array =
|
||||
new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
|
||||
// create grids
|
||||
init_grids(fn, dst_builder, sh_mem_ptr);
|
||||
|
@@ -246,7 +246,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
int num_threads = 1;
|
||||
for(size_t k = 0; k < shapes.size(); k++)
|
||||
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
|
||||
if(num_threads % 1 != 0)
|
||||
if(num_threads % 64 != 0)
|
||||
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of 32");
|
||||
if(num_threads != num_threads_)
|
||||
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
|
||||
|
@@ -48,15 +48,15 @@ void backend::platforms::init() {
|
||||
if(dispatch::cuinit()){
|
||||
cache_.push_back(new cu_platform());
|
||||
}
|
||||
// //if OpenCL is here
|
||||
// if(dispatch::clinit()){
|
||||
// cl_uint num_platforms;
|
||||
// dispatch::clGetPlatformIDs(0, nullptr, &num_platforms);
|
||||
// std::vector<cl_platform_id> ids(num_platforms);
|
||||
// dispatch::clGetPlatformIDs(num_platforms, ids.data(), nullptr);
|
||||
// for(cl_platform_id id: ids)
|
||||
// cache_.push_back(new cl_platform(id));
|
||||
// }
|
||||
//if OpenCL is here
|
||||
if(dispatch::clinit()){
|
||||
cl_uint num_platforms;
|
||||
dispatch::clGetPlatformIDs(0, nullptr, &num_platforms);
|
||||
std::vector<cl_platform_id> ids(num_platforms);
|
||||
dispatch::clGetPlatformIDs(num_platforms, ids.data(), nullptr);
|
||||
for(cl_platform_id id: ids)
|
||||
cache_.push_back(new cl_platform(id));
|
||||
}
|
||||
//if host is here
|
||||
bool host_visible = true;
|
||||
if(host_visible){
|
||||
|
@@ -179,10 +179,10 @@ host_module::host_module(driver::context * context, llvm::Module* src): module(c
|
||||
|
||||
|
||||
// create execution engine
|
||||
llvm::legacy::PassManager pass;
|
||||
pass.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
pass.add(llvm::createVerifierPass());
|
||||
pass.run(*src);
|
||||
// llvm::legacy::PassManager pass;
|
||||
// pass.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
// pass.add(llvm::createVerifierPass());
|
||||
// pass.run(*src);
|
||||
auto cloned = llvm::CloneModule(*src);
|
||||
for(llvm::Function& fn: cloned->functions())
|
||||
hst_->functions[fn.getName()] = &fn;
|
||||
|
@@ -92,7 +92,7 @@ std::unique_ptr<ir::module> jit::make_triton_module(const std::string &src) {
|
||||
}
|
||||
|
||||
|
||||
jit::jit(driver::context *context): driver_context_(context), target_(new triton::codegen::cpu_target()) {
|
||||
jit::jit(driver::context *context): driver_context_(context), target_(new triton::codegen::amd_cl_target()) {
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user