[PYTHON] Made codebase pep8 compliant

This commit is contained in:
Philippe Tillet
2021-02-07 15:06:57 -05:00
parent 2a02fabdac
commit dffd66bc83
5 changed files with 207 additions and 177 deletions

View File

@@ -6,6 +6,7 @@ import triton._C.libtriton.triton as _triton
import triton._C.libtriton.torch_utils as _torch_utils import triton._C.libtriton.torch_utils as _torch_utils
# Make sure internal C resources are cleaned up upon exit # Make sure internal C resources are cleaned up upon exit
import atexit import atexit
@atexit.register @atexit.register
def cleanup(): def cleanup():
_triton.cleanup() _triton.cleanup()
@@ -50,8 +51,7 @@ def read(path, kernel_names=[]):
return source return source
class kernel: class kernel:
def __init__(self, src, device, defines=dict(), num_warps=4, autotune_vals=[], autotune_key=[]):
def __init__(self, src, device, defines = dict(), num_warps = 4, autotune_vals = [], autotune_key = []):
# check if src is empty # check if src is empty
if src == '': if src == '':
raise ValueError('Kernel source code is empty') raise ValueError('Kernel source code is empty')

View File

@@ -2,19 +2,13 @@ import triton
import torch import torch
import os import os
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
kernel_names=['forward'])
fwd_kernels = dict() fwd_kernels = dict()
bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['backward'])
bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'),
kernel_names=['backward'])
bwd_kernels = dict() bwd_kernels = dict()
class _softmax(torch.autograd.Function): class _softmax(torch.autograd.Function):
@staticmethod @staticmethod
def next_power_of_2(n): def next_power_of_2(n):
n -= 1 n -= 1
@@ -32,7 +26,7 @@ class _softmax(torch.autograd.Function):
sizes = _empty.clone() sizes = _empty.clone()
# sizes along rows # sizes along rows
for h in range(layout.shape[0]): for h in range(layout.shape[0]):
sizes = torch.cat((sizes, layout[h,:,:].sum(-1))) sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
# offsets in block format # offsets in block format
offsets = torch.zeros_like(sizes) offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0) offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
@@ -43,24 +37,27 @@ class _softmax(torch.autograd.Function):
columns = layout.nonzero(as_tuple=False)[:, 2] columns = layout.nonzero(as_tuple=False)[:, 2]
core = torch.stack((idx, columns, rows, head), dim=1).view(-1) core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
# construct look-up table # construct look-up table
offsets = offsets*4 + 2*sizes.numel() offsets = offsets * 4 + 2 * sizes.numel()
header = torch.stack((sizes, offsets), dim=1).view(-1) header = torch.stack((sizes, offsets), dim=1).view(-1)
lut = torch.cat((header, core)).type(torch.int32).to(device) lut = torch.cat((header, core)).type(torch.int32).to(device)
return lut, int(sizes.max()) return lut, int(sizes.max())
@staticmethod @staticmethod
def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode): def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
kp_mask_mode, attn_mask_mode):
if max_k >= 32768: if max_k >= 32768:
raise NotImplementedError('Reductions larger than 32768 elements '\ raise NotImplementedError('Reductions larger than 32768 elements '\
'are not yet implemented') 'are not yet implemented')
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16) num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
TN = _softmax.next_power_of_2(max_k) TN = _softmax.next_power_of_2(max_k)
# just-in-time compile kernel # just-in-time compile kernel
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode) key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
kp_mask_mode, attn_mask_mode)
if key not in cache: if key not in cache:
defines = {'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block, defines = {
'INFINITY': {torch.float32: 'F32_INFINITY', 'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block, 'INFINITY':
torch.float16: 'F16_INFINITY'}[dtype]} {torch.float32: 'F32_INFINITY', torch.float16: 'F16_INFINITY'}[dtype]
}
if apply_scale: if apply_scale:
defines['APPLY_SCALE'] = True defines['APPLY_SCALE'] = True
if apply_rpe: if apply_rpe:
@@ -78,8 +75,8 @@ class _softmax(torch.autograd.Function):
return cache[key] return cache[key]
@staticmethod @staticmethod
def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut,
spdims, block, lut, maxlut, bench, time): maxlut, bench, time):
apply_scale = False if scale == 1.0 else True apply_scale = False if scale == 1.0 else True
# handle None rpe # handle None rpe
@@ -109,17 +106,26 @@ class _softmax(torch.autograd.Function):
apply_attn_mask = True apply_attn_mask = True
stride_zattnm = attn_mask.stride(0) stride_zattnm = attn_mask.stride(0)
# run kernel # run kernel
kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut*block, x.device, x.dtype, block, kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut * block, x.device, x.dtype, block, apply_scale,
apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
kp_mask_mode, attn_mask_mode)
M = x.shape[0] M = x.shape[0]
grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M] grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M]
# run kernel # run kernel
kernel(x.data_ptr(), scale, lut.data_ptr(), rpe.data_ptr(), key_padding_mask.data_ptr(), attn_mask.data_ptr(), kernel(x.data_ptr(),
maxlut, x.stride(0), stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, scale,
lut.data_ptr(),
rpe.data_ptr(),
key_padding_mask.data_ptr(),
attn_mask.data_ptr(),
maxlut,
x.stride(0),
stride_zrpe,
stride_hrpe,
stride_srpe,
stride_zkpm,
stride_zattnm,
grid=grid) grid=grid)
# save to context # save to context
ctx.mark_dirty(x) ctx.mark_dirty(x)
@@ -141,7 +147,7 @@ class _softmax(torch.autograd.Function):
# retrieve from context # retrieve from context
x, lut = ctx.saved_tensors x, lut = ctx.saved_tensors
# run kernel # run kernel
kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut*ctx.block, x.device, x.dtype, ctx.block, kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut * ctx.block, x.device, x.dtype, ctx.block,
ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask, ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask,
ctx.kp_mask_mode, ctx.attn_mask_mode) ctx.kp_mask_mode, ctx.attn_mask_mode)
M = x.shape[0] M = x.shape[0]
@@ -159,15 +165,21 @@ class softmax:
self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device) self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
return self.lut_cache[key] return self.lut_cache[key]
def __init__(self, layout, block, bench = False): def __init__(self, layout, block, bench=False):
self.spdims = layout.shape self.spdims = layout.shape
self.layout = layout self.layout = layout
self.block = block self.block = block
self.bench = bench self.bench = bench
self.lut_cache = dict() self.lut_cache = dict()
def __call__(self, x, scale = 1., rpe = None, key_padding_mask = None, attn_mask = None, def __call__(self,
key_padding_mask_mode='add', attn_mask_mode='add'): x,
scale=1.,
rpe=None,
key_padding_mask=None,
attn_mask=None,
key_padding_mask_mode='add',
attn_mask_mode='add'):
time_y = [None] time_y = [None]
if rpe is not None and rpe.dtype != x.dtype: if rpe is not None and rpe.dtype != x.dtype:
raise ValueError('relative position embedding must be %s' % x.dtype) raise ValueError('relative position embedding must be %s' % x.dtype)
@@ -176,9 +188,6 @@ class softmax:
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype: if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
raise ValueError('Key padding mask must be %s' % x.dtype) raise ValueError('Key padding mask must be %s' % x.dtype)
lut, maxlut = self.make_lut(x.device) lut, maxlut = self.make_lut(x.device)
x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask, x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode,
key_padding_mask_mode, attn_mask_mode, self.spdims, self.block, lut, maxlut, self.bench, time_y)
self.spdims, self.block,
lut,
maxlut, self.bench, time_y)
return x return x

