[PYTHON] Made codebase pep8 compliant

This commit is contained in:
Philippe Tillet
2021-02-07 15:06:57 -05:00
parent 2a02fabdac
commit dffd66bc83
5 changed files with 207 additions and 177 deletions

View File

@@ -6,34 +6,35 @@ import triton._C.libtriton.triton as _triton
import triton._C.libtriton.torch_utils as _torch_utils import triton._C.libtriton.torch_utils as _torch_utils
# Make sure internal C resources are cleaned up upon exit # Make sure internal C resources are cleaned up upon exit
import atexit import atexit
@atexit.register @atexit.register
def cleanup(): def cleanup():
_triton.cleanup() _triton.cleanup()
codes = { codes = {
_triton.arg_type.int1: 'B', _triton.arg_type.int1: 'B',
_triton.arg_type.int8: 'B', _triton.arg_type.int8: 'B',
_triton.arg_type.int32: 'I', _triton.arg_type.int32: 'I',
_triton.arg_type.int64: 'Q', _triton.arg_type.int64: 'Q',
_triton.arg_type.half: 'H', _triton.arg_type.half: 'H',
_triton.arg_type.float: 'f', _triton.arg_type.float: 'f',
_triton.arg_type.double: 'd', _triton.arg_type.double: 'd',
_triton.arg_type.buffer: 'P' _triton.arg_type.buffer: 'P'
} }
def th_to_triton(obj): def th_to_triton(obj):
tys = { tys = {
torch.int8: 'char', torch.int8: 'char',
torch.int16: 'short', torch.int16: 'short',
torch.int32: 'int', torch.int32: 'int',
torch.int64: 'long', torch.int64: 'long',
torch.float16: 'half', torch.float16: 'half',
torch.float32: 'float', torch.float32: 'float',
torch.float64: 'double' torch.float64: 'double'
} }
if isinstance(obj, torch.dtype): if isinstance(obj, torch.dtype):
return tys[obj] return tys[obj]
return str(obj) return str(obj)
def cdiv(a, b): def cdiv(a, b):
return (a + b - 1) // b return (a + b - 1) // b
@@ -44,46 +45,45 @@ def synchronize(device):
_torch_utils.synchronize(dev_id) _torch_utils.synchronize(dev_id)
def read(path, kernel_names=[]): def read(path, kernel_names=[]):
with open(path, 'r') as f: with open(path, 'r') as f:
source = f.read() source = f.read()
source = _triton.extract_kernels(source, kernel_names) source = _triton.extract_kernels(source, kernel_names)
return source return source
class kernel: class kernel:
def __init__(self, src, device, defines=dict(), num_warps=4, autotune_vals=[], autotune_key=[]):
# check if src is empty
if src == '':
raise ValueError('Kernel source code is empty')
self.src = src
self.opt = _triton.options()
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
self.opt.num_warps = num_warps
# device
assert device.type in ['cuda', 'cpu']
if device.type == 'cuda':
self.device = torch.cuda.current_device() if device.index is None else device.index
if device.type == 'cpu':
self.device = -1
_torch_utils.register_device(self.device)
_torch_utils.register_stream(self.device)
# C++ function wrapper
self.op_id = _triton.make_op_id()
_triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
# debug mode
self.is_debug = 'TRITON_DEBUG' in os.environ
# signature
arg_types = _triton.get_fn_signature(self.op_id)
self.tys = ''.join([codes[x] for x in arg_types])
def __init__(self, src, device, defines = dict(), num_warps = 4, autotune_vals = [], autotune_key = []): def __call__(self, *args, grid):
# check if src is empty _torch_utils.set_device(self.device)
if src == '': # pack parameters into a byte buffer
raise ValueError('Kernel source code is empty') params = struct.pack(self.tys, *args)
self.src = src opt = _triton.autotune(self.op_id, self.device, params, grid)
self.opt = _triton.options() # run kernel
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()} grid = grid(opt)
self.opt.num_warps = num_warps grid_0 = grid[0]
# device grid_1 = 1 if len(grid) < 2 else grid[1]
assert device.type in ['cuda', 'cpu'] grid_2 = 1 if len(grid) < 3 else grid[2]
if device.type == 'cuda': _triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)
self.device = torch.cuda.current_device() if device.index is None else device.index
if device.type == 'cpu':
self.device = -1
_torch_utils.register_device(self.device)
_torch_utils.register_stream(self.device)
# C++ function wrapper
self.op_id = _triton.make_op_id()
_triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
# debug mode
self.is_debug = 'TRITON_DEBUG' in os.environ
# signature
arg_types = _triton.get_fn_signature(self.op_id)
self.tys = ''.join([codes[x] for x in arg_types])
def __call__(self, *args, grid):
_torch_utils.set_device(self.device)
# pack parameters into a byte buffer
params = struct.pack(self.tys, *args)
opt = _triton.autotune(self.op_id, self.device, params, grid)
# run kernel
grid = grid(opt)
grid_0 = grid[0]
grid_1 = 1 if len(grid) < 2 else grid[1]
grid_2 = 1 if len(grid) < 3 else grid[2]
_triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)

