[CODEGEN] Removed dedicated reassociate pass to merge it into LLVM isel (#101)
This massively simplifies implementation of `reassociate` and also fixes a bunch of bug. The pass could still be improved, but can already be used to generate constant pointer offsets in eg the matmul epilogue
This commit is contained in:
committed by
Philippe Tillet
parent
e16bee1a27
commit
840140bf26
@@ -174,8 +174,6 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
return is_constant_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
|
||||
if(dynamic_cast<ir::make_range_sta*>(v))
|
||||
return add_to_cache(v, {cst_info{true, 0}}, is_constant_);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_is_constant_phi(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
@@ -322,8 +320,6 @@ std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
|
||||
auto shapes = v->get_type()->get_block_shapes();
|
||||
if(dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||
if(dynamic_cast<ir::make_range_sta*>(v))
|
||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
||||
}
|
||||
|
||||
@@ -486,10 +482,6 @@ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range_dyn*>(v))
|
||||
return add_to_cache(x, {128}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range_sta*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_range()->get_first()->get_value()}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_starting_multiple_gep(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
|
@@ -12,7 +12,6 @@
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/driver/module.h"
|
||||
@@ -48,7 +47,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole(target.get(), &layouts);
|
||||
codegen::transform::reassociate reassociate;
|
||||
// codegen::transform::reassociate reassociate;
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
|
||||
// run passes
|
||||
@@ -76,7 +75,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
||||
align.run(ir);
|
||||
dce.run(ir);
|
||||
if (target->is_gpu()) {
|
||||
reassociate.run(ir);
|
||||
// reassociate.run(ir);
|
||||
cts.run(ir);
|
||||
}
|
||||
dce.run(ir);
|
||||
@@ -100,4 +99,4 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
} // namespace triton
|
||||
|
@@ -23,6 +23,63 @@ namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
Value* adder::operator()(Value *x, Value *y, const std::string& name) {
|
||||
// (x + cst) + y -> (x + y) + cst
|
||||
if(auto* bin = dyn_cast<BinaryOperator>(x))
|
||||
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||
if(dyn_cast<Constant>(bin->getOperand(1))){
|
||||
return (*builder_)->CreateAdd((*builder_)->CreateAdd(bin->getOperand(0), y),
|
||||
bin->getOperand(1));
|
||||
}
|
||||
// (x + (y + cst)) -> (x + y) + cst
|
||||
if(auto* bin = dyn_cast<BinaryOperator>(y))
|
||||
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||
if(dyn_cast<Constant>(bin->getOperand(1))){
|
||||
return (*builder_)->CreateAdd((*builder_)->CreateAdd(x, bin->getOperand(0)),
|
||||
bin->getOperand(1));
|
||||
}
|
||||
|
||||
// default
|
||||
return (*builder_)->CreateAdd(x, y, name);
|
||||
}
|
||||
|
||||
Value* multiplier::operator()(Value *x, Value *y, const std::string &name) {
|
||||
// (x + cst1) * cst2 -> (x * cst2) + (cst1 * cst2)
|
||||
if(auto* bin = dyn_cast<BinaryOperator>(x))
|
||||
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||
if(dyn_cast<Constant>(bin->getOperand(1)))
|
||||
if(dyn_cast<Constant>(y)){
|
||||
return (*builder_)->CreateAdd((*builder_)->CreateMul(bin->getOperand(0), y),
|
||||
(*builder_)->CreateMul(bin->getOperand(1), y));
|
||||
}
|
||||
// default
|
||||
return (*builder_)->CreateMul(x, y, name);
|
||||
}
|
||||
|
||||
Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
// (ptr + cst1) + (cst2) -> ptr + (cst1 + cst2)
|
||||
if(auto* gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
if(ConstantInt* cst1 = dyn_cast<ConstantInt>(gep->idx_begin()))
|
||||
if(ConstantInt* cst2 = dyn_cast<ConstantInt>(off)){
|
||||
return (*builder_)->CreateGEP(gep->getPointerOperand(),
|
||||
(*builder_)->CreateAdd(cst1, cst2));
|
||||
}
|
||||
// ptr + (off + cst) -> (ptr + off) + cst
|
||||
if(auto* bin = dyn_cast<BinaryOperator>(off))
|
||||
if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add)
|
||||
if(ConstantInt* cst = dyn_cast<ConstantInt>(bin->getOperand(1))){
|
||||
return (*builder_)->CreateGEP((*builder_)->CreateGEP(ptr, bin->getOperand(0)),
|
||||
bin->getOperand(1));
|
||||
}
|
||||
// default
|
||||
return (*builder_)->CreateGEP(ptr, off, name);
|
||||
}
|
||||
|
||||
//Value* geper::operator()(Type *ty, Value *ptr, std::vector<Value *> vals, const std::string &name) {
|
||||
// return (*builder_)->CreateGEP(ty, ptr, vals, name);
|
||||
//}
|
||||
|
||||
|
||||
// types
|
||||
#define void_ty builder_->getVoidTy()
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
@@ -34,7 +91,6 @@ using namespace llvm;
|
||||
// constants
|
||||
#define i32(...) builder_->getInt32(__VA_ARGS__)
|
||||
// ops
|
||||
#define add(...) builder_->CreateAdd(__VA_ARGS__)
|
||||
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
|
||||
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
|
||||
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
|
||||
@@ -52,7 +108,6 @@ using namespace llvm;
|
||||
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
|
||||
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
|
||||
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
|
||||
#define gep(...) builder_->CreateGEP(__VA_ARGS__)
|
||||
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
|
||||
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
|
||||
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
|
||||
@@ -64,7 +119,6 @@ using namespace llvm;
|
||||
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
||||
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
||||
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
||||
#define mul(...) builder_->CreateMul(__VA_ARGS__)
|
||||
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
|
||||
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
|
||||
#define ret(...) builder_->CreateRet(__VA_ARGS__)
|
||||
@@ -144,7 +198,7 @@ generator::generator(analysis::axes *a_axes,
|
||||
target *tgt,
|
||||
unsigned num_warps)
|
||||
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
|
||||
tgt_(tgt), num_warps_(num_warps) {
|
||||
tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
|
||||
|
||||
}
|
||||
|
||||
@@ -207,8 +261,8 @@ void generator::visit_phi_node(ir::phi_node* x) {
|
||||
* \brief Code Generation for `binary_operator`
|
||||
*/
|
||||
void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
using ll = llvm::Instruction::BinaryOps;
|
||||
auto cvt = [](ir::binary_op_t op){
|
||||
using ll = llvm::Instruction::BinaryOps;
|
||||
using tt = ir::binary_op_t;
|
||||
switch(op) {
|
||||
case tt::Add: return ll::Add;
|
||||
@@ -235,7 +289,13 @@ void generator::visit_binary_operator(ir::binary_operator*x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
Value *lhs = vals_[x->get_operand(0)][idx];
|
||||
Value *rhs = vals_[x->get_operand(1)][idx];
|
||||
vals_[x][idx] = bin_op(cvt(x->get_op()), lhs, rhs);
|
||||
auto op = cvt(x->get_op());
|
||||
if(op == ll::Add)
|
||||
vals_[x][idx] = add(lhs, rhs);
|
||||
else if(op == ll::Mul)
|
||||
vals_[x][idx] = mul(lhs, rhs);
|
||||
else
|
||||
vals_[x][idx] = bin_op(op, lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,8 +308,8 @@ void generator::visit_getelementptr_inst(ir::getelementptr_inst* x) {
|
||||
std::vector<Value*> vals;
|
||||
for(auto it= x->idx_begin(); it != x->idx_end(); it++)
|
||||
vals.push_back(vals_[*it][idx]);
|
||||
Type *ty = cvt(x->get_source_elt_ty()->get_scalar_ty());
|
||||
vals_[x][idx] = gep(ty, ptr, vals);
|
||||
assert(vals.size() == 1);
|
||||
vals_[x][idx] = gep(ptr, vals[0]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -640,7 +700,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634);
|
||||
std::vector<llvm::Type*> tys = {f32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false);
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $1;", "=f,f", false);
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
|
||||
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
|
||||
@@ -1576,7 +1636,7 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
|
||||
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
|
||||
Value *in_base = in_gep->getPointerOperand();
|
||||
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
|
||||
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize*in_vec : 0;
|
||||
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
|
||||
in_base = cst ? in_base : in_gep;
|
||||
// output ptr info
|
||||
Value* out_base = shared[i].first;
|
||||
@@ -1683,34 +1743,34 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
||||
call(iasm);
|
||||
}
|
||||
|
||||
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
assert(idx.size() == 1);
|
||||
if(idx[0] == i32(0))
|
||||
vals_[x][idx] = idx[0];
|
||||
else{
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
vals_[x][idx] = bin_add->getOperand(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
//void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
||||
// for(indices_t idx: idxs_.at(x)){
|
||||
// assert(idx.size() == 1);
|
||||
// if(idx[0] == i32(0))
|
||||
// vals_[x][idx] = idx[0];
|
||||
// else{
|
||||
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
// assert(bin_add);
|
||||
// vals_[x][idx] = bin_add->getOperand(0);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
void generator::visit_make_range_sta(ir::make_range_sta* x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
assert(idx.size() == 1);
|
||||
if(idx[0] == i32(0)){
|
||||
vals_[x][idx] = idx[0];
|
||||
}
|
||||
else{
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
Value *cst = bin_add->getOperand(1);
|
||||
assert(isa<Constant>(cst));
|
||||
vals_[x][idx] = cst;
|
||||
}
|
||||
};
|
||||
}
|
||||
//void generator::visit_make_range_sta(ir::make_range_sta* x) {
|
||||
// for(indices_t idx: idxs_.at(x)){
|
||||
// assert(idx.size() == 1);
|
||||
// if(idx[0] == i32(0)){
|
||||
// vals_[x][idx] = idx[0];
|
||||
// }
|
||||
// else{
|
||||
// BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
// assert(bin_add);
|
||||
// Value *cst = bin_add->getOperand(1);
|
||||
// assert(isa<Constant>(cst));
|
||||
// vals_[x][idx] = cst;
|
||||
// }
|
||||
// };
|
||||
//}
|
||||
|
||||
void generator::visit_make_range(ir::make_range* x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
|
@@ -1,267 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
|
||||
ir::binary_operator *bin_op = dynamic_cast<ir::binary_operator*>(x);
|
||||
bool is_bin_add = bin_op && bin_op->get_op()== ir::binary_op_t::Add;
|
||||
if(is_bin_add)
|
||||
return (ir::instruction*)x;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline bool is_cst(ir::value *x) {
|
||||
if(dynamic_cast<ir::constant*>(x))
|
||||
return true;
|
||||
if(dynamic_cast<ir::make_range*>(x))
|
||||
return true;
|
||||
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
|
||||
return is_cst(v->get_operand(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
ir::builder &builder,
|
||||
ir::value *&noncst,
|
||||
ir::value *&cst){
|
||||
// value doesn't change by default
|
||||
ir::value* new_value = old_value;
|
||||
cst = nullptr;
|
||||
noncst = old_value;
|
||||
|
||||
// handle retiling
|
||||
if(ir::instruction* op = dynamic_cast<ir::retile_inst*>(old_value)){
|
||||
auto shapes = op->get_type()->get_block_shapes();
|
||||
ir::value *old_arg = op->get_operand(0);
|
||||
ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst);
|
||||
// retile(x + y) = retile(x) + retile(y)
|
||||
if(ir::instruction* bin_add = is_bin_add(new_arg))
|
||||
if(cst){
|
||||
ir::value *old_lhs = bin_add->get_operand(0);
|
||||
ir::value *old_rhs = bin_add->get_operand(1);
|
||||
ir::value *new_lhs = nullptr;
|
||||
ir::value *new_rhs = nullptr;
|
||||
if(dynamic_cast<ir::reshape_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_reshape(old_lhs, shapes);
|
||||
new_rhs = builder.create_reshape(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs);
|
||||
}
|
||||
if(dynamic_cast<ir::broadcast_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_broadcast(old_lhs, shapes);
|
||||
new_rhs = builder.create_broadcast(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs);
|
||||
}
|
||||
if(dynamic_cast<ir::splat_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_splat(old_lhs, shapes);
|
||||
new_rhs = builder.create_splat(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle binary addition
|
||||
if(ir::instruction* op = is_bin_add(old_value)){
|
||||
builder.set_insert_point(op);
|
||||
std::string name = op->get_name();
|
||||
ir::value *lhs = reassociate_idx(op->get_operand (0), builder, noncst, cst);
|
||||
ir::value *rhs = reassociate_idx(op->get_operand(1), builder, noncst, cst);
|
||||
builder.set_insert_point(op);
|
||||
// (x + y) + z
|
||||
if(ir::instruction* bin_lhs = is_bin_add(lhs)){
|
||||
ir::value *llhs = bin_lhs->get_operand(0);
|
||||
ir::value *rlhs = bin_lhs->get_operand(1);
|
||||
// (cst + x) + y -> cst + (x + y)
|
||||
if(is_cst(llhs))
|
||||
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs));
|
||||
// (x + cst) + y -> cst + (x + y)
|
||||
if(is_cst(rlhs))
|
||||
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs));
|
||||
}
|
||||
// x + (y + z)
|
||||
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
|
||||
ir::value *lrhs = bin_rhs->get_operand(0);
|
||||
ir::value *rrhs = bin_rhs->get_operand(1);
|
||||
// x + (cst + y) -> cst + (x + y)
|
||||
if(is_cst(lrhs))
|
||||
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), cst);
|
||||
// x + (y + cst) -> cst + (x + y)
|
||||
if(is_cst(rrhs))
|
||||
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), cst);
|
||||
}
|
||||
}
|
||||
// extract constant and non-constant
|
||||
if(ir::instruction *bin_add = is_bin_add(new_value)){
|
||||
ir::value *new_lhs = bin_add->get_operand(0);
|
||||
ir::value *new_rhs = bin_add->get_operand(1);
|
||||
if(is_cst(new_lhs)){
|
||||
cst = new_lhs;
|
||||
noncst = new_rhs;
|
||||
}
|
||||
if(is_cst(new_rhs)){
|
||||
cst = new_rhs;
|
||||
noncst = new_lhs;
|
||||
}
|
||||
}
|
||||
// clean-up if some re-ordering happened
|
||||
if(old_value != new_value)
|
||||
old_value->replace_all_uses_with(new_value);
|
||||
return new_value;
|
||||
}
|
||||
|
||||
/* run */
|
||||
void reassociate::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
|
||||
// constant_range -> nv_dynamic_program_idx + nv_static_program_idx
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::make_range*> ranges;
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
for(ir::basic_block *block: rpo){
|
||||
// iterate through instruction
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
for(ir::value* op: i->ops())
|
||||
if(auto *range = dynamic_cast<ir::make_range*>(op))
|
||||
ranges.push_back(range);
|
||||
}
|
||||
|
||||
builder.set_insert_point(rpo.front()->get_first_non_phi());
|
||||
for(ir::make_range* old_range: ranges){
|
||||
ir::value* dyn_range = builder.insert(ir::make_range_dyn::create(old_range->get_type()));
|
||||
ir::value* static_range = ir::make_range_sta::get(old_range);
|
||||
ir::value* new_range = builder.create_add(dyn_range, static_range);
|
||||
old_range->replace_all_uses_with(new_range);
|
||||
}
|
||||
}
|
||||
|
||||
// reassociate
|
||||
std::map<ir::value*, cst_info> infos;
|
||||
std::set<ir::value*> replaced;
|
||||
size_t n_replaced;
|
||||
do{
|
||||
n_replaced = replaced.size();
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo){
|
||||
// iterate through instruction
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
// retiling
|
||||
if(ir::retile_inst *rt = dynamic_cast<ir::retile_inst*>(i)) {
|
||||
ir::value* op = rt->get_operand(0);
|
||||
if(infos.find(op) != infos.end()){
|
||||
builder.set_insert_point(rt);
|
||||
ir::getelementptr_inst* sta = infos.at(op).sta_ptr;
|
||||
ir::value* dyn = infos.at(op).dyn_ptr;
|
||||
ir::value* cst = *sta->idx_begin();
|
||||
if(dynamic_cast<ir::broadcast_inst*>(rt)) {
|
||||
auto shapes = rt->get_type()->get_block_shapes();
|
||||
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
|
||||
ir::value* broadcast = builder.create_broadcast(cst, shapes);
|
||||
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
|
||||
infos[rt] = cst_info{ndyn, nsta};
|
||||
}
|
||||
}
|
||||
}
|
||||
// getelementptr instruction
|
||||
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
|
||||
if(replaced.find(pz) != replaced.end())
|
||||
continue;
|
||||
// unpack GEP instruction
|
||||
ir::value* py = pz->get_pointer_operand();
|
||||
ir::value* offset = *pz->idx_begin();
|
||||
// reassociate index
|
||||
ir::value *sta = nullptr;
|
||||
ir::value *dyn = offset;
|
||||
reassociate_idx(offset, builder, dyn, sta);
|
||||
if(sta){
|
||||
builder.set_insert_point(pz);
|
||||
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
|
||||
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
|
||||
pz->replace_all_uses_with(sta_ptr);
|
||||
infos[sta_ptr].dyn_ptr = dyn_ptr;
|
||||
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
|
||||
replaced.insert(pz);
|
||||
}
|
||||
// reassociate pointer argument
|
||||
if(infos.find(py) != infos.end()){
|
||||
builder.set_insert_point(pz);
|
||||
ir::getelementptr_inst *sta = infos[py].sta_ptr;
|
||||
ir::value *dyn = infos[py].dyn_ptr;
|
||||
ir::value *cst = *sta->idx_begin();
|
||||
ir::value *off = *pz->idx_begin();
|
||||
ir::value *pz_dyn = builder.create_gep(dyn, {off});
|
||||
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst});
|
||||
pz->replace_all_uses_with(pz_sta);
|
||||
infos[pz_sta].dyn_ptr = pz_dyn;
|
||||
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
|
||||
replaced.insert(pz);
|
||||
}
|
||||
// reassociate phi-node pointer
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
|
||||
// only optimize the case where py = phi pa, pz for now
|
||||
std::vector<ir::value*> ops = phi->ops();
|
||||
if(ops.size() != 2)
|
||||
continue;
|
||||
if(ops[0] != pz && ops[1] != pz)
|
||||
continue;
|
||||
// grab incoming
|
||||
size_t idx_z = (ops[0] == pz) ? 0 : 1;
|
||||
size_t idx_a = (ops[0] == pz) ? 1 : 0;
|
||||
// check if pa is known to have constant offset
|
||||
ir::value *vpa = phi->get_incoming_value(idx_a);
|
||||
auto it_a = infos.find(vpa);
|
||||
if(it_a == infos.end())
|
||||
continue;
|
||||
// unpack dynamically/statically offset pointer
|
||||
ir::value *pa_dyn = it_a->second.dyn_ptr;
|
||||
ir::getelementptr_inst *pa_sta = it_a->second.sta_ptr;
|
||||
ir::value *pz = phi->get_incoming_value(idx_z);
|
||||
// extract offset
|
||||
ir::value *off = *pa_sta->idx_begin();
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node *phi_dyn = builder.create_phi(phi->get_type(), 2);
|
||||
phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a));
|
||||
builder.set_insert_point(phi->get_parent()->get_first_non_phi());
|
||||
// re-add the offset
|
||||
ir::value *phi_sta = builder.create_gep(phi_dyn, {off});
|
||||
phi_sta->set_name( phi->get_name() + "_sta");
|
||||
phi->replace_all_uses_with(phi_sta);
|
||||
// remove offset from pz
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(pz)){
|
||||
auto insts = x->get_parent()->get_inst_list();
|
||||
auto it = std::find(insts.begin(), insts.end(), x);
|
||||
it++;
|
||||
builder.set_insert_point(*it);
|
||||
}
|
||||
ir::value *_0 = builder.get_int32(0);
|
||||
if(off->get_type()->is_block_ty())
|
||||
_0 = builder.create_splat(_0, off->get_type()->get_block_shapes());
|
||||
ir::value *neg_off = builder.create_sub(_0, off);
|
||||
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
|
||||
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
|
||||
infos[phi_sta].dyn_ptr = phi_dyn;
|
||||
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
|
||||
replaced.insert(phi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}while(replaced.size() != n_replaced);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,22 +1,22 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
@@ -81,6 +81,7 @@ cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(progra
|
||||
dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_);
|
||||
dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_);
|
||||
dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_);
|
||||
std::cout << n_reg << std::endl;
|
||||
if (shared_optin > 49152){
|
||||
// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl;
|
||||
dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
||||
|
@@ -833,27 +833,27 @@ async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string
|
||||
}
|
||||
|
||||
|
||||
// nv_dynamic_program_idx
|
||||
make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
||||
: instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
|
||||
//// nv_dynamic_program_idx
|
||||
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
||||
// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
|
||||
|
||||
make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
|
||||
return new make_range_dyn(ty, name, next);
|
||||
}
|
||||
//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
|
||||
// return new make_range_dyn(ty, name, next);
|
||||
//}
|
||||
|
||||
// nv_static_program_idx
|
||||
make_range_sta::make_range_sta(make_range *range)
|
||||
: constant(range->get_type(), 0), range_(range) { }
|
||||
//// nv_static_program_idx
|
||||
//make_range_sta::make_range_sta(make_range *range)
|
||||
// : constant(range->get_type(), 0), range_(range) { }
|
||||
|
||||
make_range* make_range_sta::get_range() const
|
||||
{ return range_; }
|
||||
//make_range* make_range_sta::get_range() const
|
||||
//{ return range_; }
|
||||
|
||||
make_range_sta* make_range_sta::get(make_range* range) {
|
||||
static std::map<make_range*, make_range_sta*> cache;
|
||||
if(cache.find(range) == cache.end())
|
||||
cache.insert({range, new make_range_sta(range)});
|
||||
return cache.at(range);
|
||||
}
|
||||
//make_range_sta* make_range_sta::get(make_range* range) {
|
||||
// static std::map<make_range*, make_range_sta*> cache;
|
||||
// if(cache.find(range) == cache.end())
|
||||
// cache.insert({range, new make_range_sta(range)});
|
||||
// return cache.at(range);
|
||||
//}
|
||||
|
||||
|
||||
// make_range
|
||||
|
Reference in New Issue
Block a user