From 58c8889235e343066d48570e6b59c5383bbe7e6e Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 13 Jun 2022 16:21:10 -0700 Subject: [PATCH] [FRONTEND] Fix scanline layout (#548) --- lib/codegen/selection/generator.cc | 15 ++++++--------- python/test/unit/language/test_core.py | 2 +- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index a04949dc5..f88ecf833 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -3113,24 +3113,21 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) { } void generator::visit_layout_scanline(analysis::scanline_layout* layout) { - Value *warp_size = i32(32); - Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); - Value *u_thread_id = urem(u_thread_id_0, warp_size); - Value *u_warp_id = udiv(u_thread_id_0, warp_size); - + Value* u_thread_id = tgt_->get_local_id(mod_, *builder_, 0); auto order = layout->get_order(); const auto& shape = layout->get_shape(); - Value* full_thread_id = add(mul(u_warp_id, i32(32)), u_thread_id); // Delinearize size_t dim = shape.size(); std::vector thread_id(dim); for(unsigned k = 0; k < dim - 1; k++){ Constant *dim_k = i32(layout->mts(order[k])); - Value *rem = urem(full_thread_id, dim_k); - full_thread_id = udiv(full_thread_id, dim_k); + Value *rem = urem(u_thread_id, dim_k); + u_thread_id = udiv(u_thread_id, dim_k); thread_id[order[k]] = rem; } - thread_id[order[dim - 1]] = full_thread_id; + Constant *dim_k = i32(layout->mts(order[dim - 1])); + thread_id[order[dim - 1]] = urem(u_thread_id, dim_k); + // Create axes for(unsigned k = 0; k < dim; k++) { int nts = layout->nts(k); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index c76cbbd95..6ea3ebc9d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -691,7 +691,7 @@ def test_f16_to_f8_rounding(): @pytest.mark.parametrize("dtype_str, shape", [(dtype, shape) for dtype in dtypes - for shape in [128, 512]]) + for shape in [32, 64, 128, 512]]) def test_reduce1d(dtype_str, shape, device='cuda'): # triton kernel