[python] added basic tensorflow support

This commit is contained in:
Philippe Tillet
2019-08-17 18:18:26 -07:00
parent 078f0052fe
commit b4a9ed9663
24 changed files with 341 additions and 173 deletions

View File

@@ -411,7 +411,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
Value *else_value = value(ii->get_operand(2));
return builder.Insert(SelectInst::Create(pred, if_value, else_value));
}
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
if(ir::get_program_id_inst* ii = dynamic_cast<ir::get_program_id_inst*>(inst)){
Value *result = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
return (Instruction*)result;
}
@@ -837,7 +837,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
T->set_value(idx, idx[0]);
});
}
if(is_inserted && dynamic_cast<ir::nv_static_range_idx*>(v)){
if(is_inserted && dynamic_cast<ir::nv_static_program_idx*>(v)){
T->for_each([&](indices_t idx){
assert(idx.size() == 1);
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
@@ -996,7 +996,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
}
}
void selection::lower_dynamic_range_idx(ir::nv_dynamic_range_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
void selection::lower_dynamic_program_idx(ir::nv_dynamic_program_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* result = (distributed_tile*)tmap_.at(x);
result->for_each([&](indices_t idx){
assert(idx.size() == 1);
@@ -1418,8 +1418,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
lower_downcast(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::reduce_inst*>(ins))
lower_reduce(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::nv_dynamic_range_idx_inst*>(ins))
lower_dynamic_range_idx(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::nv_dynamic_program_idx_inst*>(ins))
lower_dynamic_program_idx(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::reshape_inst*>(ins))
lower_reshape(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::splat_inst*>(ins))