[PYTHON] Made codebase pep8 compliant
This commit is contained in:
@@ -6,34 +6,35 @@ import triton._C.libtriton.triton as _triton
|
|||||||
import triton._C.libtriton.torch_utils as _torch_utils
|
import triton._C.libtriton.torch_utils as _torch_utils
|
||||||
# Make sure internal C resources are cleaned up upon exit
|
# Make sure internal C resources are cleaned up upon exit
|
||||||
import atexit
|
import atexit
|
||||||
|
|
||||||
@atexit.register
|
@atexit.register
|
||||||
def cleanup():
|
def cleanup():
|
||||||
_triton.cleanup()
|
_triton.cleanup()
|
||||||
|
|
||||||
codes = {
|
codes = {
|
||||||
_triton.arg_type.int1: 'B',
|
_triton.arg_type.int1: 'B',
|
||||||
_triton.arg_type.int8: 'B',
|
_triton.arg_type.int8: 'B',
|
||||||
_triton.arg_type.int32: 'I',
|
_triton.arg_type.int32: 'I',
|
||||||
_triton.arg_type.int64: 'Q',
|
_triton.arg_type.int64: 'Q',
|
||||||
_triton.arg_type.half: 'H',
|
_triton.arg_type.half: 'H',
|
||||||
_triton.arg_type.float: 'f',
|
_triton.arg_type.float: 'f',
|
||||||
_triton.arg_type.double: 'd',
|
_triton.arg_type.double: 'd',
|
||||||
_triton.arg_type.buffer: 'P'
|
_triton.arg_type.buffer: 'P'
|
||||||
}
|
}
|
||||||
|
|
||||||
def th_to_triton(obj):
|
def th_to_triton(obj):
|
||||||
tys = {
|
tys = {
|
||||||
torch.int8: 'char',
|
torch.int8: 'char',
|
||||||
torch.int16: 'short',
|
torch.int16: 'short',
|
||||||
torch.int32: 'int',
|
torch.int32: 'int',
|
||||||
torch.int64: 'long',
|
torch.int64: 'long',
|
||||||
torch.float16: 'half',
|
torch.float16: 'half',
|
||||||
torch.float32: 'float',
|
torch.float32: 'float',
|
||||||
torch.float64: 'double'
|
torch.float64: 'double'
|
||||||
}
|
}
|
||||||
if isinstance(obj, torch.dtype):
|
if isinstance(obj, torch.dtype):
|
||||||
return tys[obj]
|
return tys[obj]
|
||||||
return str(obj)
|
return str(obj)
|
||||||
|
|
||||||
def cdiv(a, b):
|
def cdiv(a, b):
|
||||||
return (a + b - 1) // b
|
return (a + b - 1) // b
|
||||||
@@ -44,46 +45,45 @@ def synchronize(device):
|
|||||||
_torch_utils.synchronize(dev_id)
|
_torch_utils.synchronize(dev_id)
|
||||||
|
|
||||||
def read(path, kernel_names=[]):
|
def read(path, kernel_names=[]):
|
||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
source = f.read()
|
source = f.read()
|
||||||
source = _triton.extract_kernels(source, kernel_names)
|
source = _triton.extract_kernels(source, kernel_names)
|
||||||
return source
|
return source
|
||||||
|
|
||||||
class kernel:
|
class kernel:
|
||||||
|
def __init__(self, src, device, defines=dict(), num_warps=4, autotune_vals=[], autotune_key=[]):
|
||||||
|
# check if src is empty
|
||||||
|
if src == '':
|
||||||
|
raise ValueError('Kernel source code is empty')
|
||||||
|
self.src = src
|
||||||
|
self.opt = _triton.options()
|
||||||
|
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
|
||||||
|
self.opt.num_warps = num_warps
|
||||||
|
# device
|
||||||
|
assert device.type in ['cuda', 'cpu']
|
||||||
|
if device.type == 'cuda':
|
||||||
|
self.device = torch.cuda.current_device() if device.index is None else device.index
|
||||||
|
if device.type == 'cpu':
|
||||||
|
self.device = -1
|
||||||
|
_torch_utils.register_device(self.device)
|
||||||
|
_torch_utils.register_stream(self.device)
|
||||||
|
# C++ function wrapper
|
||||||
|
self.op_id = _triton.make_op_id()
|
||||||
|
_triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
|
||||||
|
# debug mode
|
||||||
|
self.is_debug = 'TRITON_DEBUG' in os.environ
|
||||||
|
# signature
|
||||||
|
arg_types = _triton.get_fn_signature(self.op_id)
|
||||||
|
self.tys = ''.join([codes[x] for x in arg_types])
|
||||||
|
|
||||||
def __init__(self, src, device, defines = dict(), num_warps = 4, autotune_vals = [], autotune_key = []):
|
def __call__(self, *args, grid):
|
||||||
# check if src is empty
|
_torch_utils.set_device(self.device)
|
||||||
if src == '':
|
# pack parameters into a byte buffer
|
||||||
raise ValueError('Kernel source code is empty')
|
params = struct.pack(self.tys, *args)
|
||||||
self.src = src
|
opt = _triton.autotune(self.op_id, self.device, params, grid)
|
||||||
self.opt = _triton.options()
|
# run kernel
|
||||||
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
|
grid = grid(opt)
|
||||||
self.opt.num_warps = num_warps
|
grid_0 = grid[0]
|
||||||
# device
|
grid_1 = 1 if len(grid) < 2 else grid[1]
|
||||||
assert device.type in ['cuda', 'cpu']
|
grid_2 = 1 if len(grid) < 3 else grid[2]
|
||||||
if device.type == 'cuda':
|
_triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)
|
||||||
self.device = torch.cuda.current_device() if device.index is None else device.index
|
|
||||||
if device.type == 'cpu':
|
|
||||||
self.device = -1
|
|
||||||
_torch_utils.register_device(self.device)
|
|
||||||
_torch_utils.register_stream(self.device)
|
|
||||||
# C++ function wrapper
|
|
||||||
self.op_id = _triton.make_op_id()
|
|
||||||
_triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
|
|
||||||
# debug mode
|
|
||||||
self.is_debug = 'TRITON_DEBUG' in os.environ
|
|
||||||
# signature
|
|
||||||
arg_types = _triton.get_fn_signature(self.op_id)
|
|
||||||
self.tys = ''.join([codes[x] for x in arg_types])
|
|
||||||
|
|
||||||
def __call__(self, *args, grid):
|
|
||||||
_torch_utils.set_device(self.device)
|
|
||||||
# pack parameters into a byte buffer
|
|
||||||
params = struct.pack(self.tys, *args)
|
|
||||||
opt = _triton.autotune(self.op_id, self.device, params, grid)
|
|
||||||
# run kernel
|
|
||||||
grid = grid(opt)
|
|
||||||
grid_0 = grid[0]
|
|
||||||
grid_1 = 1 if len(grid) < 2 else grid[1]
|
|
||||||
grid_2 = 1 if len(grid) < 3 else grid[2]
|
|
||||||
_triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)
|
|
||||||
|
@@ -2,21 +2,15 @@ import triton
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'),
|
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
|
||||||
kernel_names=['forward'])
|
|
||||||
fwd_kernels = dict()
|
fwd_kernels = dict()
|
||||||
|
|
||||||
|
bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['backward'])
|
||||||
bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'),
|
|
||||||
kernel_names=['backward'])
|
|
||||||
bwd_kernels = dict()
|
bwd_kernels = dict()
|
||||||
|
|
||||||
|
|
||||||
class _softmax(torch.autograd.Function):
|
class _softmax(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def next_power_of_2(n):
|
def next_power_of_2(n):
|
||||||
n -= 1
|
n -= 1
|
||||||
n |= n >> 1
|
n |= n >> 1
|
||||||
n |= n >> 2
|
n |= n >> 2
|
||||||
@@ -24,7 +18,7 @@ class _softmax(torch.autograd.Function):
|
|||||||
n |= n >> 8
|
n |= n >> 8
|
||||||
n |= n >> 16
|
n |= n >> 16
|
||||||
n += 1
|
n += 1
|
||||||
return n
|
return n
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_lut(layout, block, device):
|
def make_lut(layout, block, device):
|
||||||
@@ -32,7 +26,7 @@ class _softmax(torch.autograd.Function):
|
|||||||
sizes = _empty.clone()
|
sizes = _empty.clone()
|
||||||
# sizes along rows
|
# sizes along rows
|
||||||
for h in range(layout.shape[0]):
|
for h in range(layout.shape[0]):
|
||||||
sizes = torch.cat((sizes, layout[h,:,:].sum(-1)))
|
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
|
||||||
# offsets in block format
|
# offsets in block format
|
||||||
offsets = torch.zeros_like(sizes)
|
offsets = torch.zeros_like(sizes)
|
||||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||||
@@ -41,26 +35,29 @@ class _softmax(torch.autograd.Function):
|
|||||||
head = layout.nonzero(as_tuple=False)[:, 0]
|
head = layout.nonzero(as_tuple=False)[:, 0]
|
||||||
rows = layout.nonzero(as_tuple=False)[:, 1]
|
rows = layout.nonzero(as_tuple=False)[:, 1]
|
||||||
columns = layout.nonzero(as_tuple=False)[:, 2]
|
columns = layout.nonzero(as_tuple=False)[:, 2]
|
||||||
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
|
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
|
||||||
# construct look-up table
|
# construct look-up table
|
||||||
offsets = offsets*4 + 2*sizes.numel()
|
offsets = offsets * 4 + 2 * sizes.numel()
|
||||||
header = torch.stack((sizes, offsets), dim=1).view(-1)
|
header = torch.stack((sizes, offsets), dim=1).view(-1)
|
||||||
lut = torch.cat((header, core)).type(torch.int32).to(device)
|
lut = torch.cat((header, core)).type(torch.int32).to(device)
|
||||||
return lut, int(sizes.max())
|
return lut, int(sizes.max())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode):
|
def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
|
||||||
|
kp_mask_mode, attn_mask_mode):
|
||||||
if max_k >= 32768:
|
if max_k >= 32768:
|
||||||
raise NotImplementedError('Reductions larger than 32768 elements '\
|
raise NotImplementedError('Reductions larger than 32768 elements '\
|
||||||
'are not yet implemented')
|
'are not yet implemented')
|
||||||
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
|
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
|
||||||
TN = _softmax.next_power_of_2(max_k)
|
TN = _softmax.next_power_of_2(max_k)
|
||||||
# just-in-time compile kernel
|
# just-in-time compile kernel
|
||||||
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
|
key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
|
||||||
|
kp_mask_mode, attn_mask_mode)
|
||||||
if key not in cache:
|
if key not in cache:
|
||||||
defines = {'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block,
|
defines = {
|
||||||
'INFINITY': {torch.float32: 'F32_INFINITY',
|
'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block, 'INFINITY':
|
||||||
torch.float16: 'F16_INFINITY'}[dtype]}
|
{torch.float32: 'F32_INFINITY', torch.float16: 'F16_INFINITY'}[dtype]
|
||||||
|
}
|
||||||
if apply_scale:
|
if apply_scale:
|
||||||
defines['APPLY_SCALE'] = True
|
defines['APPLY_SCALE'] = True
|
||||||
if apply_rpe:
|
if apply_rpe:
|
||||||
@@ -73,13 +70,13 @@ class _softmax(torch.autograd.Function):
|
|||||||
defines['APPLY_ATTN_MASK'] = True
|
defines['APPLY_ATTN_MASK'] = True
|
||||||
if attn_mask_mode == 'mul':
|
if attn_mask_mode == 'mul':
|
||||||
defines['ATTN_MASK_MUL'] = True
|
defines['ATTN_MASK_MUL'] = True
|
||||||
kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps)
|
kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps)
|
||||||
cache[key] = kernel
|
cache[key] = kernel
|
||||||
return cache[key]
|
return cache[key]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode,
|
def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut,
|
||||||
spdims, block, lut, maxlut, bench, time):
|
maxlut, bench, time):
|
||||||
apply_scale = False if scale == 1.0 else True
|
apply_scale = False if scale == 1.0 else True
|
||||||
|
|
||||||
# handle None rpe
|
# handle None rpe
|
||||||
@@ -109,17 +106,26 @@ class _softmax(torch.autograd.Function):
|
|||||||
apply_attn_mask = True
|
apply_attn_mask = True
|
||||||
stride_zattnm = attn_mask.stride(0)
|
stride_zattnm = attn_mask.stride(0)
|
||||||
|
|
||||||
|
|
||||||
# run kernel
|
# run kernel
|
||||||
kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut*block, x.device, x.dtype, block,
|
kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut * block, x.device, x.dtype, block, apply_scale,
|
||||||
apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask,
|
apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
|
||||||
kp_mask_mode, attn_mask_mode)
|
|
||||||
M = x.shape[0]
|
M = x.shape[0]
|
||||||
grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M]
|
grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M]
|
||||||
|
|
||||||
# run kernel
|
# run kernel
|
||||||
kernel(x.data_ptr(), scale, lut.data_ptr(), rpe.data_ptr(), key_padding_mask.data_ptr(), attn_mask.data_ptr(),
|
kernel(x.data_ptr(),
|
||||||
maxlut, x.stride(0), stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
scale,
|
||||||
|
lut.data_ptr(),
|
||||||
|
rpe.data_ptr(),
|
||||||
|
key_padding_mask.data_ptr(),
|
||||||
|
attn_mask.data_ptr(),
|
||||||
|
maxlut,
|
||||||
|
x.stride(0),
|
||||||
|
stride_zrpe,
|
||||||
|
stride_hrpe,
|
||||||
|
stride_srpe,
|
||||||
|
stride_zkpm,
|
||||||
|
stride_zattnm,
|
||||||
grid=grid)
|
grid=grid)
|
||||||
# save to context
|
# save to context
|
||||||
ctx.mark_dirty(x)
|
ctx.mark_dirty(x)
|
||||||
@@ -135,39 +141,45 @@ class _softmax(torch.autograd.Function):
|
|||||||
ctx.kp_mask_mode = kp_mask_mode
|
ctx.kp_mask_mode = kp_mask_mode
|
||||||
ctx.attn_mask_mode = attn_mask_mode
|
ctx.attn_mask_mode = attn_mask_mode
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, dx):
|
def backward(ctx, dx):
|
||||||
# retrieve from context
|
# retrieve from context
|
||||||
x, lut = ctx.saved_tensors
|
x, lut = ctx.saved_tensors
|
||||||
# run kernel
|
# run kernel
|
||||||
kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut*ctx.block, x.device, x.dtype, ctx.block,
|
kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut * ctx.block, x.device, x.dtype, ctx.block,
|
||||||
ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask,
|
ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask,
|
||||||
ctx.kp_mask_mode, ctx.attn_mask_mode)
|
ctx.kp_mask_mode, ctx.attn_mask_mode)
|
||||||
M = x.shape[0]
|
M = x.shape[0]
|
||||||
grid = lambda opt: [triton.cdiv(ctx.spdims[0] * ctx.spdims[1] * ctx.block, opt.TM), M]
|
grid = lambda opt: [triton.cdiv(ctx.spdims[0] * ctx.spdims[1] * ctx.block, opt.TM), M]
|
||||||
kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid)
|
kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid)
|
||||||
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
class softmax:
|
class softmax:
|
||||||
|
|
||||||
apply_softmax = _softmax.apply
|
apply_softmax = _softmax.apply
|
||||||
|
|
||||||
def make_lut(self, device):
|
def make_lut(self, device):
|
||||||
key = (device, )
|
key = (device, )
|
||||||
if key not in self.lut_cache:
|
if key not in self.lut_cache:
|
||||||
self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
|
self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
|
||||||
return self.lut_cache[key]
|
return self.lut_cache[key]
|
||||||
|
|
||||||
def __init__(self, layout, block, bench = False):
|
def __init__(self, layout, block, bench=False):
|
||||||
self.spdims = layout.shape
|
self.spdims = layout.shape
|
||||||
self.layout = layout
|
self.layout = layout
|
||||||
self.block = block
|
self.block = block
|
||||||
self.bench = bench
|
self.bench = bench
|
||||||
self.lut_cache = dict()
|
self.lut_cache = dict()
|
||||||
|
|
||||||
def __call__(self, x, scale = 1., rpe = None, key_padding_mask = None, attn_mask = None,
|
def __call__(self,
|
||||||
key_padding_mask_mode='add', attn_mask_mode='add'):
|
x,
|
||||||
|
scale=1.,
|
||||||
|
rpe=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
attn_mask=None,
|
||||||
|
key_padding_mask_mode='add',
|
||||||
|
attn_mask_mode='add'):
|
||||||
time_y = [None]
|
time_y = [None]
|
||||||
if rpe is not None and rpe.dtype != x.dtype:
|
if rpe is not None and rpe.dtype != x.dtype:
|
||||||
raise ValueError('relative position embedding must be %s' % x.dtype)
|
raise ValueError('relative position embedding must be %s' % x.dtype)
|
||||||
@@ -176,9 +188,6 @@ class softmax:
|
|||||||
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
|
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
|
||||||
raise ValueError('Key padding mask must be %s' % x.dtype)
|
raise ValueError('Key padding mask must be %s' % x.dtype)
|
||||||
lut, maxlut = self.make_lut(x.device)
|
lut, maxlut = self.make_lut(x.device)
|
||||||
x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask,
|
x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode,
|
||||||
key_padding_mask_mode, attn_mask_mode,
|
self.spdims, self.block, lut, maxlut, self.bench, time_y)
|
||||||
self.spdims, self.block,
|
|
||||||
lut,
|
|
||||||
maxlut, self.bench, time_y)
|
|
||||||
return x
|
return x
|
@@ -8,50 +8,74 @@ class _conv(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def unpack(IDX, CI, R, S):
|
def unpack(IDX, CI, R, S):
|
||||||
s = IDX % S
|
s = IDX % S
|
||||||
cr = IDX // S
|
cr = IDX // S
|
||||||
r = cr % R
|
r = cr % R
|
||||||
ci = cr // R
|
ci = cr // R
|
||||||
return ci, r, s
|
return ci, r, s
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, a, b, pad, stride):
|
def forward(ctx, a, b, pad, stride):
|
||||||
# create kernel if necessary
|
# create kernel if necessary
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
device = a.device
|
device = a.device
|
||||||
# shapes
|
# shapes
|
||||||
Z, CI, H, W = a.shape
|
Z, CI, H, W = a.shape
|
||||||
_, R, S, CO = b.shape
|
_, R, S, CO = b.shape
|
||||||
P = (H + 2*pad[0] - R)//stride[0] + 1
|
P = (H + 2 * pad[0] - R) // stride[0] + 1
|
||||||
Q = (W + 2*pad[1] - S)//stride[1] + 1
|
Q = (W + 2 * pad[1] - S) // stride[1] + 1
|
||||||
# compile kernel
|
# compile kernel
|
||||||
if (dtype, device) not in _conv.kernel:
|
if (dtype, device) not in _conv.kernel:
|
||||||
TK = 16
|
TK = 16
|
||||||
defines = {
|
defines = {
|
||||||
'TYPE' : dtype,
|
'TYPE': dtype,
|
||||||
'TM' : 64,
|
'TM': 64,
|
||||||
'TN' : 64,
|
'TN': 64,
|
||||||
'TK' : TK,
|
'TK': TK,
|
||||||
'TZ' : 1,
|
'TZ': 1,
|
||||||
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
|
'HH': H,
|
||||||
}
|
'WW': W,
|
||||||
idx = torch.arange(CI*R*S)
|
'PP': P,
|
||||||
ci, r, s = _conv.unpack(idx, CI, R, S)
|
'QQ': Q,
|
||||||
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
|
'SS': S,
|
||||||
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
|
'RR': R,
|
||||||
delta = delta.type(torch.int32).cuda()
|
}
|
||||||
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines))
|
idx = torch.arange(CI * R * S)
|
||||||
delta, kernel = _conv.kernel[dtype]
|
ci, r, s = _conv.unpack(idx, CI, R, S)
|
||||||
# allocate output
|
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
|
||||||
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
|
delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + (ns - s) * a.stride(3)
|
||||||
# enqueue
|
delta = delta.type(torch.int32).cuda()
|
||||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), 1., Z*P*Q, CO, CI*R*S,
|
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines))
|
||||||
pad[0], pad[1], stride[0], stride[1],
|
delta, kernel = _conv.kernel[dtype]
|
||||||
|
# allocate output
|
||||||
|
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
|
||||||
|
# enqueue
|
||||||
|
kernel(
|
||||||
|
a.data_ptr(),
|
||||||
|
b.data_ptr(),
|
||||||
|
c.data_ptr(),
|
||||||
|
1.,
|
||||||
|
Z * P * Q,
|
||||||
|
CO,
|
||||||
|
CI * R * S,
|
||||||
|
pad[0],
|
||||||
|
pad[1],
|
||||||
|
stride[0],
|
||||||
|
stride[1],
|
||||||
delta.data_ptr(),
|
delta.data_ptr(),
|
||||||
a.stride(0), a.stride(1), a.stride(2), a.stride(3),
|
a.stride(0),
|
||||||
b.stride(0), b.stride(1), b.stride(2), b.stride(3),
|
a.stride(1),
|
||||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
a.stride(2),
|
||||||
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.TM), triton.cdiv(CO, opt.TN)])
|
a.stride(3),
|
||||||
return c
|
b.stride(0),
|
||||||
|
b.stride(1),
|
||||||
|
b.stride(2),
|
||||||
|
b.stride(3),
|
||||||
|
c.stride(0),
|
||||||
|
c.stride(1),
|
||||||
|
c.stride(2),
|
||||||
|
c.stride(3),
|
||||||
|
grid=lambda opt: [triton.cdiv(Z * P * Q, opt.TM), triton.cdiv(CO, opt.TN)])
|
||||||
|
return c
|
||||||
|
|
||||||
conv = _conv.apply
|
conv = _conv.apply
|
@@ -7,17 +7,17 @@ class _matmul(torch.autograd.Function):
|
|||||||
|
|
||||||
_DEFAULT_CONFIGS = [
|
_DEFAULT_CONFIGS = [
|
||||||
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
|
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
|
||||||
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
|
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
|
||||||
({'TM': '128', 'TN': '64' , 'TK': '32', 'TZ': '1'}, 4),
|
({'TM': '128', 'TN': '64', 'TK': '32', 'TZ': '1'}, 4),
|
||||||
({'TM': '64' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 4),
|
({'TM': '64', 'TN': '64', 'TK': '64', 'TZ': '1'}, 4),
|
||||||
({'TM': '32' , 'TN': '128', 'TK': '64', 'TZ': '1'}, 4),
|
({'TM': '32', 'TN': '128', 'TK': '64', 'TZ': '1'}, 4),
|
||||||
({'TM': '128', 'TN': '32' , 'TK': '64', 'TZ': '1'}, 4),
|
({'TM': '128', 'TN': '32', 'TK': '64', 'TZ': '1'}, 4),
|
||||||
({'TM': '64' , 'TN': '32' , 'TK': '64', 'TZ': '1'}, 2),
|
({'TM': '64', 'TN': '32', 'TK': '64', 'TZ': '1'}, 2),
|
||||||
({'TM': '32' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 2),
|
({'TM': '32', 'TN': '64', 'TK': '64', 'TZ': '1'}, 2),
|
||||||
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
|
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
|
||||||
({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
|
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
|
||||||
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
|
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
|
||||||
({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
|
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
|
||||||
]
|
]
|
||||||
_CONFIGS = _DEFAULT_CONFIGS
|
_CONFIGS = _DEFAULT_CONFIGS
|
||||||
|
|
||||||
@@ -28,9 +28,9 @@ class _matmul(torch.autograd.Function):
|
|||||||
if N % 2 == 0: return 2
|
if N % 2 == 0: return 2
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
_locks = dict()
|
_locks = dict()
|
||||||
_kernels = dict()
|
_kernels = dict()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _call(a, b):
|
def _call(a, b):
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
@@ -51,26 +51,24 @@ class _matmul(torch.autograd.Function):
|
|||||||
lda_pow2_div = _matmul.largest_pow2_divisor(lda)
|
lda_pow2_div = _matmul.largest_pow2_divisor(lda)
|
||||||
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
|
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
|
||||||
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
|
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
|
||||||
is_tk_div_k = K % 64 == 0
|
is_tk_div_k = K % 64 == 0
|
||||||
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
|
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
|
||||||
if key not in _matmul._kernels:
|
if key not in _matmul._kernels:
|
||||||
defines = {
|
defines = {
|
||||||
'TYPE' : dtype,
|
'TYPE': dtype, 'STRIDE_AM': 'lda' if is_a_row else '1', 'STRIDE_AK': '1' if is_a_row else 'lda',
|
||||||
'STRIDE_AM' : 'lda' if is_a_row else '1',
|
'STRIDE_BK': 'ldb' if is_b_row else '1', 'STRIDE_BN': '1' if is_b_row else 'ldb', 'LDA_POW2_DIV':
|
||||||
'STRIDE_AK' : '1' if is_a_row else 'lda',
|
lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, 'IS_TK_DIV_K':
|
||||||
'STRIDE_BK' : 'ldb' if is_b_row else '1',
|
int(is_tk_div_k)
|
||||||
'STRIDE_BN' : '1' if is_b_row else 'ldb',
|
|
||||||
'LDA_POW2_DIV': lda_pow2_div,
|
|
||||||
'LDB_POW2_DIV': ldb_pow2_div,
|
|
||||||
'LDC_POW2_DIV': ldc_pow2_div,
|
|
||||||
'IS_TK_DIV_K' : int(is_tk_div_k)
|
|
||||||
}
|
}
|
||||||
_matmul._kernels[key] = triton.kernel(_matmul.src, device, defines=defines,
|
_matmul._kernels[key] = triton.kernel(_matmul.src,
|
||||||
autotune_vals = _matmul._CONFIGS, autotune_key=['M', 'N', 'K'])
|
device,
|
||||||
|
defines=defines,
|
||||||
|
autotune_vals=_matmul._CONFIGS,
|
||||||
|
autotune_key=['M', 'N', 'K'])
|
||||||
kernel = _matmul._kernels[key]
|
kernel = _matmul._kernels[key]
|
||||||
# # locks for split-k
|
# # locks for split-k
|
||||||
if device not in _matmul._locks:
|
if device not in _matmul._locks:
|
||||||
_matmul._locks[device] = torch.zeros(1024*1024, dtype=torch.int32, device=device)
|
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
|
||||||
locks = _matmul._locks[device]
|
locks = _matmul._locks[device]
|
||||||
# enqueue
|
# enqueue
|
||||||
alpha = 1.
|
alpha = 1.
|
||||||
@@ -81,7 +79,7 @@ class _matmul(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, a, b):
|
def forward(ctx, a, b):
|
||||||
c = _matmul._call(a,b)
|
c = _matmul._call(a, b)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
matmul = _matmul.apply
|
matmul = _matmul.apply
|
||||||
|
@@ -4,23 +4,22 @@ import os
|
|||||||
|
|
||||||
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
|
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
|
||||||
fwd_kernels = dict()
|
fwd_kernels = dict()
|
||||||
|
|
||||||
def get_fwd_kernel(block, dtype, device):
|
def get_fwd_kernel(block, dtype, device):
|
||||||
key = (block, dtype, device)
|
key = (block, dtype, device)
|
||||||
if key not in fwd_kernels:
|
if key not in fwd_kernels:
|
||||||
defines = {'BLOCK': block, 'TYPE': dtype}
|
defines = {'BLOCK': block, 'TYPE': dtype}
|
||||||
fwd_kernels[key] = triton.kernel(fwd_src, device = device, defines = defines)
|
fwd_kernels[key] = triton.kernel(fwd_src, device=device, defines=defines)
|
||||||
return fwd_kernels[key]
|
return fwd_kernels[key]
|
||||||
|
|
||||||
|
|
||||||
class _softmax(torch.autograd.Function):
|
class _softmax(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
M, N = x.shape
|
M, N = x.shape
|
||||||
kernel = get_fwd_kernel(N, x.dtype, x.device)
|
kernel = get_fwd_kernel(N, x.dtype, x.device)
|
||||||
kernel(x.data_ptr(), y.data_ptr(), grid = lambda opt: [M, ])
|
grid = lambda opt: (M, )
|
||||||
|
kernel(x.data_ptr(), y.data_ptr(), grid=grid)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
softmax = _softmax.apply
|
softmax = _softmax.apply
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user