[general] cleaned tensorflow source code generation
This commit is contained in:
@@ -573,7 +573,7 @@ inline void to_warps(const std::vector<unsigned> &bs, std::vector<unsigned> &nw,
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
size_t dim = shapes.size();
|
||||
if(params_->get_fragment(v, 0) == analysis::tune::STRIDED_SCAN){
|
||||
if(params_->get_fragment(v, 0) == analysis::grids::STRIDED_SCAN){
|
||||
std::vector<unsigned> contiguous(dim);
|
||||
std::vector<unsigned> block_size(dim);
|
||||
std::vector<unsigned> warp_size(dim);
|
||||
@@ -1278,7 +1278,7 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
|
||||
if(NK != 1) {
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
if(params_->get_fragment(dot, 0) == analysis::tune::STRIDED_SCAN)
|
||||
if(params_->get_fragment(dot, 0) == analysis::grids::STRIDED_SCAN)
|
||||
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
|
||||
else
|
||||
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
|
||||
|
Reference in New Issue
Block a user