View File

@@ -2,21 +2,15 @@ import triton
import torch import torch
import os import os
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
kernel_names=['forward'])
fwd_kernels = dict() fwd_kernels = dict()
bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['backward'])
bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'),
kernel_names=['backward'])
bwd_kernels = dict() bwd_kernels = dict()
class _softmax(torch.autograd.Function): class _softmax(torch.autograd.Function):
@staticmethod @staticmethod
def next_power_of_2(n): def next_power_of_2(n):
n -= 1 n -= 1
n |= n >> 1 n |= n >> 1
n |= n >> 2 n |= n >> 2
@@ -24,7 +18,7 @@ class _softmax(torch.autograd.Function):
n |= n >> 8 n |= n >> 8
n |= n >> 16 n |= n >> 16
n += 1 n += 1
return n return n
@staticmethod @staticmethod
def make_lut(layout, block, device): def make_lut(layout, block, device):
@@ -32,7 +26,7 @@ class _softmax(torch.autograd.Function):
sizes = _empty.clone() sizes = _empty.clone()
# sizes along rows # sizes along rows
for h in range(layout.shape[0]): for h in range(layout.shape[0]):
sizes = torch.cat((sizes, layout[h,:,:].sum(-1))) sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
# offsets in block format # offsets in block format
offsets = torch.zeros_like(sizes) offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0) offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
@@ -41,26 +35,29 @@ class _softmax(torch.autograd.Function):
head = layout.nonzero(as_tuple=False)[:, 0] head = layout.nonzero(as_tuple=False)[:, 0]
rows = layout.nonzero(as_tuple=False)[:, 1] rows = layout.nonzero(as_tuple=False)[:, 1]
columns = layout.nonzero(as_tuple=False)[:, 2] columns = layout.nonzero(as_tuple=False)[:, 2]
core = torch.stack((idx, columns, rows, head), dim=1).view(-1) core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
# construct look-up table # construct look-up table
offsets = offsets*4 + 2*sizes.numel() offsets = offsets * 4 + 2 * sizes.numel()
header = torch.stack((sizes, offsets), dim=1).view(-1) header = torch.stack((sizes, offsets), dim=1).view(-1)
lut = torch.cat((header, core)).type(torch.int32).to(device) lut = torch.cat((header, core)).type(torch.int32).to(device)
return lut, int(sizes.max()) return lut, int(sizes.max())
@staticmethod @staticmethod
def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode): def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
kp_mask_mode, attn_mask_mode):
if max_k >= 32768: if max_k >= 32768:
raise NotImplementedError('Reductions larger than 32768 elements '\ raise NotImplementedError('Reductions larger than 32768 elements '\
'are not yet implemented') 'are not yet implemented')
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16) num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
TN = _softmax.next_power_of_2(max_k) TN = _softmax.next_power_of_2(max_k)
# just-in-time compile kernel # just-in-time compile kernel
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode) key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
kp_mask_mode, attn_mask_mode)
if key not in cache: if key not in cache:
defines = {'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block, defines = {
'INFINITY': {torch.float32: 'F32_INFINITY', 'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block, 'INFINITY':
torch.float16: 'F16_INFINITY'}[dtype]} {torch.float32: 'F32_INFINITY', torch.float16: 'F16_INFINITY'}[dtype]
}
if apply_scale: if apply_scale:
defines['APPLY_SCALE'] = True defines['APPLY_SCALE'] = True
if apply_rpe: if apply_rpe:
@@ -73,13 +70,13 @@ class _softmax(torch.autograd.Function):
defines['APPLY_ATTN_MASK'] = True defines['APPLY_ATTN_MASK'] = True
if attn_mask_mode == 'mul': if attn_mask_mode == 'mul':
defines['ATTN_MASK_MUL'] = True defines['ATTN_MASK_MUL'] = True
kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps) kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps)
cache[key] = kernel cache[key] = kernel
return cache[key] return cache[key]
@staticmethod @staticmethod
def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut,
spdims, block, lut, maxlut, bench, time): maxlut, bench, time):
apply_scale = False if scale == 1.0 else True apply_scale = False if scale == 1.0 else True
# handle None rpe # handle None rpe
@@ -109,17 +106,26 @@ class _softmax(torch.autograd.Function):
apply_attn_mask = True apply_attn_mask = True
stride_zattnm = attn_mask.stride(0) stride_zattnm = attn_mask.stride(0)
# run kernel # run kernel
kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut*block, x.device, x.dtype, block, kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut * block, x.device, x.dtype, block, apply_scale,
apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
kp_mask_mode, attn_mask_mode)
M = x.shape[0] M = x.shape[0]
grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M] grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M]
# run kernel # run kernel
kernel(x.data_ptr(), scale, lut.data_ptr(), rpe.data_ptr(), key_padding_mask.data_ptr(), attn_mask.data_ptr(), kernel(x.data_ptr(),
maxlut, x.stride(0), stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, scale,
lut.data_ptr(),
rpe.data_ptr(),
key_padding_mask.data_ptr(),
attn_mask.data_ptr(),
maxlut,
x.stride(0),
stride_zrpe,
stride_hrpe,
stride_srpe,
stride_zkpm,
stride_zattnm,
grid=grid) grid=grid)
# save to context # save to context
ctx.mark_dirty(x) ctx.mark_dirty(x)
@@ -135,39 +141,45 @@ class _softmax(torch.autograd.Function):
ctx.kp_mask_mode = kp_mask_mode ctx.kp_mask_mode = kp_mask_mode
ctx.attn_mask_mode = attn_mask_mode ctx.attn_mask_mode = attn_mask_mode
return x return x
@staticmethod @staticmethod
def backward(ctx, dx): def backward(ctx, dx):
# retrieve from context # retrieve from context
x, lut = ctx.saved_tensors x, lut = ctx.saved_tensors
# run kernel # run kernel
kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut*ctx.block, x.device, x.dtype, ctx.block, kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut * ctx.block, x.device, x.dtype, ctx.block,
ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask, ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask,
ctx.kp_mask_mode, ctx.attn_mask_mode) ctx.kp_mask_mode, ctx.attn_mask_mode)
M = x.shape[0] M = x.shape[0]
grid = lambda opt: [triton.cdiv(ctx.spdims[0] * ctx.spdims[1] * ctx.block, opt.TM), M] grid = lambda opt: [triton.cdiv(ctx.spdims[0] * ctx.spdims[1] * ctx.block, opt.TM), M]
kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid) kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class softmax: class softmax:
apply_softmax = _softmax.apply apply_softmax = _softmax.apply
def make_lut(self, device): def make_lut(self, device):
key = (device, ) key = (device, )
if key not in self.lut_cache: if key not in self.lut_cache:
self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device) self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
return self.lut_cache[key] return self.lut_cache[key]
def __init__(self, layout, block, bench = False): def __init__(self, layout, block, bench=False):
self.spdims = layout.shape self.spdims = layout.shape
self.layout = layout self.layout = layout
self.block = block self.block = block
self.bench = bench self.bench = bench
self.lut_cache = dict() self.lut_cache = dict()
def __call__(self, x, scale = 1., rpe = None, key_padding_mask = None, attn_mask = None, def __call__(self,
key_padding_mask_mode='add', attn_mask_mode='add'): x,
scale=1.,
rpe=None,
key_padding_mask=None,
attn_mask=None,
key_padding_mask_mode='add',
attn_mask_mode='add'):
time_y = [None] time_y = [None]
if rpe is not None and rpe.dtype != x.dtype: if rpe is not None and rpe.dtype != x.dtype:
raise ValueError('relative position embedding must be %s' % x.dtype) raise ValueError('relative position embedding must be %s' % x.dtype)
@@ -176,9 +188,6 @@ class softmax:
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype: if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
raise ValueError('Key padding mask must be %s' % x.dtype) raise ValueError('Key padding mask must be %s' % x.dtype)
lut, maxlut = self.make_lut(x.device) lut, maxlut = self.make_lut(x.device)
x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask, x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode,
key_padding_mask_mode, attn_mask_mode, self.spdims, self.block, lut, maxlut, self.bench, time_y)
self.spdims, self.block,
lut,
maxlut, self.bench, time_y)
return x return x

