[CODEGEN] Fixed over-aggressive division handling in alignment pass (#280)

This commit is contained in:
Philippe Tillet
2021-09-15 00:40:17 -07:00
committed by GitHub
parent da5063d898
commit 313d6488f6
4 changed files with 15 additions and 16 deletions

View File

@@ -401,7 +401,7 @@ std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operato
if(x->is_int_add_sub())
result[d] = gcd(lhs[d], rhs[d]);
if(x->is_int_div())
result[d] = std::max<unsigned>(lhs[d] / rhs[d], 1);
result[d] = 1;
if(x->is_int_rem() && rhs[d] > 1){
result[d] = gcd(lhs[d], rhs[d]);
}

View File

@@ -52,7 +52,7 @@ def test_matmul(M, N, K):
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
ref_sm_clock = 1350
max_gpu_perf = 1e-6*80*8*128*cur_sm_clock
assert cur_sm_clock == ref_sm_clock, f'GPU SMs must run at {ref_sm_clock} MHz'
assert abs(cur_sm_clock - ref_sm_clock) < 5, f'GPU SMs must run at {ref_sm_clock} MHz'
a = torch.randn((M, K), dtype=torch.float16, device='cuda')
b = torch.randn((K, N), dtype=torch.float16, device='cuda')
fn = lambda: triton.ops.matmul(a, b)
@@ -95,7 +95,7 @@ def test_elementwise(N):
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
ref_mem_clock = 877
max_gpu_perf = 512*2*ref_mem_clock*1e-3
assert cur_mem_clock == ref_mem_clock, f'GPU memmory must run at {ref_mem_clock} MHz'
assert abs(cur_mem_clock - ref_mem_clock) < 5, f'GPU memmory must run at {ref_mem_clock} MHz'
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z)
y = torch.randn_like(z)

View File

@@ -30,7 +30,7 @@ def _kernel(
z = tl.load(header + 0)
i = tl.load(header + 1 + offlutm)
j = tl.load(header + 2 + offlutn)
AS1 = SDD_K // TZ
AS1 = SDD_K
lockid = tl.where(TZ > 1, 1, 0)
offka = pid0 * AS1
offkb = pid0 * AS1
@@ -96,8 +96,8 @@ def _kernel(
# initialize a, b pointers
rka = offka + tl.arange(0, TK)
rkb = offkb + tl.arange(0, TK)
pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
pa = A + pidz * TZ * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
pb = B + pidz * TZ * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
if meta['DDS']:
checkam = ram[:, None] < DS0
else:
@@ -113,11 +113,11 @@ def _kernel(
## Inner Loop ##
## ---------------- ##
acc = tl.zeros((TM, TN), dtype=tl.float32)
for k in range(AS1, 0, -TK):
for k in range(AS1, 0, -TK*TZ):
acc += tl.dot(a, b)
if meta['SDD']:
inc_a = TK * stride_ka
inc_b = TK * stride_kb
inc_a = TK * TZ * stride_ka
inc_b = TK * TZ * stride_kb
else:
pinc += 2
if meta['DSD']:

View File

@@ -49,13 +49,12 @@ def _kernel(A, B, C, M, N, K,
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
rk = pid_z*BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
K = K // SPLIT_K
A = A + (pid_z * K * stride_ak + ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K):
for k in range(K, 0, -BLOCK_K*SPLIT_K):
if META['EVEN_K']:
a = tl.load(A)
b = tl.load(B)
@@ -63,8 +62,8 @@ def _kernel(A, B, C, M, N, K,
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(tl.float16)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)