[PYTHON] Removed support for dense softmax

Interest seems limited now that it is fused in cross_entropy. Will
likely re-add once it's easier to share code between ops
This commit is contained in:
Philippe Tillet
2021-02-07 16:46:47 -05:00
parent 682ac4c60e
commit b0647cfd52
3 changed files with 0 additions and 41 deletions

View File

@@ -1,8 +0,0 @@
import torch
import triton
def test_op(M = 1024, N = 1024, dtype = torch.float32):
x = torch.randn(M, N, dtype=dtype, device='cuda')
th_y = torch.softmax(x, dim=-1)
tt_y = triton.ops.softmax(x)
assert torch.allclose(tt_y, th_y)

View File

@@ -1,8 +0,0 @@
__global__ void forward(TYPE* X, TYPE* Y) {
int pid = get_program_id(0);
int off[BLOCK] = pid * BLOCK + 0 ... BLOCK;
float x[BLOCK] = *(X + off);
float shifted[BLOCK] = exp(x - x[max]);
float sum = shifted[+];
*(Y + off) = shifted / sum;
}

View File

@@ -1,25 +0,0 @@
import torch
import triton
import os
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
fwd_kernels = dict()
def get_fwd_kernel(block, dtype, device):
key = (block, dtype, device)
if key not in fwd_kernels:
defines = {'BLOCK': block, 'TYPE': dtype}
fwd_kernels[key] = triton.kernel(fwd_src, device=device, defines=defines)
return fwd_kernels[key]
class _softmax(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = torch.empty_like(x)
M, N = x.shape
kernel = get_fwd_kernel(N, x.dtype, x.device)
grid = lambda opt: (M, )
kernel(x.data_ptr(), y.data_ptr(), grid=grid)
return y
softmax = _softmax.apply