[LANG] Added support for device functions (#484)
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||
@@ -139,6 +140,14 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
* \brief Convert Triton-IR Type to LLVM-IR Type
|
||||
*/
|
||||
Type *generator::cvt(ir::type *ty) {
|
||||
// struct
|
||||
if(ty->is_struct_ty()){
|
||||
std::vector<Type*> tys;
|
||||
for(size_t i = 0; i < ty->get_struct_numel(); i++)
|
||||
tys.push_back(cvt(ty->get_struct_type(i)));
|
||||
return StructType::get(builder_->getContext(), tys, true);
|
||||
}
|
||||
|
||||
// function
|
||||
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
||||
Type *ret_ty = cvt(tt->get_return_ty());
|
||||
@@ -266,7 +275,8 @@ void generator::visit_value(ir::value* v) {
|
||||
builder_->SetInsertPoint(&*current->getFirstNonPHI());
|
||||
// visit user
|
||||
if(auto *usr = dynamic_cast<ir::user*>(v)){
|
||||
usr->accept(this);
|
||||
if(!dynamic_cast<ir::function*>(usr))
|
||||
usr->accept(this);
|
||||
}
|
||||
// revert insert point
|
||||
if(phi && !current->empty() && current->getFirstNonPHI())
|
||||
@@ -282,6 +292,81 @@ void generator::visit_phi_node(ir::phi_node* x) {
|
||||
vals_[x][idx] = phi(ty, x->get_num_operands());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `call`
|
||||
*/
|
||||
void generator::visit_call_inst(ir::call_inst* call) {
|
||||
throw std::runtime_error("call not supported! Triton should be inlining everything.");
|
||||
}
|
||||
|
||||
void generator::visit_launch_inst(ir::launch_inst *launch) {
|
||||
ir::function* fn = (ir::function*)launch->get_operand(0);
|
||||
// forward-declare cudaGetParameterBufferV2
|
||||
std::vector<Type*> get_param_arg_tys = {PointerType::get(builder_->getInt8Ty(), 0),
|
||||
ArrayType::get(builder_->getInt32Ty(), 3),
|
||||
ArrayType::get(builder_->getInt32Ty(), 3),
|
||||
builder_->getInt32Ty()};
|
||||
FunctionType* get_param_ty = FunctionType::get(PointerType::get(builder_->getInt8Ty(), 0), get_param_arg_tys, false);
|
||||
Function* get_param_buffer = Function::Create(get_param_ty, Function::ExternalLinkage, "cudaGetParameterBufferV2", mod_);
|
||||
AllocaInst* grid = builder_->CreateAlloca(get_param_arg_tys[1]);
|
||||
AllocaInst* block = builder_->CreateAlloca(get_param_arg_tys[2]);
|
||||
ConstantInt* _0 = builder_->getInt32(0);
|
||||
ConstantInt* _1 = builder_->getInt32(1);
|
||||
ConstantInt* _2 = builder_->getInt32(2);
|
||||
// create basic block
|
||||
BasicBlock* launch_done_bb = BasicBlock::Create(builder_->getContext(), "launch_done", builder_->GetInsertBlock()->getParent());
|
||||
BasicBlock* launch_bb = BasicBlock::Create(builder_->getContext(), "launch", launch_done_bb->getParent(), launch_done_bb);
|
||||
Value *tid = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *is_first_thread = builder_->CreateICmpEQ(tid, i32(0));
|
||||
builder_->CreateCondBr(is_first_thread, launch_bb, launch_done_bb);
|
||||
builder_->SetInsertPoint(launch_bb);
|
||||
|
||||
//
|
||||
builder_->CreateStore(vals_[launch->get_grid()[0]][{}], builder_->CreateGEP(grid, {_0, _0}));
|
||||
builder_->CreateStore(vals_[launch->get_grid()[1]][{}], builder_->CreateGEP(grid, {_0, _1}));
|
||||
builder_->CreateStore(vals_[launch->get_grid()[2]][{}], builder_->CreateGEP(grid, {_0, _2}));
|
||||
Value* num_warps = mul(builder_->getInt32(32), vals_[launch->get_num_warps()][{}]);
|
||||
builder_->CreateStore(num_warps, builder_->CreateGEP(block, {_0, _0}));
|
||||
builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _1}));
|
||||
builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _2}));
|
||||
Function* called_fn = fns_[fn];
|
||||
Value* callee = ConstantExpr::getCast(Instruction::BitCast, called_fn, get_param_arg_tys[0]);
|
||||
Value* arg_ptr = builder_->CreateCall(get_param_buffer, {callee, builder_->CreateLoad(grid), builder_->CreateLoad(block), builder_->getInt32(0)});
|
||||
// forwrd-declare cudaLaunchDeviceV2
|
||||
std::vector<Type*> launch_device_arg_tys = {get_param_ty->getReturnType(), builder_->getInt64Ty()};
|
||||
FunctionType* launch_device_ty = FunctionType::get(builder_->getInt32Ty(), launch_device_arg_tys, false);
|
||||
Function* launch_device = Function::Create(launch_device_ty, Function::ExternalLinkage, "cudaLaunchDeviceV2", mod_);
|
||||
// TODO: add branch
|
||||
Value* do_not_launch = builder_->CreateICmpEQ(builder_->CreatePtrToInt(arg_ptr, builder_->getInt64Ty()),
|
||||
builder_->getInt64(0));
|
||||
BasicBlock* launch2_bb = BasicBlock::Create(builder_->getContext(), "launch2", launch_done_bb->getParent(), launch_done_bb);
|
||||
builder_->CreateCondBr(do_not_launch, launch_done_bb, launch2_bb);
|
||||
builder_->SetInsertPoint(launch2_bb);
|
||||
|
||||
unsigned addr_space = arg_ptr->getType()->getPointerAddressSpace();
|
||||
unsigned off = 0;
|
||||
unsigned last_size = 0;
|
||||
for(ir::value* arg: launch->get_values()){
|
||||
Value* curr_arg = vals_[arg][{}];
|
||||
Type* curr_arg_ty = curr_arg->getType();
|
||||
// handle struct alignment
|
||||
off += last_size;
|
||||
unsigned size = curr_arg_ty->isPointerTy() ? 8 : curr_arg_ty->getPrimitiveSizeInBits() / 8;
|
||||
off = (off + size - 1) / size * size;
|
||||
// get pointer to current arg
|
||||
Value* curr_arg_ptr = builder_->CreateGEP(arg_ptr, builder_->getInt32(off));
|
||||
curr_arg_ptr = builder_->CreateBitCast(curr_arg_ptr, curr_arg_ty->getPointerTo(addr_space));
|
||||
// store arg
|
||||
builder_->CreateStore(curr_arg, curr_arg_ptr);
|
||||
last_size = size;
|
||||
}
|
||||
builder_->CreateCall(launch_device, {arg_ptr, builder_->getInt64(0)});
|
||||
builder_->CreateBr(launch_done_bb);
|
||||
// done
|
||||
builder_->SetInsertPoint(launch_done_bb);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `binary_operator`
|
||||
*/
|
||||
@@ -311,6 +396,7 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
default: throw std::runtime_error("unreachable switch");
|
||||
}
|
||||
};
|
||||
// x->print(std::cout);
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
Value *lhs = vals_[x->get_operand(0)][idx];
|
||||
Value *rhs = vals_[x->get_operand(1)][idx];
|
||||
@@ -852,6 +938,31 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) {
|
||||
visit_store_inst(x);
|
||||
}
|
||||
|
||||
// --
|
||||
|
||||
void generator::visit_extract_value_inst(ir::extract_value_inst *x) {
|
||||
auto idxs = idxs_.at(x);
|
||||
ir::value* agg = x->get_operand(0);
|
||||
unsigned insert_idx = x->get_idx();
|
||||
for(size_t i = 0; i < idxs.size(); i++){
|
||||
auto idx = idxs[i];
|
||||
vals_[x][idx] = builder_->CreateExtractValue(vals_[agg][idx], {insert_idx});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void generator::visit_insert_value_inst(ir::insert_value_inst *x){
|
||||
auto idxs = idxs_.at(x);
|
||||
ir::value* agg = x->get_operand(0);
|
||||
ir::value* val = x->get_operand(1);
|
||||
unsigned insert_idx = x->get_idx();
|
||||
for(size_t i = 0; i < idxs.size(); i++){
|
||||
auto idx = idxs[i];
|
||||
vals_[x][idx] = builder_->CreateInsertValue(vals_[agg][idx], vals_[val][idx],{insert_idx});
|
||||
}
|
||||
}
|
||||
|
||||
// --
|
||||
/**
|
||||
* \brief Code Generation for `cat`
|
||||
*/
|
||||
@@ -2686,7 +2797,8 @@ void generator::visit_make_range(ir::make_range* x) {
|
||||
}
|
||||
|
||||
void generator::visit_undef_value(ir::undef_value *x) {
|
||||
Type* ty = cvt(x->get_type()->get_scalar_ty());
|
||||
ir::type* sca_ty = x->get_type()->get_scalar_ty();
|
||||
Type* ty = cvt(sca_ty);
|
||||
for(indices_t idx: idxs_.at(x))
|
||||
vals_[x][idx] = llvm::UndefValue::get(ty);
|
||||
}
|
||||
@@ -2713,8 +2825,7 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
||||
}
|
||||
|
||||
|
||||
void generator::visit_function(ir::function* fn) {
|
||||
LLVMContext &ctx = builder_->getContext();
|
||||
void generator::forward_declare(ir::function* fn){
|
||||
FunctionType *fn_ty = (FunctionType*)cvt(fn->get_fn_type());
|
||||
if(!tgt_->is_gpu()){
|
||||
Type *fn_ret_ty = fn_ty->getReturnType();
|
||||
@@ -2727,6 +2838,18 @@ void generator::visit_function(ir::function* fn) {
|
||||
fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false);
|
||||
}
|
||||
Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_);
|
||||
fns_[fn] = ret;
|
||||
}
|
||||
|
||||
void generator::visit_function(ir::function* fn) {
|
||||
idxs_.clear();
|
||||
vals_.clear();
|
||||
seen_.clear();
|
||||
LLVMContext &ctx = builder_->getContext();
|
||||
|
||||
Function* ret = fns_[fn];
|
||||
|
||||
|
||||
// set attributes
|
||||
for(auto attr_pair: fn->attrs()){
|
||||
unsigned id = attr_pair.first;
|
||||
@@ -2751,7 +2874,8 @@ void generator::visit_function(ir::function* fn) {
|
||||
for(unsigned i = 0; i < fn->args().size(); i++)
|
||||
vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i);
|
||||
// create blocks
|
||||
for(ir::basic_block *block: fn->blocks()) {
|
||||
auto blocks = ir::cfg::reverse_post_order(fn);
|
||||
for(ir::basic_block *block: blocks) {
|
||||
BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
|
||||
bbs_[block] = dst_block;
|
||||
}
|
||||
@@ -2761,7 +2885,7 @@ void generator::visit_function(ir::function* fn) {
|
||||
visit_layout(x.second);
|
||||
}
|
||||
// generate LLVM-IR code
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::basic_block *block: blocks)
|
||||
visit_basic_block(block);
|
||||
// finalize
|
||||
finalize_function(fn);
|
||||
@@ -2982,10 +3106,12 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
||||
}
|
||||
|
||||
void generator::visit_basic_block(ir::basic_block * block) {
|
||||
|
||||
BasicBlock *parent = bbs_[block];
|
||||
builder_->SetInsertPoint(parent);
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
visit_value(i);
|
||||
}
|
||||
// Update ir bb -> llvm bb mapping
|
||||
bbs_[block] = builder_->GetInsertBlock();
|
||||
}
|
||||
@@ -3168,6 +3294,12 @@ void generator::finalize_phi_node(ir::phi_node *x) {
|
||||
}
|
||||
}
|
||||
|
||||
StructType* generator::packed_type(ir::value* i){
|
||||
Type* dtype = cvt(i->get_type()->get_tile_element_ty());
|
||||
auto* layout = dynamic_cast<analysis::scanline_layout*>(layouts_->get(i));
|
||||
assert(layout);
|
||||
}
|
||||
|
||||
void generator::visit(ir::module &src, llvm::Module &dst) {
|
||||
mod_ = &dst;
|
||||
ctx_ = &dst.getContext();
|
||||
@@ -3184,7 +3316,16 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
shmem_ = bit_cast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
// instantiate device functions
|
||||
// for(ir::function *fn: src.get_function_list())
|
||||
// for(ir::basic_block *bb: fn->blocks())
|
||||
// for(ir::instruction *i: bb->get_inst_list())
|
||||
// if(auto *call = dynamic_cast<ir::call_inst*>(i)){
|
||||
// std::cout << "call??" << std::endl;
|
||||
// }
|
||||
// visit functions
|
||||
for(ir::function *fn: src.get_function_list())
|
||||
forward_declare(fn);
|
||||
for(ir::function *fn: src.get_function_list())
|
||||
visit_function(fn);
|
||||
}
|
||||
|
Reference in New Issue
Block a user