[LANG] Added support for device functions (#484)

This commit is contained in:
Philippe Tillet
2022-04-03 20:58:16 -07:00
committed by GitHub
parent e85c7a7fc7
commit 2bed6fc850
39 changed files with 1213 additions and 379 deletions

View File

@@ -608,6 +608,8 @@ void layouts::run(ir::module &mod) {
// create temporaries
size_t id = values_.size();
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
// std::cout << "layout: " << std::endl;
// i->print(std::cout);
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
id++;
ir::value *arg = red->get_operand(0);

View File

@@ -13,6 +13,7 @@
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/codegen/transform/prefetch.h"
#include "triton/codegen/transform/inline.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/print.h"
@@ -33,6 +34,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
// create passes
codegen::analysis::align align;
codegen::transform::inliner inliner;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
@@ -48,6 +50,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
// run passes
inliner.run(ir);
dce.run(ir);
peephole.run(ir);
dce.run(ir);

View File

@@ -13,6 +13,7 @@
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/type.h"
#include "triton/ir/utils.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
@@ -139,6 +140,14 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
* \brief Convert Triton-IR Type to LLVM-IR Type
*/
Type *generator::cvt(ir::type *ty) {
// struct
if(ty->is_struct_ty()){
std::vector<Type*> tys;
for(size_t i = 0; i < ty->get_struct_numel(); i++)
tys.push_back(cvt(ty->get_struct_type(i)));
return StructType::get(builder_->getContext(), tys, true);
}
// function
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
Type *ret_ty = cvt(tt->get_return_ty());
@@ -266,7 +275,8 @@ void generator::visit_value(ir::value* v) {
builder_->SetInsertPoint(&*current->getFirstNonPHI());
// visit user
if(auto *usr = dynamic_cast<ir::user*>(v)){
usr->accept(this);
if(!dynamic_cast<ir::function*>(usr))
usr->accept(this);
}
// revert insert point
if(phi && !current->empty() && current->getFirstNonPHI())
@@ -282,6 +292,81 @@ void generator::visit_phi_node(ir::phi_node* x) {
vals_[x][idx] = phi(ty, x->get_num_operands());
}
/**
* \brief Code Generation for `call`
*/
void generator::visit_call_inst(ir::call_inst* call) {
throw std::runtime_error("call not supported! Triton should be inlining everything.");
}
void generator::visit_launch_inst(ir::launch_inst *launch) {
ir::function* fn = (ir::function*)launch->get_operand(0);
// forward-declare cudaGetParameterBufferV2
std::vector<Type*> get_param_arg_tys = {PointerType::get(builder_->getInt8Ty(), 0),
ArrayType::get(builder_->getInt32Ty(), 3),
ArrayType::get(builder_->getInt32Ty(), 3),
builder_->getInt32Ty()};
FunctionType* get_param_ty = FunctionType::get(PointerType::get(builder_->getInt8Ty(), 0), get_param_arg_tys, false);
Function* get_param_buffer = Function::Create(get_param_ty, Function::ExternalLinkage, "cudaGetParameterBufferV2", mod_);
AllocaInst* grid = builder_->CreateAlloca(get_param_arg_tys[1]);
AllocaInst* block = builder_->CreateAlloca(get_param_arg_tys[2]);
ConstantInt* _0 = builder_->getInt32(0);
ConstantInt* _1 = builder_->getInt32(1);
ConstantInt* _2 = builder_->getInt32(2);
// create basic block
BasicBlock* launch_done_bb = BasicBlock::Create(builder_->getContext(), "launch_done", builder_->GetInsertBlock()->getParent());
BasicBlock* launch_bb = BasicBlock::Create(builder_->getContext(), "launch", launch_done_bb->getParent(), launch_done_bb);
Value *tid = tgt_->get_local_id(mod_, *builder_, 0);
Value *is_first_thread = builder_->CreateICmpEQ(tid, i32(0));
builder_->CreateCondBr(is_first_thread, launch_bb, launch_done_bb);
builder_->SetInsertPoint(launch_bb);
//
builder_->CreateStore(vals_[launch->get_grid()[0]][{}], builder_->CreateGEP(grid, {_0, _0}));
builder_->CreateStore(vals_[launch->get_grid()[1]][{}], builder_->CreateGEP(grid, {_0, _1}));
builder_->CreateStore(vals_[launch->get_grid()[2]][{}], builder_->CreateGEP(grid, {_0, _2}));
Value* num_warps = mul(builder_->getInt32(32), vals_[launch->get_num_warps()][{}]);
builder_->CreateStore(num_warps, builder_->CreateGEP(block, {_0, _0}));
builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _1}));
builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _2}));
Function* called_fn = fns_[fn];
Value* callee = ConstantExpr::getCast(Instruction::BitCast, called_fn, get_param_arg_tys[0]);
Value* arg_ptr = builder_->CreateCall(get_param_buffer, {callee, builder_->CreateLoad(grid), builder_->CreateLoad(block), builder_->getInt32(0)});
// forwrd-declare cudaLaunchDeviceV2
std::vector<Type*> launch_device_arg_tys = {get_param_ty->getReturnType(), builder_->getInt64Ty()};
FunctionType* launch_device_ty = FunctionType::get(builder_->getInt32Ty(), launch_device_arg_tys, false);
Function* launch_device = Function::Create(launch_device_ty, Function::ExternalLinkage, "cudaLaunchDeviceV2", mod_);
// TODO: add branch
Value* do_not_launch = builder_->CreateICmpEQ(builder_->CreatePtrToInt(arg_ptr, builder_->getInt64Ty()),
builder_->getInt64(0));
BasicBlock* launch2_bb = BasicBlock::Create(builder_->getContext(), "launch2", launch_done_bb->getParent(), launch_done_bb);
builder_->CreateCondBr(do_not_launch, launch_done_bb, launch2_bb);
builder_->SetInsertPoint(launch2_bb);
unsigned addr_space = arg_ptr->getType()->getPointerAddressSpace();
unsigned off = 0;
unsigned last_size = 0;
for(ir::value* arg: launch->get_values()){
Value* curr_arg = vals_[arg][{}];
Type* curr_arg_ty = curr_arg->getType();
// handle struct alignment
off += last_size;
unsigned size = curr_arg_ty->isPointerTy() ? 8 : curr_arg_ty->getPrimitiveSizeInBits() / 8;
off = (off + size - 1) / size * size;
// get pointer to current arg
Value* curr_arg_ptr = builder_->CreateGEP(arg_ptr, builder_->getInt32(off));
curr_arg_ptr = builder_->CreateBitCast(curr_arg_ptr, curr_arg_ty->getPointerTo(addr_space));
// store arg
builder_->CreateStore(curr_arg, curr_arg_ptr);
last_size = size;
}
builder_->CreateCall(launch_device, {arg_ptr, builder_->getInt64(0)});
builder_->CreateBr(launch_done_bb);
// done
builder_->SetInsertPoint(launch_done_bb);
}
/**
* \brief Code Generation for `binary_operator`
*/
@@ -311,6 +396,7 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
default: throw std::runtime_error("unreachable switch");
}
};
// x->print(std::cout);
for(indices_t idx: idxs_.at(x)){
Value *lhs = vals_[x->get_operand(0)][idx];
Value *rhs = vals_[x->get_operand(1)][idx];
@@ -852,6 +938,31 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) {
visit_store_inst(x);
}
// --
void generator::visit_extract_value_inst(ir::extract_value_inst *x) {
auto idxs = idxs_.at(x);
ir::value* agg = x->get_operand(0);
unsigned insert_idx = x->get_idx();
for(size_t i = 0; i < idxs.size(); i++){
auto idx = idxs[i];
vals_[x][idx] = builder_->CreateExtractValue(vals_[agg][idx], {insert_idx});
}
}
void generator::visit_insert_value_inst(ir::insert_value_inst *x){
auto idxs = idxs_.at(x);
ir::value* agg = x->get_operand(0);
ir::value* val = x->get_operand(1);
unsigned insert_idx = x->get_idx();
for(size_t i = 0; i < idxs.size(); i++){
auto idx = idxs[i];
vals_[x][idx] = builder_->CreateInsertValue(vals_[agg][idx], vals_[val][idx],{insert_idx});
}
}
// --
/**
* \brief Code Generation for `cat`
*/
@@ -2686,7 +2797,8 @@ void generator::visit_make_range(ir::make_range* x) {
}
void generator::visit_undef_value(ir::undef_value *x) {
Type* ty = cvt(x->get_type()->get_scalar_ty());
ir::type* sca_ty = x->get_type()->get_scalar_ty();
Type* ty = cvt(sca_ty);
for(indices_t idx: idxs_.at(x))
vals_[x][idx] = llvm::UndefValue::get(ty);
}
@@ -2713,8 +2825,7 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) {
}
void generator::visit_function(ir::function* fn) {
LLVMContext &ctx = builder_->getContext();
void generator::forward_declare(ir::function* fn){
FunctionType *fn_ty = (FunctionType*)cvt(fn->get_fn_type());
if(!tgt_->is_gpu()){
Type *fn_ret_ty = fn_ty->getReturnType();
@@ -2727,6 +2838,18 @@ void generator::visit_function(ir::function* fn) {
fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false);
}
Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_);
fns_[fn] = ret;
}
void generator::visit_function(ir::function* fn) {
idxs_.clear();
vals_.clear();
seen_.clear();
LLVMContext &ctx = builder_->getContext();
Function* ret = fns_[fn];
// set attributes
for(auto attr_pair: fn->attrs()){
unsigned id = attr_pair.first;
@@ -2751,7 +2874,8 @@ void generator::visit_function(ir::function* fn) {
for(unsigned i = 0; i < fn->args().size(); i++)
vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i);
// create blocks
for(ir::basic_block *block: fn->blocks()) {
auto blocks = ir::cfg::reverse_post_order(fn);
for(ir::basic_block *block: blocks) {
BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
bbs_[block] = dst_block;
}
@@ -2761,7 +2885,7 @@ void generator::visit_function(ir::function* fn) {
visit_layout(x.second);
}
// generate LLVM-IR code
for(ir::basic_block *block: fn->blocks())
for(ir::basic_block *block: blocks)
visit_basic_block(block);
// finalize
finalize_function(fn);
@@ -2982,10 +3106,12 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
}
void generator::visit_basic_block(ir::basic_block * block) {
BasicBlock *parent = bbs_[block];
builder_->SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list())
for(ir::instruction *i: block->get_inst_list()){
visit_value(i);
}
// Update ir bb -> llvm bb mapping
bbs_[block] = builder_->GetInsertBlock();
}
@@ -3168,6 +3294,12 @@ void generator::finalize_phi_node(ir::phi_node *x) {
}
}
StructType* generator::packed_type(ir::value* i){
Type* dtype = cvt(i->get_type()->get_tile_element_ty());
auto* layout = dynamic_cast<analysis::scanline_layout*>(layouts_->get(i));
assert(layout);
}
void generator::visit(ir::module &src, llvm::Module &dst) {
mod_ = &dst;
ctx_ = &dst.getContext();
@@ -3184,7 +3316,16 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
shmem_ = bit_cast(sh_mem_array, ptr_ty);
}
// instantiate device functions
// for(ir::function *fn: src.get_function_list())
// for(ir::basic_block *bb: fn->blocks())
// for(ir::instruction *i: bb->get_inst_list())
// if(auto *call = dynamic_cast<ir::call_inst*>(i)){
// std::cout << "call??" << std::endl;
// }
// visit functions
for(ir::function *fn: src.get_function_list())
forward_declare(fn);
for(ir::function *fn: src.get_function_list())
visit_function(fn);
}

