diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index b972f6185..e92d3b6ee 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -401,7 +401,7 @@ std::vector align::populate_starting_multiple_binop(ir::binary_operato if(x->is_int_add_sub()) result[d] = gcd(lhs[d], rhs[d]); if(x->is_int_div()) - result[d] = std::max(lhs[d] / rhs[d], 1); + result[d] = 1; if(x->is_int_rem() && rhs[d] > 1){ result[d] = gcd(lhs[d], rhs[d]); } diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index d8ae65742..a2c9f08ca 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -52,7 +52,7 @@ def test_matmul(M, N, K): cur_sm_clock = nvsmi(['clocks.current.sm'])[0] ref_sm_clock = 1350 max_gpu_perf = 1e-6*80*8*128*cur_sm_clock - assert cur_sm_clock == ref_sm_clock, f'GPU SMs must run at {ref_sm_clock} MHz' + assert abs(cur_sm_clock - ref_sm_clock) < 5, f'GPU SMs must run at {ref_sm_clock} MHz' a = torch.randn((M, K), dtype=torch.float16, device='cuda') b = torch.randn((K, N), dtype=torch.float16, device='cuda') fn = lambda: triton.ops.matmul(a, b) @@ -95,7 +95,7 @@ def test_elementwise(N): cur_mem_clock = nvsmi(['clocks.current.memory'])[0] ref_mem_clock = 877 max_gpu_perf = 512*2*ref_mem_clock*1e-3 - assert cur_mem_clock == ref_mem_clock, f'GPU memmory must run at {ref_mem_clock} MHz' + assert abs(cur_mem_clock - ref_mem_clock) < 5, f'GPU memmory must run at {ref_mem_clock} MHz' z = torch.empty((N, ), dtype=torch.float16, device='cuda') x = torch.randn_like(z) y = torch.randn_like(z) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index abafed584..d9528133c 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -30,7 +30,7 @@ def _kernel( z = tl.load(header + 0) i = tl.load(header + 1 + offlutm) j = tl.load(header + 2 + offlutn) - AS1 = SDD_K // TZ + AS1 = SDD_K lockid = tl.where(TZ > 1, 1, 0) offka = pid0 * AS1 offkb = pid0 * AS1 @@ -96,8 +96,8 @@ def _kernel( # initialize a, b pointers rka = offka + tl.arange(0, TK) rkb = offkb + tl.arange(0, TK) - pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka - pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb + pa = A + pidz * TZ * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka + pb = B + pidz * TZ * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb if meta['DDS']: checkam = ram[:, None] < DS0 else: @@ -113,11 +113,11 @@ def _kernel( ## Inner Loop ## ## ---------------- ## acc = tl.zeros((TM, TN), dtype=tl.float32) - for k in range(AS1, 0, -TK): + for k in range(AS1, 0, -TK*TZ): acc += tl.dot(a, b) if meta['SDD']: - inc_a = TK * stride_ka - inc_b = TK * stride_kb + inc_a = TK * TZ * stride_ka + inc_b = TK * TZ * stride_kb else: pinc += 2 if meta['DSD']: diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 96c458285..22f5f6cc2 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -49,13 +49,12 @@ def _kernel(A, B, C, M, N, K, rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) + rk = pid_z*BLOCK_K + tl.arange(0, BLOCK_K) # pointers - K = K // SPLIT_K - A = A + (pid_z * K * stride_ak + ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(K, 0, -BLOCK_K): + for k in range(K, 0, -BLOCK_K*SPLIT_K): if META['EVEN_K']: a = tl.load(A) b = tl.load(B) @@ -63,8 +62,8 @@ def _kernel(A, B, C, M, N, K, a = tl.load(A, mask=rk[None, :] < k, other=0.) b = tl.load(B, mask=rk[:, None] < k, other=0.) acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk acc = acc.to(tl.float16) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)