[FRONTEND] Fix scanline layout (#548)
This commit is contained in:
@@ -3113,24 +3113,21 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
|
||||
}
|
||||
|
||||
void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
||||
Value *warp_size = i32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = urem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = udiv(u_thread_id_0, warp_size);
|
||||
|
||||
Value* u_thread_id = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
auto order = layout->get_order();
|
||||
const auto& shape = layout->get_shape();
|
||||
Value* full_thread_id = add(mul(u_warp_id, i32(32)), u_thread_id);
|
||||
// Delinearize
|
||||
size_t dim = shape.size();
|
||||
std::vector<Value*> thread_id(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = i32(layout->mts(order[k]));
|
||||
Value *rem = urem(full_thread_id, dim_k);
|
||||
full_thread_id = udiv(full_thread_id, dim_k);
|
||||
Value *rem = urem(u_thread_id, dim_k);
|
||||
u_thread_id = udiv(u_thread_id, dim_k);
|
||||
thread_id[order[k]] = rem;
|
||||
}
|
||||
thread_id[order[dim - 1]] = full_thread_id;
|
||||
Constant *dim_k = i32(layout->mts(order[dim - 1]));
|
||||
thread_id[order[dim - 1]] = urem(u_thread_id, dim_k);
|
||||
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
int nts = layout->nts(k);
|
||||
|
@@ -691,7 +691,7 @@ def test_f16_to_f8_rounding():
|
||||
@pytest.mark.parametrize("dtype_str, shape",
|
||||
[(dtype, shape)
|
||||
for dtype in dtypes
|
||||
for shape in [128, 512]])
|
||||
for shape in [32, 64, 128, 512]])
|
||||
def test_reduce1d(dtype_str, shape, device='cuda'):
|
||||
|
||||
# triton kernel
|
||||
|
Reference in New Issue
Block a user