View File

@@ -22,36 +22,60 @@ class _conv(torch.autograd.Function):
# shapes # shapes
Z, CI, H, W = a.shape Z, CI, H, W = a.shape
_, R, S, CO = b.shape _, R, S, CO = b.shape
P = (H + 2*pad[0] - R)//stride[0] + 1 P = (H + 2 * pad[0] - R) // stride[0] + 1
Q = (W + 2*pad[1] - S)//stride[1] + 1 Q = (W + 2 * pad[1] - S) // stride[1] + 1
# compile kernel # compile kernel
if (dtype, device) not in _conv.kernel: if (dtype, device) not in _conv.kernel:
TK = 16 TK = 16
defines = { defines = {
'TYPE' : dtype, 'TYPE': dtype,
'TM' : 64, 'TM': 64,
'TN' : 64, 'TN': 64,
'TK' : TK, 'TK': TK,
'TZ' : 1, 'TZ': 1,
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R, 'HH': H,
'WW': W,
'PP': P,
'QQ': Q,
'SS': S,
'RR': R,
} }
idx = torch.arange(CI*R*S) idx = torch.arange(CI * R * S)
ci, r, s = _conv.unpack(idx, CI, R, S) ci, r, s = _conv.unpack(idx, CI, R, S)
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S) nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3) delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + (ns - s) * a.stride(3)
delta = delta.type(torch.int32).cuda() delta = delta.type(torch.int32).cuda()
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines)) _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines))
delta, kernel = _conv.kernel[dtype] delta, kernel = _conv.kernel[dtype]
# allocate output # allocate output
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device) c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
# enqueue # enqueue
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), 1., Z*P*Q, CO, CI*R*S, kernel(
pad[0], pad[1], stride[0], stride[1], a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
1.,
Z * P * Q,
CO,
CI * R * S,
pad[0],
pad[1],
stride[0],
stride[1],
delta.data_ptr(), delta.data_ptr(),
a.stride(0), a.stride(1), a.stride(2), a.stride(3), a.stride(0),
b.stride(0), b.stride(1), b.stride(2), b.stride(3), a.stride(1),
c.stride(0), c.stride(1), c.stride(2), c.stride(3), a.stride(2),
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.TM), triton.cdiv(CO, opt.TN)]) a.stride(3),
b.stride(0),
b.stride(1),
b.stride(2),
b.stride(3),
c.stride(0),
c.stride(1),
c.stride(2),
c.stride(3),
grid=lambda opt: [triton.cdiv(Z * P * Q, opt.TM), triton.cdiv(CO, opt.TN)])
return c return c
conv = _conv.apply conv = _conv.apply

