test
This commit is contained in:
@@ -82,9 +82,9 @@ int main() {
|
|||||||
// std::vector<unsigned*> params = tune.get_params(module);
|
// std::vector<unsigned*> params = tune.get_params(module);
|
||||||
// std::cout << params.size() << std::endl;
|
// std::cout << params.size() << std::endl;
|
||||||
// selection.run(module, llvm_module);
|
// selection.run(module, llvm_module);
|
||||||
// // print LLVM program
|
// print LLVM program
|
||||||
// llvm::PrintModulePass print(llvm::outs());
|
llvm::PrintModulePass print(llvm::outs());
|
||||||
// llvm::AnalysisManager<llvm::Module> analysis;
|
llvm::AnalysisManager<llvm::Module> analysis;
|
||||||
// print.run(llvm_module, analysis);
|
print.run(llvm_module, analysis);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@@ -78,8 +78,8 @@ class selection{
|
|||||||
private:
|
private:
|
||||||
// LLVM conversions
|
// LLVM conversions
|
||||||
llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx);
|
llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx);
|
||||||
llvm::Value* llvm_value(ir::value *v,llvm:: LLVMContext &ctx);
|
llvm::Value* llvm_value(ir::value *v, llvm:: LLVMContext &ctx, llvm::IRBuilder<> &builder);
|
||||||
llvm::Instruction* llvm_inst(ir::instruction *inst, std::function<llvm::Value*(ir::value*)> value, llvm::LLVMContext &ctx);
|
llvm::Instruction* llvm_inst(ir::instruction *inst, std::function<llvm::Value*(ir::value*)> value, llvm::LLVMContext &ctx, llvm::IRBuilder<> &builder);
|
||||||
llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx);
|
llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx);
|
||||||
|
|
||||||
// grid construction
|
// grid construction
|
||||||
|
@@ -324,7 +324,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
|||||||
ir::value *value = ir::undef_value::get(ty);
|
ir::value *value = ir::undef_value::get(ty);
|
||||||
if(expr_){
|
if(expr_){
|
||||||
value = expr_->codegen(mod);
|
value = expr_->codegen(mod);
|
||||||
explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
|
value = explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
|
||||||
implicit_broadcast(mod, value, ty);
|
implicit_broadcast(mod, value, ty);
|
||||||
}
|
}
|
||||||
value->set_name(name);
|
value->set_name(name);
|
||||||
@@ -336,85 +336,85 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
|||||||
/* Expression */
|
/* Expression */
|
||||||
/*------------------*/
|
/*------------------*/
|
||||||
/* Binary operator */
|
/* Binary operator */
|
||||||
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *arg, ir::value *rhs, const std::string &name) const
|
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
|
||||||
{
|
{
|
||||||
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
|
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
|
||||||
implicit_cast(builder, arg, rhs, is_float, is_ptr, is_int, is_signed);
|
implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
|
||||||
implicit_broadcast(mod, arg, rhs);
|
implicit_broadcast(mod, lhs, rhs);
|
||||||
if(op_==MUL && is_float)
|
if(op_==MUL && is_float)
|
||||||
return builder.create_fmul(arg, rhs, name);
|
return builder.create_fmul(lhs, rhs, name);
|
||||||
if(op_==MUL && is_int)
|
if(op_==MUL && is_int)
|
||||||
return builder.create_mul(arg, rhs, name);
|
return builder.create_mul(lhs, rhs, name);
|
||||||
if(op_==DIV && is_float)
|
if(op_==DIV && is_float)
|
||||||
return builder.create_fdiv(arg, rhs, name);
|
return builder.create_fdiv(lhs, rhs, name);
|
||||||
if(op_==DIV && is_int && is_signed)
|
if(op_==DIV && is_int && is_signed)
|
||||||
return builder.create_sdiv(arg, rhs, name);
|
return builder.create_sdiv(lhs, rhs, name);
|
||||||
if(op_==DIV && is_int && !is_signed)
|
if(op_==DIV && is_int && !is_signed)
|
||||||
return builder.create_udiv(arg, rhs, name);
|
return builder.create_udiv(lhs, rhs, name);
|
||||||
if(op_==MOD && is_float)
|
if(op_==MOD && is_float)
|
||||||
return builder.create_frem(arg, rhs, name);
|
return builder.create_frem(lhs, rhs, name);
|
||||||
if(op_==MOD && is_int && is_signed)
|
if(op_==MOD && is_int && is_signed)
|
||||||
return builder.create_srem(arg, rhs, name);
|
return builder.create_srem(lhs, rhs, name);
|
||||||
if(op_==MOD && is_int && !is_signed)
|
if(op_==MOD && is_int && !is_signed)
|
||||||
return builder.create_urem(arg, rhs, name);
|
return builder.create_urem(lhs, rhs, name);
|
||||||
if(op_==ADD && is_float)
|
if(op_==ADD && is_float)
|
||||||
return builder.create_fadd(arg, rhs, name);
|
return builder.create_fadd(lhs, rhs, name);
|
||||||
if(op_==ADD && is_int)
|
if(op_==ADD && is_int)
|
||||||
return builder.create_add(arg, rhs);
|
return builder.create_add(lhs, rhs);
|
||||||
if(op_==ADD && is_ptr)
|
if(op_==ADD && is_ptr)
|
||||||
return builder.create_gep(arg, {rhs});
|
return builder.create_gep(lhs, {rhs});
|
||||||
if(op_==SUB && is_float)
|
if(op_==SUB && is_float)
|
||||||
return builder.create_fsub(arg, rhs, name);
|
return builder.create_fsub(lhs, rhs, name);
|
||||||
if(op_==SUB && is_int)
|
if(op_==SUB && is_int)
|
||||||
return builder.create_sub(arg, rhs, name);
|
return builder.create_sub(lhs, rhs, name);
|
||||||
if(op_==SUB && is_ptr)
|
if(op_==SUB && is_ptr)
|
||||||
return builder.create_gep(arg, {builder.create_neg(rhs)});
|
return builder.create_gep(lhs, {builder.create_neg(rhs)});
|
||||||
if(op_==LEFT_SHIFT)
|
if(op_==LEFT_SHIFT)
|
||||||
return builder.create_shl(arg, rhs, name);
|
return builder.create_shl(lhs, rhs, name);
|
||||||
if(op_==RIGHT_SHIFT)
|
if(op_==RIGHT_SHIFT)
|
||||||
return builder.create_ashr(arg, rhs, name);
|
return builder.create_ashr(lhs, rhs, name);
|
||||||
if(op_ == LT && is_float)
|
if(op_ == LT && is_float)
|
||||||
return builder.create_fcmpOLT(arg, rhs, name);
|
return builder.create_fcmpOLT(lhs, rhs, name);
|
||||||
if(op_ == LT && is_int && is_signed)
|
if(op_ == LT && is_int && is_signed)
|
||||||
return builder.create_icmpSLT(arg, rhs, name);
|
return builder.create_icmpSLT(lhs, rhs, name);
|
||||||
if(op_ == LT && is_int && !is_signed)
|
if(op_ == LT && is_int && !is_signed)
|
||||||
return builder.create_icmpULT(arg, rhs, name);
|
return builder.create_icmpULT(lhs, rhs, name);
|
||||||
if(op_ == GT && is_float)
|
if(op_ == GT && is_float)
|
||||||
return builder.create_fcmpOGT(arg, rhs, name);
|
return builder.create_fcmpOGT(lhs, rhs, name);
|
||||||
if(op_ == GT && is_int && is_signed)
|
if(op_ == GT && is_int && is_signed)
|
||||||
return builder.create_icmpSGT(arg, rhs, name);
|
return builder.create_icmpSGT(lhs, rhs, name);
|
||||||
if(op_ == GT && is_int && !is_signed)
|
if(op_ == GT && is_int && !is_signed)
|
||||||
return builder.create_icmpUGT(arg, rhs, name);
|
return builder.create_icmpUGT(lhs, rhs, name);
|
||||||
if(op_ == LE && is_float)
|
if(op_ == LE && is_float)
|
||||||
return builder.create_fcmpOLE(arg, rhs, name);
|
return builder.create_fcmpOLE(lhs, rhs, name);
|
||||||
if(op_ == LE && is_int && is_signed)
|
if(op_ == LE && is_int && is_signed)
|
||||||
return builder.create_icmpSLE(arg, rhs, name);
|
return builder.create_icmpSLE(lhs, rhs, name);
|
||||||
if(op_ == LE && is_int && !is_signed)
|
if(op_ == LE && is_int && !is_signed)
|
||||||
return builder.create_icmpULE(arg, rhs, name);
|
return builder.create_icmpULE(lhs, rhs, name);
|
||||||
if(op_ == GE && is_float)
|
if(op_ == GE && is_float)
|
||||||
return builder.create_fcmpOGE(arg, rhs, name);
|
return builder.create_fcmpOGE(lhs, rhs, name);
|
||||||
if(op_ == GE && is_int && is_signed)
|
if(op_ == GE && is_int && is_signed)
|
||||||
return builder.create_icmpSGE(arg, rhs, name);
|
return builder.create_icmpSGE(lhs, rhs, name);
|
||||||
if(op_ == GE && is_int && !is_signed)
|
if(op_ == GE && is_int && !is_signed)
|
||||||
return builder.create_icmpUGE(arg, rhs, name);
|
return builder.create_icmpUGE(lhs, rhs, name);
|
||||||
if(op_ == EQ && is_float)
|
if(op_ == EQ && is_float)
|
||||||
return builder.create_fcmpOEQ(arg, rhs, name);
|
return builder.create_fcmpOEQ(lhs, rhs, name);
|
||||||
if(op_ == EQ && is_int)
|
if(op_ == EQ && is_int)
|
||||||
return builder.create_icmpEQ(arg, rhs, name);
|
return builder.create_icmpEQ(lhs, rhs, name);
|
||||||
if(op_ == NE && is_float)
|
if(op_ == NE && is_float)
|
||||||
return builder.create_fcmpONE(arg, rhs, name);
|
return builder.create_fcmpONE(lhs, rhs, name);
|
||||||
if(op_ == NE && is_int)
|
if(op_ == NE && is_int)
|
||||||
return builder.create_icmpNE(arg, rhs, name);
|
return builder.create_icmpNE(lhs, rhs, name);
|
||||||
if(op_ == AND)
|
if(op_ == AND)
|
||||||
return builder.create_and(arg, rhs, name);
|
return builder.create_and(lhs, rhs, name);
|
||||||
if(op_ == XOR)
|
if(op_ == XOR)
|
||||||
return builder.create_xor(arg, rhs, name);
|
return builder.create_xor(lhs, rhs, name);
|
||||||
if(op_ == OR)
|
if(op_ == OR)
|
||||||
return builder.create_or(arg, rhs, name);
|
return builder.create_or(lhs, rhs, name);
|
||||||
if(op_ == LAND)
|
if(op_ == LAND)
|
||||||
return builder.create_and(arg, rhs, name);
|
return builder.create_and(lhs, rhs, name);
|
||||||
if(op_ == LOR)
|
if(op_ == LOR)
|
||||||
return builder.create_or(arg, rhs, name);
|
return builder.create_or(lhs, rhs, name);
|
||||||
throw std::runtime_error("unreachable");
|
throw std::runtime_error("unreachable");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -105,49 +105,49 @@ Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
|
|||||||
|
|
||||||
|
|
||||||
/* convert ir::instruction to llvm::Instruction */
|
/* convert ir::instruction to llvm::Instruction */
|
||||||
Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, LLVMContext & ctx) {
|
Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, LLVMContext & ctx, IRBuilder<> &builder) {
|
||||||
auto block = [&](ir::basic_block *x) { return bmap_.at(x); };
|
auto block = [&](ir::basic_block *x) { return bmap_.at(x); };
|
||||||
auto type = [&](ir::type *x) { return llvm_type(x, ctx); };
|
auto type = [&](ir::type *x) { return llvm_type(x, ctx); };
|
||||||
if(auto* ii = dynamic_cast<ir::cond_branch_inst*>(inst)){
|
if(auto* ii = dynamic_cast<ir::cond_branch_inst*>(inst)){
|
||||||
BasicBlock *true_dest = block(ii->get_true_dest());
|
BasicBlock *true_dest = block(ii->get_true_dest());
|
||||||
BasicBlock *false_dest = block(ii->get_false_dest());
|
BasicBlock *false_dest = block(ii->get_false_dest());
|
||||||
Value *cond = value(ii->get_cond());
|
Value *cond = value(ii->get_cond());
|
||||||
return BranchInst::Create(true_dest, false_dest, cond);
|
return builder.CreateCondBr(cond, true_dest, false_dest);
|
||||||
}
|
}
|
||||||
if(auto* ii = dynamic_cast<ir::uncond_branch_inst*>(inst)){
|
if(auto* ii = dynamic_cast<ir::uncond_branch_inst*>(inst)){
|
||||||
BasicBlock *dest = block(ii->get_dest());
|
BasicBlock *dest = block(ii->get_dest());
|
||||||
return BranchInst::Create(dest);
|
return builder.CreateBr(dest);
|
||||||
}
|
}
|
||||||
if(auto* ii = dynamic_cast<ir::phi_node*>(inst)){
|
if(auto* ii = dynamic_cast<ir::phi_node*>(inst)){
|
||||||
Type *ty = type(ii->get_type()->get_scalar_ty());
|
Type *ty = type(ii->get_type()->get_scalar_ty());
|
||||||
unsigned num_ops = ii->get_num_operands();
|
unsigned num_ops = ii->get_num_operands();
|
||||||
return PHINode::Create(ty, num_ops, ii->get_name());
|
return builder.CreatePHI(ty, num_ops);
|
||||||
}
|
}
|
||||||
if(auto* ii = dynamic_cast<ir::return_inst*>(inst)){
|
if(auto* ii = dynamic_cast<ir::return_inst*>(inst)){
|
||||||
ir::value *ret_val = ii->get_return_value();
|
ir::value *ret_val = ii->get_return_value();
|
||||||
return ReturnInst::Create(ctx, ret_val?value(ret_val):nullptr);
|
return builder.CreateRet(ret_val?value(ret_val):nullptr);
|
||||||
}
|
}
|
||||||
if(auto* ii = dynamic_cast<ir::binary_operator*>(inst)){
|
if(auto* ii = dynamic_cast<ir::binary_operator*>(inst)){
|
||||||
Value *lhs = value(ii->get_operand(0));
|
Value *lhs = value(ii->get_operand(0));
|
||||||
Value *rhs = value(ii->get_operand(1));
|
Value *rhs = value(ii->get_operand(1));
|
||||||
return BinaryOperator::Create(ii->get_op(), lhs, rhs, ii->get_name());
|
return builder.Insert(BinaryOperator::Create(ii->get_op(), lhs, rhs));
|
||||||
}
|
}
|
||||||
if(auto* ii = dynamic_cast<ir::icmp_inst*>(inst)){
|
if(auto* ii = dynamic_cast<ir::icmp_inst*>(inst)){
|
||||||
CmpInst::Predicate pred = ii->get_pred();
|
CmpInst::Predicate pred = ii->get_pred();
|
||||||
Value *lhs = value(ii->get_operand(0));
|
Value *lhs = value(ii->get_operand(0));
|
||||||
Value *rhs = value(ii->get_operand(1));
|
Value *rhs = value(ii->get_operand(1));
|
||||||
return CmpInst::Create(Instruction::ICmp, pred, lhs, rhs, ii->get_name());
|
return builder.Insert(CmpInst::Create(Instruction::ICmp, pred, lhs, rhs));
|
||||||
}
|
}
|
||||||
if(auto* ii = dynamic_cast<ir::fcmp_inst*>(inst)){
|
if(auto* ii = dynamic_cast<ir::fcmp_inst*>(inst)){
|
||||||
CmpInst::Predicate pred = ii->get_pred();
|
CmpInst::Predicate pred = ii->get_pred();
|
||||||
Value *lhs = value(ii->get_operand(0));
|
Value *lhs = value(ii->get_operand(0));
|
||||||
Value *rhs = value(ii->get_operand(1));
|
Value *rhs = value(ii->get_operand(1));
|
||||||
return FCmpInst::Create(Instruction::FCmp, pred, lhs, rhs, ii->get_name());
|
return builder.Insert(FCmpInst::Create(Instruction::FCmp, pred, lhs, rhs));
|
||||||
}
|
}
|
||||||
if(auto* ii = dynamic_cast<ir::cast_inst*>(inst)){
|
if(auto* ii = dynamic_cast<ir::cast_inst*>(inst)){
|
||||||
Value *arg = value(ii->get_operand(0));
|
Value *arg = value(ii->get_operand(0));
|
||||||
Type *dst_ty = type(ii->get_type()->get_scalar_ty());
|
Type *dst_ty = type(ii->get_type()->get_scalar_ty());
|
||||||
return CastInst::Create(ii->get_op(), arg, dst_ty, ii->get_name());
|
return builder.Insert(CastInst::Create(ii->get_op(), arg, dst_ty));
|
||||||
}
|
}
|
||||||
if(auto* ii = dynamic_cast<ir::getelementptr_inst*>(inst)){
|
if(auto* ii = dynamic_cast<ir::getelementptr_inst*>(inst)){
|
||||||
std::vector<Value*> idx_vals;
|
std::vector<Value*> idx_vals;
|
||||||
@@ -155,31 +155,31 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
|||||||
[&value](ir::value* x){ return value(x);});
|
[&value](ir::value* x){ return value(x);});
|
||||||
Type *source_ty = type(ii->get_source_elt_ty()->get_scalar_ty());
|
Type *source_ty = type(ii->get_source_elt_ty()->get_scalar_ty());
|
||||||
Value *arg = value(ii->get_operand(0));
|
Value *arg = value(ii->get_operand(0));
|
||||||
return GetElementPtrInst::Create(source_ty, arg, idx_vals, ii->get_name());
|
return builder.Insert(GetElementPtrInst::Create(source_ty, arg, idx_vals));
|
||||||
}
|
}
|
||||||
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){
|
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){
|
||||||
Value *ptr = value(ii->get_pointer_operand());
|
Value *ptr = value(ii->get_pointer_operand());
|
||||||
return new LoadInst(ptr, ii->get_name());
|
return builder.CreateLoad(ptr);
|
||||||
}
|
}
|
||||||
// unknown instruction
|
// unknown instruction
|
||||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
throw std::runtime_error("unknown conversion from ir::type to Type");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* convert ir::value to llvm::Value */
|
/* convert ir::value to llvm::Value */
|
||||||
Value* selection::llvm_value(ir::value *v, LLVMContext &ctx) {
|
Value* selection::llvm_value(ir::value *v, LLVMContext &ctx, IRBuilder<> &builder) {
|
||||||
assert(!v->get_type()->is_tile_ty());
|
assert(!v->get_type()->is_tile_ty());
|
||||||
if(vmap_.find(v) != vmap_.end())
|
if(vmap_.find(v) != vmap_.end())
|
||||||
return vmap_.at(v);
|
return vmap_.at(v);
|
||||||
// create operands
|
// create operands
|
||||||
if(auto *uu = dynamic_cast<ir::user*>(v))
|
if(auto *uu = dynamic_cast<ir::user*>(v))
|
||||||
for(ir::value* u: uu->ops())
|
for(ir::value* u: uu->ops())
|
||||||
vmap_[u] = llvm_value(u, ctx);
|
vmap_.insert({u, llvm_value(u, ctx, builder)});
|
||||||
if(auto *cc = dynamic_cast<ir::constant*>(v))
|
if(auto *cc = dynamic_cast<ir::constant*>(v))
|
||||||
return llvm_constant(cc, ctx);
|
return llvm_constant(cc, ctx);
|
||||||
// instruction
|
// instruction
|
||||||
if(auto *ii = dynamic_cast<ir::instruction*>(v)){
|
if(auto *ii = dynamic_cast<ir::instruction*>(v)){
|
||||||
auto value = [&](ir::value *x) { return llvm_value(x, ctx); };
|
auto value = [&](ir::value *x) { return llvm_value(x, ctx, builder); };
|
||||||
return llvm_inst(ii, value, ctx);
|
return llvm_inst(ii, value, ctx, builder);
|
||||||
}
|
}
|
||||||
// unknown value
|
// unknown value
|
||||||
throw std::runtime_error("unknown conversion from ir::value to Value");
|
throw std::runtime_error("unknown conversion from ir::value to Value");
|
||||||
@@ -308,7 +308,14 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
|||||||
else
|
else
|
||||||
axes[d].values = {builder.getInt32(0)};
|
axes[d].values = {builder.getInt32(0)};
|
||||||
}
|
}
|
||||||
tmap_.insert({v, new distributed_tile(ty, shapes, axes)});
|
distributed_tile *T = new distributed_tile(ty, shapes, axes);
|
||||||
|
tmap_.insert({v, T});
|
||||||
|
// constant range
|
||||||
|
if(dynamic_cast<ir::constant*>(v))
|
||||||
|
T->for_each([&](indices_t idx){
|
||||||
|
T->set_value(idx, idx[0]);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -340,76 +347,88 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder){
|
|||||||
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
|
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
|
||||||
Module *module = builder.GetInsertBlock()->getModule();
|
Module *module = builder.GetInsertBlock()->getModule();
|
||||||
LLVMContext &ctx = builder.getContext();
|
LLVMContext &ctx = builder.getContext();
|
||||||
tile *ti = tmap_[ins];
|
// store
|
||||||
distributed_tile* result = (distributed_tile*)ti;
|
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
|
||||||
if(!ins->get_type()->is_tile_ty())
|
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||||
return;
|
tile *value = tmap_.at(x->get_value_operand());
|
||||||
const auto& shapes = ins->get_type()->get_tile_shapes();
|
ptr->for_each([&](indices_t idx){
|
||||||
// global_range
|
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
|
||||||
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
|
|
||||||
static std::array<Intrinsic::ID, 3> ctaid = {
|
|
||||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
|
|
||||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
|
|
||||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
|
|
||||||
};
|
|
||||||
Function *get_group_id = Intrinsic::getDeclaration(module, ctaid[x->get_axis()]);
|
|
||||||
Value *group_id = builder.CreateCall(get_group_id, {});
|
|
||||||
Value *offset = builder.CreateMul(builder.getInt32(shapes[0]), group_id);
|
|
||||||
result->for_each([&](indices_t idx){
|
|
||||||
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
|
|
||||||
result->set_value(idx, builder.CreateAdd(bin->getOperand(1),
|
|
||||||
builder.CreateAdd(bin->getOperand(0), offset)));
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
// reshape
|
|
||||||
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
|
|
||||||
ir::value* in = ins->get_operand(0);
|
|
||||||
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
|
|
||||||
result->for_each([&](indices_t out_idx){
|
|
||||||
indices_t in_idx;
|
|
||||||
for(size_t k = 0; k < shapes.size(); k++){
|
|
||||||
if(shapes[k] > 1)
|
|
||||||
in_idx.push_back(out_idx[k]);
|
|
||||||
}
|
|
||||||
result->set_value(out_idx, in_tile->get_value(in_idx));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
// splat
|
|
||||||
else if(dynamic_cast<ir::splat_inst*>(ins)) {
|
|
||||||
result->for_each([&](indices_t idx) {
|
|
||||||
result->set_value(idx, llvm_value(ins->get_operand(0), ctx));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
// broadcast
|
|
||||||
else if(dynamic_cast<ir::broadcast_inst*>(ins)) {
|
|
||||||
ir::value* in = ins->get_operand(0);
|
|
||||||
const auto& in_shapes = in->get_type()->get_tile_shapes();
|
|
||||||
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
|
|
||||||
result->for_each([&](indices_t out_idx){
|
|
||||||
indices_t in_idx = out_idx;
|
|
||||||
for(size_t k = 0; k < in_idx.size(); k++){
|
|
||||||
if(in_shapes[k] == 1)
|
|
||||||
in_idx[k] = builder.getInt32(0);
|
|
||||||
result->set_value(out_idx, in_tile->get_value(in_idx));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
// copy to shared
|
|
||||||
else if(dynamic_cast<ir::copy_to_shared_inst*>(ins)) {
|
|
||||||
|
|
||||||
}
|
|
||||||
// element-wise
|
|
||||||
else {
|
else {
|
||||||
result->for_each([&](indices_t idx){
|
tile *ti = tmap_[ins];
|
||||||
auto value = [&](ir::value *x) {
|
distributed_tile* result = (distributed_tile*)ti;
|
||||||
if(x->get_type()->is_tile_ty())
|
if(!ins->get_type()->is_tile_ty())
|
||||||
return tmap_.at(x)->get_value(idx);
|
return;
|
||||||
else
|
const auto& shapes = ins->get_type()->get_tile_shapes();
|
||||||
return llvm_value(x, ctx);
|
// global_range
|
||||||
|
if(auto *x = dynamic_cast<ir::get_global_range_inst*>(ins)) {
|
||||||
|
static std::array<Intrinsic::ID, 3> ctaid = {
|
||||||
|
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
|
||||||
|
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
|
||||||
|
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
|
||||||
};
|
};
|
||||||
result->set_value(idx, llvm_inst(ins, value, ctx));
|
Function *get_group_id = Intrinsic::getDeclaration(module, ctaid[x->get_axis()]);
|
||||||
});
|
Value *group_id = builder.CreateCall(get_group_id, {});
|
||||||
|
Value *offset = builder.CreateMul(builder.getInt32(shapes[0]), group_id);
|
||||||
|
result->for_each([&](indices_t idx){
|
||||||
|
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
|
||||||
|
result->set_value(idx, builder.CreateAdd(bin->getOperand(1),
|
||||||
|
builder.CreateAdd(bin->getOperand(0), offset)));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// reshape
|
||||||
|
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
|
||||||
|
ir::value* in = ins->get_operand(0);
|
||||||
|
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
|
||||||
|
result->for_each([&](indices_t out_idx){
|
||||||
|
indices_t in_idx;
|
||||||
|
for(size_t k = 0; k < shapes.size(); k++){
|
||||||
|
if(shapes[k] > 1)
|
||||||
|
in_idx.push_back(out_idx[k]);
|
||||||
|
}
|
||||||
|
result->set_value(out_idx, in_tile->get_value(in_idx));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// splat
|
||||||
|
else if(dynamic_cast<ir::splat_inst*>(ins)) {
|
||||||
|
result->for_each([&](indices_t idx) {
|
||||||
|
result->set_value(idx, llvm_value(ins->get_operand(0), ctx, builder));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// broadcast
|
||||||
|
else if(dynamic_cast<ir::broadcast_inst*>(ins)) {
|
||||||
|
ir::value* in = ins->get_operand(0);
|
||||||
|
const auto& in_shapes = in->get_type()->get_tile_shapes();
|
||||||
|
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
|
||||||
|
result->for_each([&](indices_t out_idx){
|
||||||
|
indices_t in_idx = out_idx;
|
||||||
|
for(size_t k = 0; k < in_idx.size(); k++){
|
||||||
|
if(in_shapes[k] == 1)
|
||||||
|
in_idx[k] = builder.getInt32(0);
|
||||||
|
result->set_value(out_idx, in_tile->get_value(in_idx));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// copy to shared
|
||||||
|
else if(dynamic_cast<ir::copy_to_shared_inst*>(ins)) {
|
||||||
|
|
||||||
|
}
|
||||||
|
// element-wise
|
||||||
|
else {
|
||||||
|
result->for_each([&](indices_t idx){
|
||||||
|
auto value = [&](ir::value *x) {
|
||||||
|
if(x->get_type()->is_tile_ty())
|
||||||
|
return tmap_.at(x)->get_value(idx);
|
||||||
|
else
|
||||||
|
return llvm_value(x, ctx, builder);
|
||||||
|
};
|
||||||
|
result->set_value(idx, llvm_inst(ins, value, ctx, builder));
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
|
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
|
||||||
@@ -418,7 +437,7 @@ void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
|
|||||||
lower_tile_instruction(src, builder);
|
lower_tile_instruction(src, builder);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
Instruction *i = (Instruction*)llvm_value(src, ctx);
|
Instruction *i = (Instruction*)llvm_value(src, ctx, builder);
|
||||||
vmap_[src] = i;
|
vmap_[src] = i;
|
||||||
builder.Insert(i);
|
builder.Insert(i);
|
||||||
}
|
}
|
||||||
@@ -459,7 +478,7 @@ void selection::run(ir::module &src, Module &dst){
|
|||||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++){
|
for(unsigned i = 0; i < phi->get_num_incoming(); i++){
|
||||||
ir::value *inc_val = phi->get_incoming_value(i);
|
ir::value *inc_val = phi->get_incoming_value(i);
|
||||||
ir::basic_block *inc_block = phi->get_incoming_block(i);
|
ir::basic_block *inc_block = phi->get_incoming_block(i);
|
||||||
Value *llvm_inc_val = llvm_value(inc_val, dst_ctx);
|
Value *llvm_inc_val = llvm_value(inc_val, dst_ctx, dst_builder);
|
||||||
BasicBlock *llvm_block = bmap_[inc_block];
|
BasicBlock *llvm_block = bmap_[inc_block];
|
||||||
dst_phi->addIncoming(llvm_inc_val, llvm_block);
|
dst_phi->addIncoming(llvm_inc_val, llvm_block);
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user