[PYTHON] Removed support for dense softmax
Interest seems limited now that it is fused in cross_entropy. Will likely re-add once it's easier to share code between ops
This commit is contained in:
@@ -1,8 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
def test_op(M = 1024, N = 1024, dtype = torch.float32):
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda')
|
||||
th_y = torch.softmax(x, dim=-1)
|
||||
tt_y = triton.ops.softmax(x)
|
||||
assert torch.allclose(tt_y, th_y)
|
@@ -1,8 +0,0 @@
|
||||
__global__ void forward(TYPE* X, TYPE* Y) {
|
||||
int pid = get_program_id(0);
|
||||
int off[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
float x[BLOCK] = *(X + off);
|
||||
float shifted[BLOCK] = exp(x - x[max]);
|
||||
float sum = shifted[+];
|
||||
*(Y + off) = shifted / sum;
|
||||
}
|
@@ -1,25 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
import os
|
||||
|
||||
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
|
||||
fwd_kernels = dict()
|
||||
|
||||
def get_fwd_kernel(block, dtype, device):
|
||||
key = (block, dtype, device)
|
||||
if key not in fwd_kernels:
|
||||
defines = {'BLOCK': block, 'TYPE': dtype}
|
||||
fwd_kernels[key] = triton.kernel(fwd_src, device=device, defines=defines)
|
||||
return fwd_kernels[key]
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
y = torch.empty_like(x)
|
||||
M, N = x.shape
|
||||
kernel = get_fwd_kernel(N, x.dtype, x.device)
|
||||
grid = lambda opt: (M, )
|
||||
kernel(x.data_ptr(), y.data_ptr(), grid=grid)
|
||||
return y
|
||||
|
||||
softmax = _softmax.apply
|
Reference in New Issue
Block a user