View File

@@ -8,16 +8,16 @@ class _matmul(torch.autograd.Function):
_DEFAULT_CONFIGS = [ _DEFAULT_CONFIGS = [
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), ({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), ({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '64' , 'TK': '32', 'TZ': '1'}, 4), ({'TM': '128', 'TN': '64', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 4), ({'TM': '64', 'TN': '64', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '32' , 'TN': '128', 'TK': '64', 'TZ': '1'}, 4), ({'TM': '32', 'TN': '128', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '32' , 'TK': '64', 'TZ': '1'}, 4), ({'TM': '128', 'TN': '32', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '64' , 'TN': '32' , 'TK': '64', 'TZ': '1'}, 2), ({'TM': '64', 'TN': '32', 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 2), ({'TM': '32', 'TN': '64', 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), ({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), ({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), ({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), ({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
] ]
_CONFIGS = _DEFAULT_CONFIGS _CONFIGS = _DEFAULT_CONFIGS
@@ -28,9 +28,9 @@ class _matmul(torch.autograd.Function):
if N % 2 == 0: return 2 if N % 2 == 0: return 2
return 1 return 1
_locks = dict() _locks = dict()
_kernels = dict() _kernels = dict()
@staticmethod @staticmethod
def _call(a, b): def _call(a, b):
dtype = a.dtype dtype = a.dtype
@@ -55,22 +55,20 @@ class _matmul(torch.autograd.Function):
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k) key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
if key not in _matmul._kernels: if key not in _matmul._kernels:
defines = { defines = {
'TYPE' : dtype, 'TYPE': dtype, 'STRIDE_AM': 'lda' if is_a_row else '1', 'STRIDE_AK': '1' if is_a_row else 'lda',
'STRIDE_AM' : 'lda' if is_a_row else '1', 'STRIDE_BK': 'ldb' if is_b_row else '1', 'STRIDE_BN': '1' if is_b_row else 'ldb', 'LDA_POW2_DIV':
'STRIDE_AK' : '1' if is_a_row else 'lda', lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, 'IS_TK_DIV_K':
'STRIDE_BK' : 'ldb' if is_b_row else '1', int(is_tk_div_k)
'STRIDE_BN' : '1' if is_b_row else 'ldb',
'LDA_POW2_DIV': lda_pow2_div,
'LDB_POW2_DIV': ldb_pow2_div,
'LDC_POW2_DIV': ldc_pow2_div,
'IS_TK_DIV_K' : int(is_tk_div_k)
} }
_matmul._kernels[key] = triton.kernel(_matmul.src, device, defines=defines, _matmul._kernels[key] = triton.kernel(_matmul.src,
autotune_vals = _matmul._CONFIGS, autotune_key=['M', 'N', 'K']) device,
defines=defines,
autotune_vals=_matmul._CONFIGS,
autotune_key=['M', 'N', 'K'])
kernel = _matmul._kernels[key] kernel = _matmul._kernels[key]
# # locks for split-k # # locks for split-k
if device not in _matmul._locks: if device not in _matmul._locks:
_matmul._locks[device] = torch.zeros(1024*1024, dtype=torch.int32, device=device) _matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
locks = _matmul._locks[device] locks = _matmul._locks[device]
# enqueue # enqueue
alpha = 1. alpha = 1.
@@ -81,7 +79,7 @@ class _matmul(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, a, b): def forward(ctx, a, b):
c = _matmul._call(a,b) c = _matmul._call(a, b)
return c return c
matmul = _matmul.apply matmul = _matmul.apply

View File

@@ -4,23 +4,22 @@ import os
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward']) fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
fwd_kernels = dict() fwd_kernels = dict()
def get_fwd_kernel(block, dtype, device): def get_fwd_kernel(block, dtype, device):
key = (block, dtype, device) key = (block, dtype, device)
if key not in fwd_kernels: if key not in fwd_kernels:
defines = {'BLOCK': block, 'TYPE': dtype} defines = {'BLOCK': block, 'TYPE': dtype}
fwd_kernels[key] = triton.kernel(fwd_src, device = device, defines = defines) fwd_kernels[key] = triton.kernel(fwd_src, device=device, defines=defines)
return fwd_kernels[key] return fwd_kernels[key]
class _softmax(torch.autograd.Function): class _softmax(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
y = torch.empty_like(x) y = torch.empty_like(x)
M, N = x.shape M, N = x.shape
kernel = get_fwd_kernel(N, x.dtype, x.device) kernel = get_fwd_kernel(N, x.dtype, x.device)
kernel(x.data_ptr(), y.data_ptr(), grid = lambda opt: [M, ]) grid = lambda opt: (M, )
kernel(x.data_ptr(), y.data_ptr(), grid=grid)
return y return y
softmax = _softmax.apply softmax = _softmax.apply