[PYTHON] Removed support for dense softmax
Interest seems limited now that it is fused in cross_entropy. Will likely re-add once it's easier to share code between ops
This commit is contained in:
@@ -1,8 +0,0 @@
|
|||||||
import torch
|
|
||||||
import triton
|
|
||||||
|
|
||||||
def test_op(M = 1024, N = 1024, dtype = torch.float32):
|
|
||||||
x = torch.randn(M, N, dtype=dtype, device='cuda')
|
|
||||||
th_y = torch.softmax(x, dim=-1)
|
|
||||||
tt_y = triton.ops.softmax(x)
|
|
||||||
assert torch.allclose(tt_y, th_y)
|
|
@@ -1,8 +0,0 @@
|
|||||||
__global__ void forward(TYPE* X, TYPE* Y) {
|
|
||||||
int pid = get_program_id(0);
|
|
||||||
int off[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
|
||||||
float x[BLOCK] = *(X + off);
|
|
||||||
float shifted[BLOCK] = exp(x - x[max]);
|
|
||||||
float sum = shifted[+];
|
|
||||||
*(Y + off) = shifted / sum;
|
|
||||||
}
|
|
@@ -1,25 +0,0 @@
|
|||||||
import torch
|
|
||||||
import triton
|
|
||||||
import os
|
|
||||||
|
|
||||||
fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward'])
|
|
||||||
fwd_kernels = dict()
|
|
||||||
|
|
||||||
def get_fwd_kernel(block, dtype, device):
|
|
||||||
key = (block, dtype, device)
|
|
||||||
if key not in fwd_kernels:
|
|
||||||
defines = {'BLOCK': block, 'TYPE': dtype}
|
|
||||||
fwd_kernels[key] = triton.kernel(fwd_src, device=device, defines=defines)
|
|
||||||
return fwd_kernels[key]
|
|
||||||
|
|
||||||
class _softmax(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x):
|
|
||||||
y = torch.empty_like(x)
|
|
||||||
M, N = x.shape
|
|
||||||
kernel = get_fwd_kernel(N, x.dtype, x.device)
|
|
||||||
grid = lambda opt: (M, )
|
|
||||||
kernel(x.data_ptr(), y.data_ptr(), grid=grid)
|
|
||||||
return y
|
|
||||||
|
|
||||||
softmax = _softmax.apply
|
|
Reference in New Issue
Block a user