diff --git a/python/triton/kernel.py b/python/triton/kernel.py index ce2a7d579..585c7b8e0 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -6,34 +6,35 @@ import triton._C.libtriton.triton as _triton import triton._C.libtriton.torch_utils as _torch_utils # Make sure internal C resources are cleaned up upon exit import atexit + @atexit.register def cleanup(): - _triton.cleanup() + _triton.cleanup() codes = { - _triton.arg_type.int1: 'B', - _triton.arg_type.int8: 'B', - _triton.arg_type.int32: 'I', - _triton.arg_type.int64: 'Q', - _triton.arg_type.half: 'H', - _triton.arg_type.float: 'f', - _triton.arg_type.double: 'd', - _triton.arg_type.buffer: 'P' + _triton.arg_type.int1: 'B', + _triton.arg_type.int8: 'B', + _triton.arg_type.int32: 'I', + _triton.arg_type.int64: 'Q', + _triton.arg_type.half: 'H', + _triton.arg_type.float: 'f', + _triton.arg_type.double: 'd', + _triton.arg_type.buffer: 'P' } def th_to_triton(obj): - tys = { - torch.int8: 'char', - torch.int16: 'short', - torch.int32: 'int', - torch.int64: 'long', - torch.float16: 'half', - torch.float32: 'float', - torch.float64: 'double' - } - if isinstance(obj, torch.dtype): - return tys[obj] - return str(obj) + tys = { + torch.int8: 'char', + torch.int16: 'short', + torch.int32: 'int', + torch.int64: 'long', + torch.float16: 'half', + torch.float32: 'float', + torch.float64: 'double' + } + if isinstance(obj, torch.dtype): + return tys[obj] + return str(obj) def cdiv(a, b): return (a + b - 1) // b @@ -44,46 +45,45 @@ def synchronize(device): _torch_utils.synchronize(dev_id) def read(path, kernel_names=[]): - with open(path, 'r') as f: - source = f.read() - source = _triton.extract_kernels(source, kernel_names) - return source + with open(path, 'r') as f: + source = f.read() + source = _triton.extract_kernels(source, kernel_names) + return source class kernel: + def __init__(self, src, device, defines=dict(), num_warps=4, autotune_vals=[], autotune_key=[]): + # check if src is empty + if src == '': + raise ValueError('Kernel source code is empty') + self.src = src + self.opt = _triton.options() + self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()} + self.opt.num_warps = num_warps + # device + assert device.type in ['cuda', 'cpu'] + if device.type == 'cuda': + self.device = torch.cuda.current_device() if device.index is None else device.index + if device.type == 'cpu': + self.device = -1 + _torch_utils.register_device(self.device) + _torch_utils.register_stream(self.device) + # C++ function wrapper + self.op_id = _triton.make_op_id() + _triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key) + # debug mode + self.is_debug = 'TRITON_DEBUG' in os.environ + # signature + arg_types = _triton.get_fn_signature(self.op_id) + self.tys = ''.join([codes[x] for x in arg_types]) - def __init__(self, src, device, defines = dict(), num_warps = 4, autotune_vals = [], autotune_key = []): - # check if src is empty - if src == '': - raise ValueError('Kernel source code is empty') - self.src = src - self.opt = _triton.options() - self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()} - self.opt.num_warps = num_warps - # device - assert device.type in ['cuda', 'cpu'] - if device.type == 'cuda': - self.device = torch.cuda.current_device() if device.index is None else device.index - if device.type == 'cpu': - self.device = -1 - _torch_utils.register_device(self.device) - _torch_utils.register_stream(self.device) - # C++ function wrapper - self.op_id = _triton.make_op_id() - _triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key) - # debug mode - self.is_debug = 'TRITON_DEBUG' in os.environ - # signature - arg_types = _triton.get_fn_signature(self.op_id) - self.tys = ''.join([codes[x] for x in arg_types]) - - def __call__(self, *args, grid): - _torch_utils.set_device(self.device) - # pack parameters into a byte buffer - params = struct.pack(self.tys, *args) - opt = _triton.autotune(self.op_id, self.device, params, grid) - # run kernel - grid = grid(opt) - grid_0 = grid[0] - grid_1 = 1 if len(grid) < 2 else grid[1] - grid_2 = 1 if len(grid) < 3 else grid[2] - _triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2) \ No newline at end of file + def __call__(self, *args, grid): + _torch_utils.set_device(self.device) + # pack parameters into a byte buffer + params = struct.pack(self.tys, *args) + opt = _triton.autotune(self.op_id, self.device, params, grid) + # run kernel + grid = grid(opt) + grid_0 = grid[0] + grid_1 = 1 if len(grid) < 2 else grid[1] + grid_2 = 1 if len(grid) < 3 else grid[2] + _triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2) diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index 7d636bc27..fd3faeade 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -2,21 +2,15 @@ import triton import torch import os -fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), - kernel_names=['forward']) +fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward']) fwd_kernels = dict() - -bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), - kernel_names=['backward']) +bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['backward']) bwd_kernels = dict() - class _softmax(torch.autograd.Function): - - @staticmethod - def next_power_of_2(n): + def next_power_of_2(n): n -= 1 n |= n >> 1 n |= n >> 2 @@ -24,7 +18,7 @@ class _softmax(torch.autograd.Function): n |= n >> 8 n |= n >> 16 n += 1 - return n + return n @staticmethod def make_lut(layout, block, device): @@ -32,7 +26,7 @@ class _softmax(torch.autograd.Function): sizes = _empty.clone() # sizes along rows for h in range(layout.shape[0]): - sizes = torch.cat((sizes, layout[h,:,:].sum(-1))) + sizes = torch.cat((sizes, layout[h, :, :].sum(-1))) # offsets in block format offsets = torch.zeros_like(sizes) offsets[1:] = torch.cumsum(sizes[:-1], dim=0) @@ -41,26 +35,29 @@ class _softmax(torch.autograd.Function): head = layout.nonzero(as_tuple=False)[:, 0] rows = layout.nonzero(as_tuple=False)[:, 1] columns = layout.nonzero(as_tuple=False)[:, 2] - core = torch.stack((idx, columns, rows, head), dim=1).view(-1) + core = torch.stack((idx, columns, rows, head), dim=1).view(-1) # construct look-up table - offsets = offsets*4 + 2*sizes.numel() + offsets = offsets * 4 + 2 * sizes.numel() header = torch.stack((sizes, offsets), dim=1).view(-1) lut = torch.cat((header, core)).type(torch.int32).to(device) return lut, int(sizes.max()) @staticmethod - def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode): + def make_kernel(cache, src, max_k, device, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, + kp_mask_mode, attn_mask_mode): if max_k >= 32768: - raise NotImplementedError('Reductions larger than 32768 elements '\ - 'are not yet implemented') + raise NotImplementedError('Reductions larger than 32768 elements '\ + 'are not yet implemented') num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16) TN = _softmax.next_power_of_2(max_k) # just-in-time compile kernel - key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode) + key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, + kp_mask_mode, attn_mask_mode) if key not in cache: - defines = {'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block, - 'INFINITY': {torch.float32: 'F32_INFINITY', - torch.float16: 'F16_INFINITY'}[dtype]} + defines = { + 'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block, 'INFINITY': + {torch.float32: 'F32_INFINITY', torch.float16: 'F16_INFINITY'}[dtype] + } if apply_scale: defines['APPLY_SCALE'] = True if apply_rpe: @@ -73,13 +70,13 @@ class _softmax(torch.autograd.Function): defines['APPLY_ATTN_MASK'] = True if attn_mask_mode == 'mul': defines['ATTN_MASK_MUL'] = True - kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps) + kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps) cache[key] = kernel return cache[key] @staticmethod - def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, - spdims, block, lut, maxlut, bench, time): + def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut, + maxlut, bench, time): apply_scale = False if scale == 1.0 else True # handle None rpe @@ -109,17 +106,26 @@ class _softmax(torch.autograd.Function): apply_attn_mask = True stride_zattnm = attn_mask.stride(0) - # run kernel - kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut*block, x.device, x.dtype, block, - apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, - kp_mask_mode, attn_mask_mode) + kernel = _softmax.make_kernel(fwd_kernels, fwd_src, maxlut * block, x.device, x.dtype, block, apply_scale, + apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode) M = x.shape[0] grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M] # run kernel - kernel(x.data_ptr(), scale, lut.data_ptr(), rpe.data_ptr(), key_padding_mask.data_ptr(), attn_mask.data_ptr(), - maxlut, x.stride(0), stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, + kernel(x.data_ptr(), + scale, + lut.data_ptr(), + rpe.data_ptr(), + key_padding_mask.data_ptr(), + attn_mask.data_ptr(), + maxlut, + x.stride(0), + stride_zrpe, + stride_hrpe, + stride_srpe, + stride_zkpm, + stride_zattnm, grid=grid) # save to context ctx.mark_dirty(x) @@ -135,39 +141,45 @@ class _softmax(torch.autograd.Function): ctx.kp_mask_mode = kp_mask_mode ctx.attn_mask_mode = attn_mask_mode return x - + @staticmethod def backward(ctx, dx): # retrieve from context x, lut = ctx.saved_tensors # run kernel - kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut*ctx.block, x.device, x.dtype, ctx.block, - ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask, - ctx.kp_mask_mode, ctx.attn_mask_mode) + kernel = _softmax.make_kernel(bwd_kernels, bwd_src, ctx.maxlut * ctx.block, x.device, x.dtype, ctx.block, + ctx.apply_scale, ctx.apply_rpe, ctx.apply_kp_mask, ctx.apply_attn_mask, + ctx.kp_mask_mode, ctx.attn_mask_mode) M = x.shape[0] grid = lambda opt: [triton.cdiv(ctx.spdims[0] * ctx.spdims[1] * ctx.block, opt.TM), M] kernel(x.data_ptr(), ctx.scale, dx.data_ptr(), lut.data_ptr(), ctx.maxlut, x.stride(0), dx.stride(0), grid=grid) return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None class softmax: - + apply_softmax = _softmax.apply def make_lut(self, device): key = (device, ) if key not in self.lut_cache: - self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device) + self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device) return self.lut_cache[key] - def __init__(self, layout, block, bench = False): + def __init__(self, layout, block, bench=False): self.spdims = layout.shape self.layout = layout self.block = block self.bench = bench self.lut_cache = dict() - - def __call__(self, x, scale = 1., rpe = None, key_padding_mask = None, attn_mask = None, - key_padding_mask_mode='add', attn_mask_mode='add'): + + def __call__(self, + x, + scale=1., + rpe=None, + key_padding_mask=None, + attn_mask=None, + key_padding_mask_mode='add', + attn_mask_mode='add'): time_y = [None] if rpe is not None and rpe.dtype != x.dtype: raise ValueError('relative position embedding must be %s' % x.dtype) @@ -176,9 +188,6 @@ class softmax: if key_padding_mask is not None and key_padding_mask.dtype != x.dtype: raise ValueError('Key padding mask must be %s' % x.dtype) lut, maxlut = self.make_lut(x.device) - x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask, - key_padding_mask_mode, attn_mask_mode, - self.spdims, self.block, - lut, - maxlut, self.bench, time_y) + x = softmax.apply_softmax(x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode, + self.spdims, self.block, lut, maxlut, self.bench, time_y) return x \ No newline at end of file diff --git a/python/triton/ops/conv.py b/python/triton/ops/conv.py index 95a1ad201..1725d9ca0 100644 --- a/python/triton/ops/conv.py +++ b/python/triton/ops/conv.py @@ -8,50 +8,74 @@ class _conv(torch.autograd.Function): @staticmethod def unpack(IDX, CI, R, S): - s = IDX % S - cr = IDX // S - r = cr % R - ci = cr // R - return ci, r, s + s = IDX % S + cr = IDX // S + r = cr % R + ci = cr // R + return ci, r, s @staticmethod def forward(ctx, a, b, pad, stride): - # create kernel if necessary - dtype = a.dtype - device = a.device - # shapes - Z, CI, H, W = a.shape - _, R, S, CO = b.shape - P = (H + 2*pad[0] - R)//stride[0] + 1 - Q = (W + 2*pad[1] - S)//stride[1] + 1 - # compile kernel - if (dtype, device) not in _conv.kernel: - TK = 16 - defines = { - 'TYPE' : dtype, - 'TM' : 64, - 'TN' : 64, - 'TK' : TK, - 'TZ' : 1, - 'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R, - } - idx = torch.arange(CI*R*S) - ci, r, s = _conv.unpack(idx, CI, R, S) - nci, nr, ns = _conv.unpack(idx + TK, CI, R, S) - delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3) - delta = delta.type(torch.int32).cuda() - _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines)) - delta, kernel = _conv.kernel[dtype] - # allocate output - c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device) - # enqueue - kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), 1., Z*P*Q, CO, CI*R*S, - pad[0], pad[1], stride[0], stride[1], + # create kernel if necessary + dtype = a.dtype + device = a.device + # shapes + Z, CI, H, W = a.shape + _, R, S, CO = b.shape + P = (H + 2 * pad[0] - R) // stride[0] + 1 + Q = (W + 2 * pad[1] - S) // stride[1] + 1 + # compile kernel + if (dtype, device) not in _conv.kernel: + TK = 16 + defines = { + 'TYPE': dtype, + 'TM': 64, + 'TN': 64, + 'TK': TK, + 'TZ': 1, + 'HH': H, + 'WW': W, + 'PP': P, + 'QQ': Q, + 'SS': S, + 'RR': R, + } + idx = torch.arange(CI * R * S) + ci, r, s = _conv.unpack(idx, CI, R, S) + nci, nr, ns = _conv.unpack(idx + TK, CI, R, S) + delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + (ns - s) * a.stride(3) + delta = delta.type(torch.int32).cuda() + _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines)) + delta, kernel = _conv.kernel[dtype] + # allocate output + c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device) + # enqueue + kernel( + a.data_ptr(), + b.data_ptr(), + c.data_ptr(), + 1., + Z * P * Q, + CO, + CI * R * S, + pad[0], + pad[1], + stride[0], + stride[1], delta.data_ptr(), - a.stride(0), a.stride(1), a.stride(2), a.stride(3), - b.stride(0), b.stride(1), b.stride(2), b.stride(3), - c.stride(0), c.stride(1), c.stride(2), c.stride(3), - grid = lambda opt: [triton.cdiv(Z*P*Q, opt.TM), triton.cdiv(CO, opt.TN)]) - return c + a.stride(0), + a.stride(1), + a.stride(2), + a.stride(3), + b.stride(0), + b.stride(1), + b.stride(2), + b.stride(3), + c.stride(0), + c.stride(1), + c.stride(2), + c.stride(3), + grid=lambda opt: [triton.cdiv(Z * P * Q, opt.TM), triton.cdiv(CO, opt.TN)]) + return c conv = _conv.apply \ No newline at end of file diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 7912badec..ce8b908b4 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -7,17 +7,17 @@ class _matmul(torch.autograd.Function): _DEFAULT_CONFIGS = [ ({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), - ({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), - ({'TM': '128', 'TN': '64' , 'TK': '32', 'TZ': '1'}, 4), - ({'TM': '64' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 4), - ({'TM': '32' , 'TN': '128', 'TK': '64', 'TZ': '1'}, 4), - ({'TM': '128', 'TN': '32' , 'TK': '64', 'TZ': '1'}, 4), - ({'TM': '64' , 'TN': '32' , 'TK': '64', 'TZ': '1'}, 2), - ({'TM': '32' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 2), - ({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), - ({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), - ({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), - ({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), + ({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), + ({'TM': '128', 'TN': '64', 'TK': '32', 'TZ': '1'}, 4), + ({'TM': '64', 'TN': '64', 'TK': '64', 'TZ': '1'}, 4), + ({'TM': '32', 'TN': '128', 'TK': '64', 'TZ': '1'}, 4), + ({'TM': '128', 'TN': '32', 'TK': '64', 'TZ': '1'}, 4), + ({'TM': '64', 'TN': '32', 'TK': '64', 'TZ': '1'}, 2), + ({'TM': '32', 'TN': '64', 'TK': '64', 'TZ': '1'}, 2), + ({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), + ({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), + ({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), + ({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), ] _CONFIGS = _DEFAULT_CONFIGS @@ -28,9 +28,9 @@ class _matmul(torch.autograd.Function): if N % 2 == 0: return 2 return 1 - _locks = dict() _kernels = dict() + @staticmethod def _call(a, b): dtype = a.dtype @@ -51,26 +51,24 @@ class _matmul(torch.autograd.Function): lda_pow2_div = _matmul.largest_pow2_divisor(lda) ldb_pow2_div = _matmul.largest_pow2_divisor(ldb) ldc_pow2_div = _matmul.largest_pow2_divisor(ldc) - is_tk_div_k = K % 64 == 0 + is_tk_div_k = K % 64 == 0 key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k) if key not in _matmul._kernels: defines = { - 'TYPE' : dtype, - 'STRIDE_AM' : 'lda' if is_a_row else '1', - 'STRIDE_AK' : '1' if is_a_row else 'lda', - 'STRIDE_BK' : 'ldb' if is_b_row else '1', - 'STRIDE_BN' : '1' if is_b_row else 'ldb', - 'LDA_POW2_DIV': lda_pow2_div, - 'LDB_POW2_DIV': ldb_pow2_div, - 'LDC_POW2_DIV': ldc_pow2_div, - 'IS_TK_DIV_K' : int(is_tk_div_k) + 'TYPE': dtype, 'STRIDE_AM': 'lda' if is_a_row else '1', 'STRIDE_AK': '1' if is_a_row else 'lda', + 'STRIDE_BK': 'ldb' if is_b_row else '1', 'STRIDE_BN': '1' if is_b_row else 'ldb', 'LDA_POW2_DIV': + lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, 'IS_TK_DIV_K': + int(is_tk_div_k) } - _matmul._kernels[key] = triton.kernel(_matmul.src, device, defines=defines, - autotune_vals = _matmul._CONFIGS, autotune_key=['M', 'N', 'K']) + _matmul._kernels[key] = triton.kernel(_matmul.src, + device, + defines=defines, + autotune_vals=_matmul._CONFIGS, + autotune_key=['M', 'N', 'K']) kernel = _matmul._kernels[key] # # locks for split-k if device not in _matmul._locks: - _matmul._locks[device] = torch.zeros(1024*1024, dtype=torch.int32, device=device) + _matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device) locks = _matmul._locks[device] # enqueue alpha = 1. @@ -81,7 +79,7 @@ class _matmul(torch.autograd.Function): @staticmethod def forward(ctx, a, b): - c = _matmul._call(a,b) + c = _matmul._call(a, b) return c matmul = _matmul.apply diff --git a/python/triton/ops/softmax.py b/python/triton/ops/softmax.py index e00254561..5e1075fdf 100644 --- a/python/triton/ops/softmax.py +++ b/python/triton/ops/softmax.py @@ -4,23 +4,22 @@ import os fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward']) fwd_kernels = dict() + def get_fwd_kernel(block, dtype, device): key = (block, dtype, device) if key not in fwd_kernels: defines = {'BLOCK': block, 'TYPE': dtype} - fwd_kernels[key] = triton.kernel(fwd_src, device = device, defines = defines) + fwd_kernels[key] = triton.kernel(fwd_src, device=device, defines=defines) return fwd_kernels[key] - class _softmax(torch.autograd.Function): @staticmethod def forward(ctx, x): y = torch.empty_like(x) M, N = x.shape kernel = get_fwd_kernel(N, x.dtype, x.device) - kernel(x.data_ptr(), y.data_ptr(), grid = lambda opt: [M, ]) + grid = lambda opt: (M, ) + kernel(x.data_ptr(), y.data_ptr(), grid=grid) return y softmax = _softmax.apply - -