View File

@@ -3,6 +3,7 @@
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
#include <iostream>
namespace triton {
namespace codegen{
@@ -28,6 +29,8 @@ void dce::run(ir::module &mod) {
case ir::INST_ATOMIC_CAS:
case ir::INST_ATOMIC_RMW:
case ir::INST_ATOMIC_EXCH:
case ir::INST_CALL:
case ir::INST_LAUNCH:
case ir::INST_BARRIER: {
work_list.push_back(i);
marked.insert(i);
@@ -65,6 +68,7 @@ void dce::run(ir::module &mod) {
}
}
// delete
for(ir::instruction* i: to_delete)
i->erase_from_parent();

View File

@@ -0,0 +1,127 @@
#include <iostream>
#include "triton/codegen/transform/inline.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace transform{
bool fncmp::operator()(ir::function* x, ir::function* y) const {
auto fn_list = x->get_parent()->get_function_list();
return std::find(fn_list.begin(), fn_list.end(), x) < std::find(fn_list.begin(), fn_list.end(), y);
};
void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder,
std::list<ir::call_inst*>& callsites){
ir::basic_block* parent_block = callsite->get_parent();
ir::function* parent_fn = parent_block->get_parent();
// the parent block is split into block A and block B:
// - block A (`new_blocks[0]`) is the entry block of the inlined function
// - block B (`exit`) resumes execution of the parent function
ir::basic_block* entry = parent_block->split_before(callsite, fn->get_name());
ir::basic_block* exit = entry->get_successors()[0];
std::vector<ir::basic_block*> new_blocks = {entry};
for(size_t i = 1; i < fn->blocks().size(); i++){
ir::basic_block* block = fn->blocks()[i];
ir::context& ctx = block->get_context();
const std::string& name = block->get_parent()->get_name() + "_" + block->get_name();
new_blocks.push_back(ir::basic_block::create(ctx, name, parent_fn));
}
// a phi node holds the return values of the inlined function
if(exit->get_inst_list().empty())
builder.set_insert_point(exit);
else
builder.set_insert_point(exit->get_first_non_phi());
ir::phi_node* exit_val = builder.create_phi(fn->get_fn_type()->get_return_ty(), 0);
callsite->replace_all_uses_with(exit_val);
callsite->erase_from_parent();
// get arguments `fn` is called with
std::vector<ir::value*> tgt_args(callsite->op_begin(), callsite->op_end());
std::vector<ir::argument*> src_args(fn->args().begin(), fn->args().end());
// Actually generate the instructions:
// - Remove the branch created by basic_block::split_before
// - Clone all instructions
// - Replace `ret` with incoming nodes to `exit_val` and branches to `exit`
ir::instruction* terminator = new_blocks[0]->get_inst_list().back();
// new_blocks[0]->get_inst_list().back()->erase_from_parent();
terminator->erase_from_parent();
std::map<ir::instruction*, ir::instruction*> inst_map;
std::map<ir::argument*, ir::value*> arg_map;
for(size_t k = 0; k < fn->args().size(); k++)
arg_map[fn->args()[k]] = callsite->ops()[k];
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
for(size_t i = 0; i < new_blocks.size(); i++){
ir::basic_block* old_block = fn->blocks()[i];
ir::basic_block* new_block = new_blocks[i];
builder.set_insert_point(new_block);
for(ir::instruction* old_inst: old_block->get_inst_list()){
// clone instruction
ir::instruction* new_inst = old_inst->clone();
// replace basic block
for(size_t k = 0; k < new_blocks.size(); k++)
new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]);
// replace values
for(size_t k = 0; k < new_inst->get_num_operands(); k++){
ir::value* op = new_inst->get_operand(k);
if(auto arg_op = dynamic_cast<ir::argument*>(op))
new_inst->set_operand(k, arg_map.at(arg_op));
if(auto inst_op = dynamic_cast<ir::instruction*>(op))
if(inst_map.find(inst_op) != inst_map.end())
new_inst->set_operand(k, inst_map.at(inst_op));
}
// `ret` instruction is a special case:
// instead of returning we need to branch to after the function call
if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)){
if(ir::value* ret_val = ret->get_return_value())
exit_val->add_incoming(ret_val, new_block);
new_inst = ir::branch_inst::create(exit);
}
inst_map[old_inst] = new_inst;
builder.insert(new_inst);
}
}
if(exit_val->get_num_incoming() == 1)
exit_val->replace_all_uses_with(exit_val->get_incoming_value(0));
// done -- make sure insert point is properly set to exit block
builder.set_insert_point(exit);
}
void inliner::run(ir::module &mod) {
// gather all call sites
while(true){
std::map<ir::function*, size_t> counts;
for(ir::function* fn: mod.get_function_list())
counts[fn] = 0;
std::list<ir::call_inst*> callsites;
for(ir::function* fn: mod.get_function_list()){
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* instr: block->get_inst_list())
if(ir::call_inst* call = dynamic_cast<ir::call_inst*>(instr)){
callsites.push_back(call);
counts[call->get_fn()] += 1;
}
}
for(auto& count: counts){
if(!count.first->get_is_kernel() && count.second == 0)
count.first->get_parent()->remove_function(count.first);
}
if(callsites.empty())
break;
for(ir::call_inst* call: callsites)
do_inline(call->get_fn(), call, mod.get_builder(), callsites);
}
}
}
}
}

