[CODEGEN] Fixed over-aggressive division handling in alignment pass (#280)

This commit is contained in:
Philippe Tillet
2021-09-15 00:40:17 -07:00
committed by GitHub
parent da5063d898
commit 313d6488f6
4 changed files with 15 additions and 16 deletions

View File

@@ -401,7 +401,7 @@ std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operato
if(x->is_int_add_sub()) if(x->is_int_add_sub())
result[d] = gcd(lhs[d], rhs[d]); result[d] = gcd(lhs[d], rhs[d]);
if(x->is_int_div()) if(x->is_int_div())
result[d] = std::max<unsigned>(lhs[d] / rhs[d], 1); result[d] = 1;
if(x->is_int_rem() && rhs[d] > 1){ if(x->is_int_rem() && rhs[d] > 1){
result[d] = gcd(lhs[d], rhs[d]); result[d] = gcd(lhs[d], rhs[d]);
} }

View File

@@ -52,7 +52,7 @@ def test_matmul(M, N, K):
cur_sm_clock = nvsmi(['clocks.current.sm'])[0] cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
ref_sm_clock = 1350 ref_sm_clock = 1350
max_gpu_perf = 1e-6*80*8*128*cur_sm_clock max_gpu_perf = 1e-6*80*8*128*cur_sm_clock
assert cur_sm_clock == ref_sm_clock, f'GPU SMs must run at {ref_sm_clock} MHz' assert abs(cur_sm_clock - ref_sm_clock) < 5, f'GPU SMs must run at {ref_sm_clock} MHz'
a = torch.randn((M, K), dtype=torch.float16, device='cuda') a = torch.randn((M, K), dtype=torch.float16, device='cuda')
b = torch.randn((K, N), dtype=torch.float16, device='cuda') b = torch.randn((K, N), dtype=torch.float16, device='cuda')
fn = lambda: triton.ops.matmul(a, b) fn = lambda: triton.ops.matmul(a, b)
@@ -95,7 +95,7 @@ def test_elementwise(N):
cur_mem_clock = nvsmi(['clocks.current.memory'])[0] cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
ref_mem_clock = 877 ref_mem_clock = 877
max_gpu_perf = 512*2*ref_mem_clock*1e-3 max_gpu_perf = 512*2*ref_mem_clock*1e-3
assert cur_mem_clock == ref_mem_clock, f'GPU memmory must run at {ref_mem_clock} MHz' assert abs(cur_mem_clock - ref_mem_clock) < 5, f'GPU memmory must run at {ref_mem_clock} MHz'
z = torch.empty((N, ), dtype=torch.float16, device='cuda') z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z) x = torch.randn_like(z)
y = torch.randn_like(z) y = torch.randn_like(z)

View File

@@ -30,7 +30,7 @@ def _kernel(
z = tl.load(header + 0) z = tl.load(header + 0)
i = tl.load(header + 1 + offlutm) i = tl.load(header + 1 + offlutm)
j = tl.load(header + 2 + offlutn) j = tl.load(header + 2 + offlutn)
AS1 = SDD_K // TZ AS1 = SDD_K
lockid = tl.where(TZ > 1, 1, 0) lockid = tl.where(TZ > 1, 1, 0)
offka = pid0 * AS1 offka = pid0 * AS1
offkb = pid0 * AS1 offkb = pid0 * AS1
@@ -96,8 +96,8 @@ def _kernel(
# initialize a, b pointers # initialize a, b pointers
rka = offka + tl.arange(0, TK) rka = offka + tl.arange(0, TK)
rkb = offkb + tl.arange(0, TK) rkb = offkb + tl.arange(0, TK)
pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka pa = A + pidz * TZ * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb pb = B + pidz * TZ * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
if meta['DDS']: if meta['DDS']:
checkam = ram[:, None] < DS0 checkam = ram[:, None] < DS0
else: else:
@@ -113,11 +113,11 @@ def _kernel(
## Inner Loop ## ## Inner Loop ##
## ---------------- ## ## ---------------- ##
acc = tl.zeros((TM, TN), dtype=tl.float32) acc = tl.zeros((TM, TN), dtype=tl.float32)
for k in range(AS1, 0, -TK): for k in range(AS1, 0, -TK*TZ):
acc += tl.dot(a, b) acc += tl.dot(a, b)
if meta['SDD']: if meta['SDD']:
inc_a = TK * stride_ka inc_a = TK * TZ * stride_ka
inc_b = TK * stride_kb inc_b = TK * TZ * stride_kb
else: else:
pinc += 2 pinc += 2
if meta['DSD']: if meta['DSD']:

View File

@@ -49,13 +49,12 @@ def _kernel(A, B, C, M, N, K,
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K) rk = pid_z*BLOCK_K + tl.arange(0, BLOCK_K)
# pointers # pointers
K = K // SPLIT_K A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
A = A + (pid_z * K * stride_ak + ram[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K): for k in range(K, 0, -BLOCK_K*SPLIT_K):
if META['EVEN_K']: if META['EVEN_K']:
a = tl.load(A) a = tl.load(A)
b = tl.load(B) b = tl.load(B)
@@ -63,8 +62,8 @@ def _kernel(A, B, C, M, N, K,
a = tl.load(A, mask=rk[None, :] < k, other=0.) a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.) b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b) acc += tl.dot(a, b)
A += BLOCK_K * stride_ak A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * stride_bk B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(tl.float16) acc = acc.to(tl.float16)
# rematerialize rm and rn to save registers # rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)