[FRONTEND] Fix scanline layout (#548)
This commit is contained in:
@@ -3113,24 +3113,21 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
||||||
Value *warp_size = i32(32);
|
Value* u_thread_id = tgt_->get_local_id(mod_, *builder_, 0);
|
||||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
|
||||||
Value *u_thread_id = urem(u_thread_id_0, warp_size);
|
|
||||||
Value *u_warp_id = udiv(u_thread_id_0, warp_size);
|
|
||||||
|
|
||||||
auto order = layout->get_order();
|
auto order = layout->get_order();
|
||||||
const auto& shape = layout->get_shape();
|
const auto& shape = layout->get_shape();
|
||||||
Value* full_thread_id = add(mul(u_warp_id, i32(32)), u_thread_id);
|
|
||||||
// Delinearize
|
// Delinearize
|
||||||
size_t dim = shape.size();
|
size_t dim = shape.size();
|
||||||
std::vector<Value*> thread_id(dim);
|
std::vector<Value*> thread_id(dim);
|
||||||
for(unsigned k = 0; k < dim - 1; k++){
|
for(unsigned k = 0; k < dim - 1; k++){
|
||||||
Constant *dim_k = i32(layout->mts(order[k]));
|
Constant *dim_k = i32(layout->mts(order[k]));
|
||||||
Value *rem = urem(full_thread_id, dim_k);
|
Value *rem = urem(u_thread_id, dim_k);
|
||||||
full_thread_id = udiv(full_thread_id, dim_k);
|
u_thread_id = udiv(u_thread_id, dim_k);
|
||||||
thread_id[order[k]] = rem;
|
thread_id[order[k]] = rem;
|
||||||
}
|
}
|
||||||
thread_id[order[dim - 1]] = full_thread_id;
|
Constant *dim_k = i32(layout->mts(order[dim - 1]));
|
||||||
|
thread_id[order[dim - 1]] = urem(u_thread_id, dim_k);
|
||||||
|
|
||||||
// Create axes
|
// Create axes
|
||||||
for(unsigned k = 0; k < dim; k++) {
|
for(unsigned k = 0; k < dim; k++) {
|
||||||
int nts = layout->nts(k);
|
int nts = layout->nts(k);
|
||||||
|
@@ -691,7 +691,7 @@ def test_f16_to_f8_rounding():
|
|||||||
@pytest.mark.parametrize("dtype_str, shape",
|
@pytest.mark.parametrize("dtype_str, shape",
|
||||||
[(dtype, shape)
|
[(dtype, shape)
|
||||||
for dtype in dtypes
|
for dtype in dtypes
|
||||||
for shape in [128, 512]])
|
for shape in [32, 64, 128, 512]])
|
||||||
def test_reduce1d(dtype_str, shape, device='cuda'):
|
def test_reduce1d(dtype_str, shape, device='cuda'):
|
||||||
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
|
Reference in New Issue
Block a user