View File

@@ -8,50 +8,74 @@ class _conv(torch.autograd.Function):
@staticmethod @staticmethod
def unpack(IDX, CI, R, S): def unpack(IDX, CI, R, S):
s = IDX % S s = IDX % S
cr = IDX // S cr = IDX // S
r = cr % R r = cr % R
ci = cr // R ci = cr // R
return ci, r, s return ci, r, s
@staticmethod @staticmethod
def forward(ctx, a, b, pad, stride): def forward(ctx, a, b, pad, stride):
# create kernel if necessary # create kernel if necessary
dtype = a.dtype dtype = a.dtype
device = a.device device = a.device
# shapes # shapes
Z, CI, H, W = a.shape Z, CI, H, W = a.shape
_, R, S, CO = b.shape _, R, S, CO = b.shape
P = (H + 2*pad[0] - R)//stride[0] + 1 P = (H + 2 * pad[0] - R) // stride[0] + 1
Q = (W + 2*pad[1] - S)//stride[1] + 1 Q = (W + 2 * pad[1] - S) // stride[1] + 1
# compile kernel # compile kernel
if (dtype, device) not in _conv.kernel: if (dtype, device) not in _conv.kernel:
TK = 16 TK = 16
defines = { defines = {
'TYPE' : dtype, 'TYPE': dtype,
'TM' : 64, 'TM': 64,
'TN' : 64, 'TN': 64,
'TK' : TK, 'TK': TK,
'TZ' : 1, 'TZ': 1,
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R, 'HH': H,
} 'WW': W,
idx = torch.arange(CI*R*S) 'PP': P,
ci, r, s = _conv.unpack(idx, CI, R, S) 'QQ': Q,
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S) 'SS': S,
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3) 'RR': R,
delta = delta.type(torch.int32).cuda() }
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines)) idx = torch.arange(CI * R * S)
delta, kernel = _conv.kernel[dtype] ci, r, s = _conv.unpack(idx, CI, R, S)
# allocate output nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device) delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + (ns - s) * a.stride(3)
# enqueue delta = delta.type(torch.int32).cuda()
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), 1., Z*P*Q, CO, CI*R*S, _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines))
pad[0], pad[1], stride[0], stride[1], delta, kernel = _conv.kernel[dtype]
# allocate output
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
# enqueue
kernel(
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
1.,
Z * P * Q,
CO,
CI * R * S,
pad[0],
pad[1],
stride[0],
stride[1],
delta.data_ptr(), delta.data_ptr(),
a.stride(0), a.stride(1), a.stride(2), a.stride(3), a.stride(0),
b.stride(0), b.stride(1), b.stride(2), b.stride(3), a.stride(1),
c.stride(0), c.stride(1), c.stride(2), c.stride(3), a.stride(2),
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.TM), triton.cdiv(CO, opt.TN)]) a.stride(3),
return c b.stride(0),
b.stride(1),
b.stride(2),
b.stride(3),
c.stride(0),
c.stride(1),
c.stride(2),
c.stride(3),
grid=lambda opt: [triton.cdiv(Z * P * Q, opt.TM), triton.cdiv(CO, opt.TN)])
return c
conv = _conv.apply conv = _conv.apply

