[FRONTEND] Fix scanline layout (#548)

This commit is contained in:
Philippe Tillet
2022-06-13 16:21:10 -07:00
committed by GitHub
parent 7094657aa9
commit 58c8889235
2 changed files with 7 additions and 10 deletions

View File

@@ -3113,24 +3113,21 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
} }
void generator::visit_layout_scanline(analysis::scanline_layout* layout) { void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
Value *warp_size = i32(32); Value* u_thread_id = tgt_->get_local_id(mod_, *builder_, 0);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = urem(u_thread_id_0, warp_size);
Value *u_warp_id = udiv(u_thread_id_0, warp_size);
auto order = layout->get_order(); auto order = layout->get_order();
const auto& shape = layout->get_shape(); const auto& shape = layout->get_shape();
Value* full_thread_id = add(mul(u_warp_id, i32(32)), u_thread_id);
// Delinearize // Delinearize
size_t dim = shape.size(); size_t dim = shape.size();
std::vector<Value*> thread_id(dim); std::vector<Value*> thread_id(dim);
for(unsigned k = 0; k < dim - 1; k++){ for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = i32(layout->mts(order[k])); Constant *dim_k = i32(layout->mts(order[k]));
Value *rem = urem(full_thread_id, dim_k); Value *rem = urem(u_thread_id, dim_k);
full_thread_id = udiv(full_thread_id, dim_k); u_thread_id = udiv(u_thread_id, dim_k);
thread_id[order[k]] = rem; thread_id[order[k]] = rem;
} }
thread_id[order[dim - 1]] = full_thread_id; Constant *dim_k = i32(layout->mts(order[dim - 1]));
thread_id[order[dim - 1]] = urem(u_thread_id, dim_k);
// Create axes // Create axes
for(unsigned k = 0; k < dim; k++) { for(unsigned k = 0; k < dim; k++) {
int nts = layout->nts(k); int nts = layout->nts(k);

View File

@@ -691,7 +691,7 @@ def test_f16_to_f8_rounding():
@pytest.mark.parametrize("dtype_str, shape", @pytest.mark.parametrize("dtype_str, shape",
[(dtype, shape) [(dtype, shape)
for dtype in dtypes for dtype in dtypes
for shape in [128, 512]]) for shape in [32, 64, 128, 512]])
def test_reduce1d(dtype_str, shape, device='cuda'): def test_reduce1d(dtype_str, shape, device='cuda'):
# triton kernel # triton kernel