View File

@@ -150,32 +150,53 @@ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
}
bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
auto binop = dynamic_cast<ir::binary_operator*>(value);
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
ir::value *lhs = binop->get_operand(0);
ir::value *rhs = binop->get_operand(1);
ir::constant_int *_1_lhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_lhs = cst;
}
ir::constant_int *_1_rhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_rhs = cst;
}
if(_1_lhs){
binop->replace_all_uses_with(rhs);
return true;
}
else if(_1_rhs){
binop->replace_all_uses_with(lhs);
return true;
}
auto binop = dynamic_cast<ir::binary_operator*>(value);
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
ir::value *lhs = binop->get_operand(0);
ir::value *rhs = binop->get_operand(1);
ir::constant_int *_1_lhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_lhs = cst;
}
ir::constant_int *_1_rhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_rhs = cst;
}
if(_1_lhs){
binop->replace_all_uses_with(rhs);
return true;
}
else if(_1_rhs){
binop->replace_all_uses_with(lhs);
return true;
}
}
return false;
}
bool peephole::rewrite_insert_extract(ir::instruction *value, ir::builder& builder){
auto extracted = dynamic_cast<ir::extract_value_inst*>(value);
if(!extracted)
return false;
size_t extract_idx = extracted->get_idx();
ir::value* agg = extracted->get_operand(0);
auto insert = dynamic_cast<ir::insert_value_inst*>(agg);
while(insert){
agg = insert->get_operand(0);
ir::value* inserted = insert->get_operand(1);
size_t insert_idx = insert->get_idx();
insert = dynamic_cast<ir::insert_value_inst*>(agg);
if(extract_idx == insert_idx){
extracted->replace_all_uses_with(inserted);
return true;
}
insert = dynamic_cast<ir::insert_value_inst*>(agg);
}
return false;
}
@@ -291,6 +312,7 @@ void peephole::run(ir::module &mod) {
was_modified = was_modified || rewrite_mult(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
// was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_insert_extract(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
// TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD

View File

@@ -134,6 +134,7 @@ void pipeline::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
const int num_stages = num_stages_;
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
for(auto info: to_pipeline){
ir::load_inst* load = info.load;
ir::phi_node* ptr = info.ptr;