[LANG] Added support for device functions (#484)
This commit is contained in:
@@ -608,6 +608,8 @@ void layouts::run(ir::module &mod) {
|
||||
// create temporaries
|
||||
size_t id = values_.size();
|
||||
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
||||
// std::cout << "layout: " << std::endl;
|
||||
// i->print(std::cout);
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
|
||||
id++;
|
||||
ir::value *arg = red->get_operand(0);
|
||||
|
@@ -13,6 +13,7 @@
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
#include "triton/codegen/transform/prefetch.h"
|
||||
#include "triton/codegen/transform/inline.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/print.h"
|
||||
@@ -33,6 +34,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
|
||||
// create passes
|
||||
codegen::analysis::align align;
|
||||
codegen::transform::inliner inliner;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::transform::cts cts(cts_use_async);
|
||||
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
|
||||
@@ -48,6 +50,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
|
||||
// run passes
|
||||
inliner.run(ir);
|
||||
dce.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
|
@@ -13,6 +13,7 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||
@@ -139,6 +140,14 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
* \brief Convert Triton-IR Type to LLVM-IR Type
|
||||
*/
|
||||
Type *generator::cvt(ir::type *ty) {
|
||||
// struct
|
||||
if(ty->is_struct_ty()){
|
||||
std::vector<Type*> tys;
|
||||
for(size_t i = 0; i < ty->get_struct_numel(); i++)
|
||||
tys.push_back(cvt(ty->get_struct_type(i)));
|
||||
return StructType::get(builder_->getContext(), tys, true);
|
||||
}
|
||||
|
||||
// function
|
||||
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
||||
Type *ret_ty = cvt(tt->get_return_ty());
|
||||
@@ -266,7 +275,8 @@ void generator::visit_value(ir::value* v) {
|
||||
builder_->SetInsertPoint(&*current->getFirstNonPHI());
|
||||
// visit user
|
||||
if(auto *usr = dynamic_cast<ir::user*>(v)){
|
||||
usr->accept(this);
|
||||
if(!dynamic_cast<ir::function*>(usr))
|
||||
usr->accept(this);
|
||||
}
|
||||
// revert insert point
|
||||
if(phi && !current->empty() && current->getFirstNonPHI())
|
||||
@@ -282,6 +292,81 @@ void generator::visit_phi_node(ir::phi_node* x) {
|
||||
vals_[x][idx] = phi(ty, x->get_num_operands());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `call`
|
||||
*/
|
||||
void generator::visit_call_inst(ir::call_inst* call) {
|
||||
throw std::runtime_error("call not supported! Triton should be inlining everything.");
|
||||
}
|
||||
|
||||
void generator::visit_launch_inst(ir::launch_inst *launch) {
|
||||
ir::function* fn = (ir::function*)launch->get_operand(0);
|
||||
// forward-declare cudaGetParameterBufferV2
|
||||
std::vector<Type*> get_param_arg_tys = {PointerType::get(builder_->getInt8Ty(), 0),
|
||||
ArrayType::get(builder_->getInt32Ty(), 3),
|
||||
ArrayType::get(builder_->getInt32Ty(), 3),
|
||||
builder_->getInt32Ty()};
|
||||
FunctionType* get_param_ty = FunctionType::get(PointerType::get(builder_->getInt8Ty(), 0), get_param_arg_tys, false);
|
||||
Function* get_param_buffer = Function::Create(get_param_ty, Function::ExternalLinkage, "cudaGetParameterBufferV2", mod_);
|
||||
AllocaInst* grid = builder_->CreateAlloca(get_param_arg_tys[1]);
|
||||
AllocaInst* block = builder_->CreateAlloca(get_param_arg_tys[2]);
|
||||
ConstantInt* _0 = builder_->getInt32(0);
|
||||
ConstantInt* _1 = builder_->getInt32(1);
|
||||
ConstantInt* _2 = builder_->getInt32(2);
|
||||
// create basic block
|
||||
BasicBlock* launch_done_bb = BasicBlock::Create(builder_->getContext(), "launch_done", builder_->GetInsertBlock()->getParent());
|
||||
BasicBlock* launch_bb = BasicBlock::Create(builder_->getContext(), "launch", launch_done_bb->getParent(), launch_done_bb);
|
||||
Value *tid = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *is_first_thread = builder_->CreateICmpEQ(tid, i32(0));
|
||||
builder_->CreateCondBr(is_first_thread, launch_bb, launch_done_bb);
|
||||
builder_->SetInsertPoint(launch_bb);
|
||||
|
||||
//
|
||||
builder_->CreateStore(vals_[launch->get_grid()[0]][{}], builder_->CreateGEP(grid, {_0, _0}));
|
||||
builder_->CreateStore(vals_[launch->get_grid()[1]][{}], builder_->CreateGEP(grid, {_0, _1}));
|
||||
builder_->CreateStore(vals_[launch->get_grid()[2]][{}], builder_->CreateGEP(grid, {_0, _2}));
|
||||
Value* num_warps = mul(builder_->getInt32(32), vals_[launch->get_num_warps()][{}]);
|
||||
builder_->CreateStore(num_warps, builder_->CreateGEP(block, {_0, _0}));
|
||||
builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _1}));
|
||||
builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _2}));
|
||||
Function* called_fn = fns_[fn];
|
||||
Value* callee = ConstantExpr::getCast(Instruction::BitCast, called_fn, get_param_arg_tys[0]);
|
||||
Value* arg_ptr = builder_->CreateCall(get_param_buffer, {callee, builder_->CreateLoad(grid), builder_->CreateLoad(block), builder_->getInt32(0)});
|
||||
// forwrd-declare cudaLaunchDeviceV2
|
||||
std::vector<Type*> launch_device_arg_tys = {get_param_ty->getReturnType(), builder_->getInt64Ty()};
|
||||
FunctionType* launch_device_ty = FunctionType::get(builder_->getInt32Ty(), launch_device_arg_tys, false);
|
||||
Function* launch_device = Function::Create(launch_device_ty, Function::ExternalLinkage, "cudaLaunchDeviceV2", mod_);
|
||||
// TODO: add branch
|
||||
Value* do_not_launch = builder_->CreateICmpEQ(builder_->CreatePtrToInt(arg_ptr, builder_->getInt64Ty()),
|
||||
builder_->getInt64(0));
|
||||
BasicBlock* launch2_bb = BasicBlock::Create(builder_->getContext(), "launch2", launch_done_bb->getParent(), launch_done_bb);
|
||||
builder_->CreateCondBr(do_not_launch, launch_done_bb, launch2_bb);
|
||||
builder_->SetInsertPoint(launch2_bb);
|
||||
|
||||
unsigned addr_space = arg_ptr->getType()->getPointerAddressSpace();
|
||||
unsigned off = 0;
|
||||
unsigned last_size = 0;
|
||||
for(ir::value* arg: launch->get_values()){
|
||||
Value* curr_arg = vals_[arg][{}];
|
||||
Type* curr_arg_ty = curr_arg->getType();
|
||||
// handle struct alignment
|
||||
off += last_size;
|
||||
unsigned size = curr_arg_ty->isPointerTy() ? 8 : curr_arg_ty->getPrimitiveSizeInBits() / 8;
|
||||
off = (off + size - 1) / size * size;
|
||||
// get pointer to current arg
|
||||
Value* curr_arg_ptr = builder_->CreateGEP(arg_ptr, builder_->getInt32(off));
|
||||
curr_arg_ptr = builder_->CreateBitCast(curr_arg_ptr, curr_arg_ty->getPointerTo(addr_space));
|
||||
// store arg
|
||||
builder_->CreateStore(curr_arg, curr_arg_ptr);
|
||||
last_size = size;
|
||||
}
|
||||
builder_->CreateCall(launch_device, {arg_ptr, builder_->getInt64(0)});
|
||||
builder_->CreateBr(launch_done_bb);
|
||||
// done
|
||||
builder_->SetInsertPoint(launch_done_bb);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `binary_operator`
|
||||
*/
|
||||
@@ -311,6 +396,7 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
default: throw std::runtime_error("unreachable switch");
|
||||
}
|
||||
};
|
||||
// x->print(std::cout);
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
Value *lhs = vals_[x->get_operand(0)][idx];
|
||||
Value *rhs = vals_[x->get_operand(1)][idx];
|
||||
@@ -852,6 +938,31 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) {
|
||||
visit_store_inst(x);
|
||||
}
|
||||
|
||||
// --
|
||||
|
||||
void generator::visit_extract_value_inst(ir::extract_value_inst *x) {
|
||||
auto idxs = idxs_.at(x);
|
||||
ir::value* agg = x->get_operand(0);
|
||||
unsigned insert_idx = x->get_idx();
|
||||
for(size_t i = 0; i < idxs.size(); i++){
|
||||
auto idx = idxs[i];
|
||||
vals_[x][idx] = builder_->CreateExtractValue(vals_[agg][idx], {insert_idx});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void generator::visit_insert_value_inst(ir::insert_value_inst *x){
|
||||
auto idxs = idxs_.at(x);
|
||||
ir::value* agg = x->get_operand(0);
|
||||
ir::value* val = x->get_operand(1);
|
||||
unsigned insert_idx = x->get_idx();
|
||||
for(size_t i = 0; i < idxs.size(); i++){
|
||||
auto idx = idxs[i];
|
||||
vals_[x][idx] = builder_->CreateInsertValue(vals_[agg][idx], vals_[val][idx],{insert_idx});
|
||||
}
|
||||
}
|
||||
|
||||
// --
|
||||
/**
|
||||
* \brief Code Generation for `cat`
|
||||
*/
|
||||
@@ -2686,7 +2797,8 @@ void generator::visit_make_range(ir::make_range* x) {
|
||||
}
|
||||
|
||||
void generator::visit_undef_value(ir::undef_value *x) {
|
||||
Type* ty = cvt(x->get_type()->get_scalar_ty());
|
||||
ir::type* sca_ty = x->get_type()->get_scalar_ty();
|
||||
Type* ty = cvt(sca_ty);
|
||||
for(indices_t idx: idxs_.at(x))
|
||||
vals_[x][idx] = llvm::UndefValue::get(ty);
|
||||
}
|
||||
@@ -2713,8 +2825,7 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
||||
}
|
||||
|
||||
|
||||
void generator::visit_function(ir::function* fn) {
|
||||
LLVMContext &ctx = builder_->getContext();
|
||||
void generator::forward_declare(ir::function* fn){
|
||||
FunctionType *fn_ty = (FunctionType*)cvt(fn->get_fn_type());
|
||||
if(!tgt_->is_gpu()){
|
||||
Type *fn_ret_ty = fn_ty->getReturnType();
|
||||
@@ -2727,6 +2838,18 @@ void generator::visit_function(ir::function* fn) {
|
||||
fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false);
|
||||
}
|
||||
Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_);
|
||||
fns_[fn] = ret;
|
||||
}
|
||||
|
||||
void generator::visit_function(ir::function* fn) {
|
||||
idxs_.clear();
|
||||
vals_.clear();
|
||||
seen_.clear();
|
||||
LLVMContext &ctx = builder_->getContext();
|
||||
|
||||
Function* ret = fns_[fn];
|
||||
|
||||
|
||||
// set attributes
|
||||
for(auto attr_pair: fn->attrs()){
|
||||
unsigned id = attr_pair.first;
|
||||
@@ -2751,7 +2874,8 @@ void generator::visit_function(ir::function* fn) {
|
||||
for(unsigned i = 0; i < fn->args().size(); i++)
|
||||
vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i);
|
||||
// create blocks
|
||||
for(ir::basic_block *block: fn->blocks()) {
|
||||
auto blocks = ir::cfg::reverse_post_order(fn);
|
||||
for(ir::basic_block *block: blocks) {
|
||||
BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
|
||||
bbs_[block] = dst_block;
|
||||
}
|
||||
@@ -2761,7 +2885,7 @@ void generator::visit_function(ir::function* fn) {
|
||||
visit_layout(x.second);
|
||||
}
|
||||
// generate LLVM-IR code
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::basic_block *block: blocks)
|
||||
visit_basic_block(block);
|
||||
// finalize
|
||||
finalize_function(fn);
|
||||
@@ -2982,10 +3106,12 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
||||
}
|
||||
|
||||
void generator::visit_basic_block(ir::basic_block * block) {
|
||||
|
||||
BasicBlock *parent = bbs_[block];
|
||||
builder_->SetInsertPoint(parent);
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
visit_value(i);
|
||||
}
|
||||
// Update ir bb -> llvm bb mapping
|
||||
bbs_[block] = builder_->GetInsertBlock();
|
||||
}
|
||||
@@ -3168,6 +3294,12 @@ void generator::finalize_phi_node(ir::phi_node *x) {
|
||||
}
|
||||
}
|
||||
|
||||
StructType* generator::packed_type(ir::value* i){
|
||||
Type* dtype = cvt(i->get_type()->get_tile_element_ty());
|
||||
auto* layout = dynamic_cast<analysis::scanline_layout*>(layouts_->get(i));
|
||||
assert(layout);
|
||||
}
|
||||
|
||||
void generator::visit(ir::module &src, llvm::Module &dst) {
|
||||
mod_ = &dst;
|
||||
ctx_ = &dst.getContext();
|
||||
@@ -3184,7 +3316,16 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
shmem_ = bit_cast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
// instantiate device functions
|
||||
// for(ir::function *fn: src.get_function_list())
|
||||
// for(ir::basic_block *bb: fn->blocks())
|
||||
// for(ir::instruction *i: bb->get_inst_list())
|
||||
// if(auto *call = dynamic_cast<ir::call_inst*>(i)){
|
||||
// std::cout << "call??" << std::endl;
|
||||
// }
|
||||
// visit functions
|
||||
for(ir::function *fn: src.get_function_list())
|
||||
forward_declare(fn);
|
||||
for(ir::function *fn: src.get_function_list())
|
||||
visit_function(fn);
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
@@ -28,6 +29,8 @@ void dce::run(ir::module &mod) {
|
||||
case ir::INST_ATOMIC_CAS:
|
||||
case ir::INST_ATOMIC_RMW:
|
||||
case ir::INST_ATOMIC_EXCH:
|
||||
case ir::INST_CALL:
|
||||
case ir::INST_LAUNCH:
|
||||
case ir::INST_BARRIER: {
|
||||
work_list.push_back(i);
|
||||
marked.insert(i);
|
||||
@@ -65,6 +68,7 @@ void dce::run(ir::module &mod) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// delete
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
|
127
lib/codegen/transform/inline.cc
Normal file
127
lib/codegen/transform/inline.cc
Normal file
@@ -0,0 +1,127 @@
|
||||
#include <iostream>
|
||||
#include "triton/codegen/transform/inline.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
bool fncmp::operator()(ir::function* x, ir::function* y) const {
|
||||
auto fn_list = x->get_parent()->get_function_list();
|
||||
return std::find(fn_list.begin(), fn_list.end(), x) < std::find(fn_list.begin(), fn_list.end(), y);
|
||||
};
|
||||
|
||||
void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder,
|
||||
std::list<ir::call_inst*>& callsites){
|
||||
ir::basic_block* parent_block = callsite->get_parent();
|
||||
ir::function* parent_fn = parent_block->get_parent();
|
||||
// the parent block is split into block A and block B:
|
||||
// - block A (`new_blocks[0]`) is the entry block of the inlined function
|
||||
// - block B (`exit`) resumes execution of the parent function
|
||||
ir::basic_block* entry = parent_block->split_before(callsite, fn->get_name());
|
||||
ir::basic_block* exit = entry->get_successors()[0];
|
||||
std::vector<ir::basic_block*> new_blocks = {entry};
|
||||
for(size_t i = 1; i < fn->blocks().size(); i++){
|
||||
ir::basic_block* block = fn->blocks()[i];
|
||||
ir::context& ctx = block->get_context();
|
||||
const std::string& name = block->get_parent()->get_name() + "_" + block->get_name();
|
||||
new_blocks.push_back(ir::basic_block::create(ctx, name, parent_fn));
|
||||
}
|
||||
// a phi node holds the return values of the inlined function
|
||||
if(exit->get_inst_list().empty())
|
||||
builder.set_insert_point(exit);
|
||||
else
|
||||
builder.set_insert_point(exit->get_first_non_phi());
|
||||
ir::phi_node* exit_val = builder.create_phi(fn->get_fn_type()->get_return_ty(), 0);
|
||||
callsite->replace_all_uses_with(exit_val);
|
||||
callsite->erase_from_parent();
|
||||
// get arguments `fn` is called with
|
||||
std::vector<ir::value*> tgt_args(callsite->op_begin(), callsite->op_end());
|
||||
std::vector<ir::argument*> src_args(fn->args().begin(), fn->args().end());
|
||||
// Actually generate the instructions:
|
||||
// - Remove the branch created by basic_block::split_before
|
||||
// - Clone all instructions
|
||||
// - Replace `ret` with incoming nodes to `exit_val` and branches to `exit`
|
||||
ir::instruction* terminator = new_blocks[0]->get_inst_list().back();
|
||||
// new_blocks[0]->get_inst_list().back()->erase_from_parent();
|
||||
terminator->erase_from_parent();
|
||||
std::map<ir::instruction*, ir::instruction*> inst_map;
|
||||
std::map<ir::argument*, ir::value*> arg_map;
|
||||
for(size_t k = 0; k < fn->args().size(); k++)
|
||||
arg_map[fn->args()[k]] = callsite->ops()[k];
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
for(size_t i = 0; i < new_blocks.size(); i++){
|
||||
ir::basic_block* old_block = fn->blocks()[i];
|
||||
ir::basic_block* new_block = new_blocks[i];
|
||||
builder.set_insert_point(new_block);
|
||||
for(ir::instruction* old_inst: old_block->get_inst_list()){
|
||||
// clone instruction
|
||||
ir::instruction* new_inst = old_inst->clone();
|
||||
// replace basic block
|
||||
for(size_t k = 0; k < new_blocks.size(); k++)
|
||||
new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]);
|
||||
// replace values
|
||||
for(size_t k = 0; k < new_inst->get_num_operands(); k++){
|
||||
ir::value* op = new_inst->get_operand(k);
|
||||
if(auto arg_op = dynamic_cast<ir::argument*>(op))
|
||||
new_inst->set_operand(k, arg_map.at(arg_op));
|
||||
if(auto inst_op = dynamic_cast<ir::instruction*>(op))
|
||||
if(inst_map.find(inst_op) != inst_map.end())
|
||||
new_inst->set_operand(k, inst_map.at(inst_op));
|
||||
}
|
||||
// `ret` instruction is a special case:
|
||||
// instead of returning we need to branch to after the function call
|
||||
if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)){
|
||||
if(ir::value* ret_val = ret->get_return_value())
|
||||
exit_val->add_incoming(ret_val, new_block);
|
||||
new_inst = ir::branch_inst::create(exit);
|
||||
}
|
||||
inst_map[old_inst] = new_inst;
|
||||
builder.insert(new_inst);
|
||||
}
|
||||
}
|
||||
if(exit_val->get_num_incoming() == 1)
|
||||
exit_val->replace_all_uses_with(exit_val->get_incoming_value(0));
|
||||
// done -- make sure insert point is properly set to exit block
|
||||
builder.set_insert_point(exit);
|
||||
}
|
||||
|
||||
void inliner::run(ir::module &mod) {
|
||||
|
||||
// gather all call sites
|
||||
while(true){
|
||||
std::map<ir::function*, size_t> counts;
|
||||
for(ir::function* fn: mod.get_function_list())
|
||||
counts[fn] = 0;
|
||||
|
||||
std::list<ir::call_inst*> callsites;
|
||||
for(ir::function* fn: mod.get_function_list()){
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
for(ir::instruction* instr: block->get_inst_list())
|
||||
if(ir::call_inst* call = dynamic_cast<ir::call_inst*>(instr)){
|
||||
callsites.push_back(call);
|
||||
counts[call->get_fn()] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for(auto& count: counts){
|
||||
if(!count.first->get_is_kernel() && count.second == 0)
|
||||
count.first->get_parent()->remove_function(count.first);
|
||||
}
|
||||
|
||||
if(callsites.empty())
|
||||
break;
|
||||
|
||||
for(ir::call_inst* call: callsites)
|
||||
do_inline(call->get_fn(), call, mod.get_builder(), callsites);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -150,32 +150,53 @@ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
}
|
||||
|
||||
bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
|
||||
auto binop = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
|
||||
ir::value *lhs = binop->get_operand(0);
|
||||
ir::value *rhs = binop->get_operand(1);
|
||||
ir::constant_int *_1_lhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_lhs = cst;
|
||||
}
|
||||
ir::constant_int *_1_rhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_rhs = cst;
|
||||
}
|
||||
if(_1_lhs){
|
||||
binop->replace_all_uses_with(rhs);
|
||||
return true;
|
||||
}
|
||||
else if(_1_rhs){
|
||||
binop->replace_all_uses_with(lhs);
|
||||
return true;
|
||||
}
|
||||
auto binop = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
|
||||
ir::value *lhs = binop->get_operand(0);
|
||||
ir::value *rhs = binop->get_operand(1);
|
||||
ir::constant_int *_1_lhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_lhs = cst;
|
||||
}
|
||||
ir::constant_int *_1_rhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_rhs = cst;
|
||||
}
|
||||
if(_1_lhs){
|
||||
binop->replace_all_uses_with(rhs);
|
||||
return true;
|
||||
}
|
||||
else if(_1_rhs){
|
||||
binop->replace_all_uses_with(lhs);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_insert_extract(ir::instruction *value, ir::builder& builder){
|
||||
auto extracted = dynamic_cast<ir::extract_value_inst*>(value);
|
||||
if(!extracted)
|
||||
return false;
|
||||
size_t extract_idx = extracted->get_idx();
|
||||
ir::value* agg = extracted->get_operand(0);
|
||||
auto insert = dynamic_cast<ir::insert_value_inst*>(agg);
|
||||
while(insert){
|
||||
agg = insert->get_operand(0);
|
||||
ir::value* inserted = insert->get_operand(1);
|
||||
size_t insert_idx = insert->get_idx();
|
||||
insert = dynamic_cast<ir::insert_value_inst*>(agg);
|
||||
if(extract_idx == insert_idx){
|
||||
extracted->replace_all_uses_with(inserted);
|
||||
return true;
|
||||
}
|
||||
insert = dynamic_cast<ir::insert_value_inst*>(agg);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -291,6 +312,7 @@ void peephole::run(ir::module &mod) {
|
||||
was_modified = was_modified || rewrite_mult(i, builder);
|
||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
// was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_insert_extract(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
// TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD
|
||||
|
@@ -134,6 +134,7 @@ void pipeline::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
const int num_stages = num_stages_;
|
||||
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
|
||||
|
||||
for(auto info: to_pipeline){
|
||||
ir::load_inst* load = info.load;
|
||||
ir::phi_node* ptr = info.ptr;
|
||||
|
@@ -138,6 +138,7 @@ CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice
|
||||
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*)
|
||||
|
||||
// link management
|
||||
CUDA_DEFINE6(CUresult, cuLinkAddFile_v2, CUlinkState, CUjitInputType, const char *, unsigned int , CUjit_option *, void **);
|
||||
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
|
||||
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*);
|
||||
CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState);
|
||||
|
@@ -90,7 +90,7 @@ void check(CUresult err)
|
||||
case CUDA_ERROR_NOT_PERMITTED : throw not_permitted();
|
||||
case CUDA_ERROR_NOT_SUPPORTED : throw not_supported();
|
||||
case CUDA_ERROR_UNKNOWN : throw unknown();
|
||||
default : throw unknown();
|
||||
default : throw std::runtime_error("unimplemented code: " + std::to_string(err));
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -174,6 +174,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
init_llvm();
|
||||
// verify and store llvm
|
||||
llvm::legacy::PassManager pm;
|
||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
pm.add(llvm::createVerifierPass());
|
||||
pm.run(*module);
|
||||
// module->print(llvm::outs(), nullptr);
|
||||
@@ -213,6 +214,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) {
|
||||
// compile ptx with ptxas
|
||||
char _fsrc[L_tmpnam];
|
||||
|
@@ -1,3 +1,5 @@
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
@@ -9,23 +11,68 @@ namespace ir {
|
||||
class phi_node;
|
||||
|
||||
|
||||
basic_block::basic_block(context &ctx, const std::string &name, function *parent):
|
||||
basic_block::basic_block(context &ctx, const std::string &name, function *parent, basic_block* next):
|
||||
value(type::get_label_ty(ctx), name), ctx_(ctx), parent_(parent) {
|
||||
if(parent_)
|
||||
parent_->insert_block(this);
|
||||
parent_->insert_block(this, next);
|
||||
}
|
||||
|
||||
basic_block* basic_block::create(context &ctx, const std::string &name, function *parent){
|
||||
return new basic_block(ctx, name, parent);
|
||||
basic_block* basic_block::create(context &ctx, const std::string &name, function *parent, basic_block* next){
|
||||
return new basic_block(ctx, name, parent, next);
|
||||
}
|
||||
|
||||
void basic_block::add_predecessor(basic_block *pred) {
|
||||
preds_.push_back(pred);
|
||||
if(pred)
|
||||
pred->succs_.push_back(this);
|
||||
void basic_block::replace_phi_uses_with(basic_block* before, basic_block* after) {
|
||||
for(ir::instruction* i: inst_list_){
|
||||
auto* curr_phi = dynamic_cast<ir::phi_node*>(i);
|
||||
if(!curr_phi)
|
||||
break;
|
||||
curr_phi->replace_uses_of_with(before, after);
|
||||
}
|
||||
}
|
||||
|
||||
void basic_block::append_instruction(ir::instruction* i){
|
||||
i->set_parent(this);
|
||||
inst_list_.push_back(i);
|
||||
}
|
||||
|
||||
basic_block* basic_block::split_before(ir::instruction* loc, const std::string& name) {
|
||||
basic_block* ret = basic_block::create(ctx_, name, parent_, this);
|
||||
ret->set_name(get_name());
|
||||
set_name("after_" + name);
|
||||
|
||||
// splice instruction list
|
||||
auto loc_it = std::find(inst_list_.begin(), inst_list_.end(), loc);
|
||||
ret->get_inst_list().splice(ret->get_inst_list().begin(), inst_list_, inst_list_.begin(), loc_it);
|
||||
for(ir::instruction* i: ret->get_inst_list())
|
||||
i->set_parent(ret);
|
||||
// the predecessors of `this` becomes the predecessors of `ret`
|
||||
for(ir::basic_block* pred: get_predecessors()){
|
||||
auto* term = dynamic_cast<ir::terminator_inst*>(pred->get_inst_list().back());
|
||||
assert(term);
|
||||
term->replace_uses_of_with(this, ret);
|
||||
replace_phi_uses_with(pred, ret);
|
||||
}
|
||||
ir::branch_inst* br = branch_inst::create(this);
|
||||
ret->append_instruction(br);
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<basic_block*> basic_block::get_predecessors() const {
|
||||
std::vector<basic_block*> ret;
|
||||
for(ir::user* u: users_)
|
||||
if(auto term = dynamic_cast<ir::terminator_inst*>(u))
|
||||
ret.push_back(term->get_parent());
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<basic_block*> basic_block::get_successors() const {
|
||||
std::vector<basic_block*> ret;
|
||||
for(ir::instruction* i: inst_list_)
|
||||
for(ir::value* v: i->ops())
|
||||
if(auto block = dynamic_cast<ir::basic_block*>(v))
|
||||
ret.push_back(block);
|
||||
return ret;
|
||||
}
|
||||
|
||||
basic_block::iterator basic_block::get_first_non_phi(){
|
||||
auto it = begin();
|
||||
|
@@ -117,13 +117,10 @@ type *builder::get_double_ty()
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value* builder::create_br(basic_block *dest){
|
||||
dest->add_predecessor(block_);
|
||||
return insert(branch_inst::create(dest));
|
||||
}
|
||||
|
||||
value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){
|
||||
if_dest->add_predecessor(block_);
|
||||
else_dest->add_predecessor(block_);
|
||||
return insert(branch_inst::create(cond, if_dest, else_dest));
|
||||
}
|
||||
|
||||
@@ -131,6 +128,10 @@ value *builder::create_ret_void() {
|
||||
return insert(return_inst::create(ctx_));
|
||||
}
|
||||
|
||||
value *builder::create_ret(value* val) {
|
||||
return insert(return_inst::create(ctx_, val));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -163,6 +164,19 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved){
|
||||
return insert(phi_node::create(ty, num_reserved));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// call instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_call(function* fn, const std::vector<value*>& args){
|
||||
return insert(call_inst::create(fn, args));
|
||||
}
|
||||
|
||||
value* builder::create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps){
|
||||
return insert(launch_inst::create(fn, args, grid, num_warps));
|
||||
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary float instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -307,6 +321,19 @@ value *builder::create_masked_store(value *ptr, value *val, value *mask){
|
||||
return insert(masked_store_inst::create(ptr, val, mask));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// struct instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
// Struct instructions
|
||||
value *builder::create_insert_value(value* val, value *elt, size_t idx){
|
||||
return insert(insert_value_inst::create(val, elt, idx));
|
||||
}
|
||||
|
||||
value *builder::create_extract_value(value* val, size_t idx) {
|
||||
return insert(extract_value_inst::create(val, idx));
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// block instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -865,6 +865,9 @@ ir::value *dispatch::clock(ir::builder *builder) {
|
||||
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Control FLow
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
|
||||
ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){
|
||||
|
@@ -33,8 +33,10 @@ void argument::accept(visitor *v) {
|
||||
/* function */
|
||||
function::function(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name, module *parent)
|
||||
: global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty) {
|
||||
: global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty), is_kernel_(false) {
|
||||
unsigned num_params = fn_ty_->get_num_params();
|
||||
if(parent)
|
||||
parent->push_function(this);
|
||||
// skip if no parameter
|
||||
if(num_params == 0)
|
||||
return;
|
||||
@@ -44,8 +46,6 @@ function::function(function_type *ty, linkage_types_t linkage,
|
||||
type *param_ty = fn_ty_->get_param_ty(i);
|
||||
args_[i] = argument::create(param_ty, "", this, i);
|
||||
}
|
||||
if(parent)
|
||||
parent->push_function(this);
|
||||
}
|
||||
|
||||
/* basic block */
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/function.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
@@ -79,6 +80,70 @@ phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &n
|
||||
return new phi_node(ty, num_reserved, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// call_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::string call_inst::repr_impl() const { return "call " + fn_->get_name(); }
|
||||
|
||||
call_inst::call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next)
|
||||
: instruction(fn->get_fn_type()->get_return_ty(), INST_CALL, values.size(), name, next), fn_(fn){
|
||||
for(size_t i = 0; i < values.size(); i++)
|
||||
set_operand(i, values.at(i));
|
||||
}
|
||||
|
||||
call_inst* call_inst::create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name, instruction *next) {
|
||||
return new call_inst(fn, values, name, next);
|
||||
}
|
||||
|
||||
|
||||
// launch
|
||||
|
||||
launch_inst::launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps, const std::string& name, instruction* next)
|
||||
: instruction(fn->get_fn_type()->get_return_ty(), INST_LAUNCH, 1 + values.size() + grid.size() + 1, name, next){
|
||||
int k = 0;
|
||||
if(grid.size() != 3)
|
||||
throw std::runtime_error("grid must have 3 elements");
|
||||
set_operand(k++, fn);
|
||||
val_begin = k;
|
||||
for(ir::value* v: values)
|
||||
set_operand(k++, v);
|
||||
val_end = k;
|
||||
grid_begin = k;
|
||||
for(ir::value* g: grid)
|
||||
set_operand(k++, g);
|
||||
grid_end = k;
|
||||
set_operand(k++, num_warps);
|
||||
}
|
||||
|
||||
|
||||
ir::function* launch_inst::get_fn() {
|
||||
return (ir::function*)get_operand(0);
|
||||
}
|
||||
|
||||
std::vector<ir::value*> launch_inst::get_values() {
|
||||
std::vector<ir::value*> ret;
|
||||
for(int i = val_begin; i < val_end; i++)
|
||||
ret.push_back(get_operand(i));
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<ir::value*> launch_inst::get_grid() {
|
||||
std::vector<ir::value*> ret;
|
||||
for(int i = grid_begin; i < grid_end; i++)
|
||||
ret.push_back(get_operand(i));
|
||||
return ret;
|
||||
}
|
||||
|
||||
ir::value* launch_inst::get_num_warps() {
|
||||
return get_operand(grid_end);
|
||||
}
|
||||
|
||||
|
||||
launch_inst* launch_inst::create(ir::function *fn, const std::vector<ir::value *> &values, const std::vector<ir::value *> &grid, ir::value *num_warps, const std::string &name, instruction *next) {
|
||||
return new launch_inst(fn, values, grid, num_warps, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary_operator classes
|
||||
@@ -324,7 +389,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
|
||||
|
||||
// return_inst
|
||||
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
|
||||
: terminator_inst(type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
|
||||
: terminator_inst(ret_val?ret_val->get_type():type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
|
||||
if(ret_val)
|
||||
set_operand(0, ret_val);
|
||||
}
|
||||
@@ -521,6 +586,36 @@ masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
|
||||
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) {
|
||||
return new masked_store_inst(ptr, val, mask, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// struct classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// insert value
|
||||
|
||||
insert_value_inst::insert_value_inst(value *val, value *elt, size_t idx, const std::string& name, instruction *next)
|
||||
: instruction(val->get_type(), INST_INSERT_VALUE, 2, name, next), idx_(idx) {
|
||||
set_operand(0, val);
|
||||
set_operand(1, elt);
|
||||
}
|
||||
|
||||
insert_value_inst* insert_value_inst::create(value *val, value *elt, size_t idx, const std::string& name, instruction *next){
|
||||
return new insert_value_inst(val, elt, idx, name, next);
|
||||
}
|
||||
|
||||
|
||||
// extract value
|
||||
|
||||
extract_value_inst::extract_value_inst(value *val, size_t idx, const std::string& name, instruction *next)
|
||||
: instruction(val->get_type()->get_struct_type(idx), INST_EXTRACT_VALUE, 1, name, next), idx_(idx) {
|
||||
set_operand(0, val);
|
||||
}
|
||||
|
||||
extract_value_inst* extract_value_inst::create(value *val, size_t idx, const std::string& name, instruction *next){
|
||||
return new extract_value_inst(val, idx, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -575,6 +670,9 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
|
||||
return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// matmul_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -9,17 +9,12 @@
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
/* Module */
|
||||
module::module(const std::string &name, builder &builder)
|
||||
: name_(name), builder_(builder) {
|
||||
/* */
|
||||
value_constructor::value_constructor(ir::builder& builder): builder_(builder){
|
||||
sealed_blocks_.insert(nullptr);
|
||||
}
|
||||
|
||||
ir::builder& module::get_builder() {
|
||||
return builder_;
|
||||
}
|
||||
|
||||
void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
|
||||
void value_constructor::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
|
||||
values_[val_key_t{name, block}] = value;
|
||||
auto it = metadatas_.find(name);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(value))
|
||||
@@ -29,23 +24,11 @@ void module::set_value(const std::string& name, ir::basic_block *block, ir::valu
|
||||
// value->set_name(name);
|
||||
}
|
||||
|
||||
void module::set_value(const std::string& name, ir::value *value){
|
||||
void value_constructor::set_value(const std::string& name, ir::value *value){
|
||||
return set_value(name, builder_.get_insert_block(), value);
|
||||
}
|
||||
|
||||
void module::set_const(const std::string& name){
|
||||
const_.insert(name);
|
||||
}
|
||||
|
||||
void module::set_continue_fn(std::function<ir::value*()> fn) {
|
||||
continue_fn_ = fn;
|
||||
}
|
||||
|
||||
std::function<ir::value*()> module::get_continue_fn() {
|
||||
return continue_fn_;
|
||||
}
|
||||
|
||||
ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
|
||||
ir::phi_node* value_constructor::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
|
||||
basic_block::iterator insert = block->get_first_non_phi();
|
||||
if(insert != block->end()){
|
||||
builder_.set_insert_point(insert);
|
||||
@@ -56,7 +39,7 @@ ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_bloc
|
||||
return res;
|
||||
}
|
||||
|
||||
ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
|
||||
ir::value *value_constructor::try_remove_trivial_phis(ir::phi_node *&phi){
|
||||
// find non-self references
|
||||
std::set<ir::value*> non_self_ref;
|
||||
std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()),
|
||||
@@ -69,7 +52,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
|
||||
assert(same != nullptr);
|
||||
phi->replace_all_uses_with(same);
|
||||
phi->erase_from_parent();
|
||||
std::set<ir::user*> users = phi->get_users();
|
||||
std::vector<ir::user*> users = phi->get_users();
|
||||
for(ir::user* u: users)
|
||||
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
|
||||
if(uphi != phi)
|
||||
@@ -78,7 +61,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
|
||||
}
|
||||
|
||||
|
||||
ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){
|
||||
ir::value *value_constructor::add_phi_operands(const std::string& name, ir::phi_node *&phi){
|
||||
// already initialized
|
||||
if(phi->get_num_operands())
|
||||
return phi;
|
||||
@@ -90,12 +73,11 @@ ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi)
|
||||
return phi;
|
||||
}
|
||||
|
||||
ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) {
|
||||
ir::value *value_constructor::get_value_recursive(const std::string& name, ir::basic_block *block) {
|
||||
ir::value *result;
|
||||
bool is_const = const_.find(name) != const_.end();
|
||||
auto &preds = block->get_predecessors();
|
||||
auto preds = block->get_predecessors();
|
||||
ir::type *ty = types_.at(name);
|
||||
if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
|
||||
if(block && sealed_blocks_.find(block) == sealed_blocks_.end()){
|
||||
incomplete_phis_[block][name] = make_phi(ty, 1, block);
|
||||
result = (ir::value*)incomplete_phis_[block][name];
|
||||
}
|
||||
@@ -117,10 +99,12 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
|
||||
return result;
|
||||
}
|
||||
|
||||
ir::value *module::get_value(const std::string& name, ir::basic_block *block) {
|
||||
ir::value *value_constructor::get_value(const std::string& name, ir::basic_block *block) {
|
||||
ir::basic_block* save_block = builder_.get_insert_block();
|
||||
ir::basic_block::iterator save_pt = builder_.get_insert_point();
|
||||
val_key_t key(name, block);
|
||||
// std::cout << values_.size() << std::endl;
|
||||
// std::cout << name << " " << block << " " << values_.begin()->first.first << " " << values_.begin()->first.second << std::endl;
|
||||
if(values_.find(key) != values_.end()){
|
||||
return values_.at(key);
|
||||
}
|
||||
@@ -131,15 +115,11 @@ ir::value *module::get_value(const std::string& name, ir::basic_block *block) {
|
||||
return result;
|
||||
}
|
||||
|
||||
ir::value *module::get_value(const std::string& name) {
|
||||
ir::value *value_constructor::get_value(const std::string& name) {
|
||||
return get_value(name, builder_.get_insert_block());
|
||||
}
|
||||
|
||||
const std::string& module::get_name() {
|
||||
return name_;
|
||||
}
|
||||
|
||||
void module::seal_block(ir::basic_block *block){
|
||||
void value_constructor::seal_block(ir::basic_block *block){
|
||||
for(auto &x: incomplete_phis_[block]){
|
||||
add_phi_operands(x.first, x.second);
|
||||
if(get_value(x.first) == x.second)
|
||||
@@ -149,11 +129,40 @@ void module::seal_block(ir::basic_block *block){
|
||||
incomplete_phis_[block].clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* Module */
|
||||
|
||||
module::module(const std::string &name, builder &builder)
|
||||
: name_(name), builder_(builder) {
|
||||
}
|
||||
|
||||
void module::reset_ret_ty(const std::string& name, type* ty) {
|
||||
get_function(name)->get_fn_type()->reset_ret_ty(ty);
|
||||
}
|
||||
|
||||
ir::builder& module::get_builder() {
|
||||
return builder_;
|
||||
}
|
||||
|
||||
void module::set_continue_fn(std::function<ir::value*()> fn) {
|
||||
continue_fn_ = fn;
|
||||
}
|
||||
|
||||
std::function<ir::value*()> module::get_continue_fn() {
|
||||
return continue_fn_;
|
||||
}
|
||||
|
||||
const std::string& module::get_name() {
|
||||
return name_;
|
||||
}
|
||||
|
||||
/* functions */
|
||||
function *module::get_or_insert_function(const std::string &name, function_type *ty) {
|
||||
function *&fn = (function*&)symbols_[name];
|
||||
if(fn == nullptr)
|
||||
return fn = function::create(ty, global_value::external, name, this);
|
||||
if(fn == nullptr){
|
||||
fn = function::create(ty, global_value::external, name, this);
|
||||
}
|
||||
return fn;
|
||||
}
|
||||
|
||||
|
@@ -188,7 +188,26 @@ bool composite_type::index_valid(value *idx) const{
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tile_type class
|
||||
// struct_type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct_type::struct_type(const contained_tys_vec_t& tys, bool is_packed)
|
||||
: composite_type(tys[0]->get_context(), StructTyID), is_packed_(is_packed) {
|
||||
contained_tys_ = tys;
|
||||
}
|
||||
|
||||
struct_type* struct_type::get(const contained_tys_vec_t& tys, bool is_packed) {
|
||||
assert(tys.size());
|
||||
context_impl* impl = tys[0]->get_context().p_impl.get();
|
||||
struct_type *& entry = impl->struct_tys[tys];
|
||||
if(!entry)
|
||||
entry = new struct_type(tys, is_packed);
|
||||
return entry;
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// block_type class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
block_type::block_type(type *ty, const block_shapes_t &shapes)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
|
||||
@@ -17,11 +18,11 @@ value::value(type *ty, const std::string &name): ty_(ty){
|
||||
}
|
||||
|
||||
void value::add_use(user *arg) {
|
||||
users_.insert(arg);
|
||||
users_.push_back(arg);
|
||||
}
|
||||
|
||||
value::users_t::iterator value::erase_use(user *arg){
|
||||
auto it = users_.find(arg);
|
||||
auto it = std::find(users_.begin(), users_.end(), arg);
|
||||
if(it == users_.end())
|
||||
return it;
|
||||
return users_.erase(it);
|
||||
|
Reference in New Issue
Block a user