[general] cleaned tensorflow source code generation

This commit is contained in:
Philippe Tillet
2019-08-18 15:39:36 -07:00
parent 457c330f15
commit 0970fe12dd
12 changed files with 162 additions and 152 deletions

View File

@@ -573,7 +573,7 @@ inline void to_warps(const std::vector<unsigned> &bs, std::vector<unsigned> &nw,
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
const auto& shapes = v->get_type()->get_tile_shapes();
size_t dim = shapes.size();
if(params_->get_fragment(v, 0) == analysis::tune::STRIDED_SCAN){
if(params_->get_fragment(v, 0) == analysis::grids::STRIDED_SCAN){
std::vector<unsigned> contiguous(dim);
std::vector<unsigned> block_size(dim);
std::vector<unsigned> warp_size(dim);
@@ -1278,7 +1278,7 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
if(NK != 1) {
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(params_->get_fragment(dot, 0) == analysis::tune::STRIDED_SCAN)
if(params_->get_fragment(dot, 0) == analysis::grids::STRIDED_SCAN)
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
else
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);