View File

@@ -7,17 +7,17 @@ class _matmul(torch.autograd.Function):
_DEFAULT_CONFIGS = [ _DEFAULT_CONFIGS = [
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), ({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), ({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '64' , 'TK': '32', 'TZ': '1'}, 4), ({'TM': '128', 'TN': '64', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 4), ({'TM': '64', 'TN': '64', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '32' , 'TN': '128', 'TK': '64', 'TZ': '1'}, 4), ({'TM': '32', 'TN': '128', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '32' , 'TK': '64', 'TZ': '1'}, 4), ({'TM': '128', 'TN': '32', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '64' , 'TN': '32' , 'TK': '64', 'TZ': '1'}, 2), ({'TM': '64', 'TN': '32', 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 2), ({'TM': '32', 'TN': '64', 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), ({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), ({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), ({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), ({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
] ]
_CONFIGS = _DEFAULT_CONFIGS _CONFIGS = _DEFAULT_CONFIGS
@@ -28,9 +28,9 @@ class _matmul(torch.autograd.Function):
if N % 2 == 0: return 2 if N % 2 == 0: return 2
return 1 return 1
_locks = dict() _locks = dict()
_kernels = dict() _kernels = dict()
@staticmethod @staticmethod
def _call(a, b): def _call(a, b):
dtype = a.dtype dtype = a.dtype
@@ -51,26 +51,24 @@ class _matmul(torch.autograd.Function):
lda_pow2_div = _matmul.largest_pow2_divisor(lda) lda_pow2_div = _matmul.largest_pow2_divisor(lda)
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb) ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc) ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
is_tk_div_k = K % 64 == 0 is_tk_div_k = K % 64 == 0
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k) key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
if key not in _matmul._kernels: if key not in _matmul._kernels:
defines = { defines = {
'TYPE' : dtype, 'TYPE': dtype, 'STRIDE_AM': 'lda' if is_a_row else '1', 'STRIDE_AK': '1' if is_a_row else 'lda',
'STRIDE_AM' : 'lda' if is_a_row else '1', 'STRIDE_BK': 'ldb' if is_b_row else '1', 'STRIDE_BN': '1' if is_b_row else 'ldb', 'LDA_POW2_DIV':
'STRIDE_AK' : '1' if is_a_row else 'lda', lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, 'IS_TK_DIV_K':
'STRIDE_BK' : 'ldb' if is_b_row else '1', int(is_tk_div_k)
'STRIDE_BN' : '1' if is_b_row else 'ldb',
'LDA_POW2_DIV': lda_pow2_div,
'LDB_POW2_DIV': ldb_pow2_div,
'LDC_POW2_DIV': ldc_pow2_div,
'IS_TK_DIV_K' : int(is_tk_div_k)
} }
_matmul._kernels[key] = triton.kernel(_matmul.src, device, defines=defines, _matmul._kernels[key] = triton.kernel(_matmul.src,
autotune_vals = _matmul._CONFIGS, autotune_key=['M', 'N', 'K']) device,
defines=defines,
autotune_vals=_matmul._CONFIGS,
autotune_key=['M', 'N', 'K'])
kernel = _matmul._kernels[key] kernel = _matmul._kernels[key]
# # locks for split-k # # locks for split-k
if device not in _matmul._locks: if device not in _matmul._locks:
_matmul._locks[device] = torch.zeros(1024*1024, dtype=torch.int32, device=device) _matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
locks = _matmul._locks[device] locks = _matmul._locks[device]
# enqueue # enqueue
alpha = 1. alpha = 1.
@@ -81,7 +79,7 @@ class _matmul(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, a, b): def forward(ctx, a, b):
c = _matmul._call(a,b) c = _matmul._call(a, b)
return c return c
matmul = _matmul.apply matmul = _matmul.apply

View File

@@ -4,23 +4,22 @@ import os
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward']) fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
fwd_kernels = dict() fwd_kernels = dict()
def get_fwd_kernel(block, dtype, device): def get_fwd_kernel(block, dtype, device):
key = (block, dtype, device) key = (block, dtype, device)
if key not in fwd_kernels: if key not in fwd_kernels:
defines = {'BLOCK': block, 'TYPE': dtype} defines = {'BLOCK': block, 'TYPE': dtype}
fwd_kernels[key] = triton.kernel(fwd_src, device = device, defines = defines) fwd_kernels[key] = triton.kernel(fwd_src, device=device, defines=defines)
return fwd_kernels[key] return fwd_kernels[key]
class _softmax(torch.autograd.Function): class _softmax(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
y = torch.empty_like(x) y = torch.empty_like(x)
M, N = x.shape M, N = x.shape
kernel = get_fwd_kernel(N, x.dtype, x.device) kernel = get_fwd_kernel(N, x.dtype, x.device)
kernel(x.data_ptr(), y.data_ptr(), grid = lambda opt: [M, ]) grid = lambda opt: (M, )
kernel(x.data_ptr(), y.data_ptr(), grid=grid)
return y return y
softmax = _softmax.apply softmax = _softmax.apply