From b0647cfd523626bb157611438b235a594c499344 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 7 Feb 2021 16:46:47 -0500 Subject: [PATCH] [PYTHON] Removed support for dense softmax Interest seems limited now that it is fused in cross_entropy. Will likely re-add once it's easier to share code between ops --- python/tests/test_softmax.py | 8 -------- python/triton/ops/softmax.c | 8 -------- python/triton/ops/softmax.py | 25 ------------------------- 3 files changed, 41 deletions(-) delete mode 100644 python/tests/test_softmax.py delete mode 100644 python/triton/ops/softmax.c delete mode 100644 python/triton/ops/softmax.py diff --git a/python/tests/test_softmax.py b/python/tests/test_softmax.py deleted file mode 100644 index 7ac75c5af..000000000 --- a/python/tests/test_softmax.py +++ /dev/null @@ -1,8 +0,0 @@ -import torch -import triton - -def test_op(M = 1024, N = 1024, dtype = torch.float32): - x = torch.randn(M, N, dtype=dtype, device='cuda') - th_y = torch.softmax(x, dim=-1) - tt_y = triton.ops.softmax(x) - assert torch.allclose(tt_y, th_y) \ No newline at end of file diff --git a/python/triton/ops/softmax.c b/python/triton/ops/softmax.c deleted file mode 100644 index 3070ed0ba..000000000 --- a/python/triton/ops/softmax.c +++ /dev/null @@ -1,8 +0,0 @@ -__global__ void forward(TYPE* X, TYPE* Y) { - int pid = get_program_id(0); - int off[BLOCK] = pid * BLOCK + 0 ... BLOCK; - float x[BLOCK] = *(X + off); - float shifted[BLOCK] = exp(x - x[max]); - float sum = shifted[+]; - *(Y + off) = shifted / sum; -} \ No newline at end of file diff --git a/python/triton/ops/softmax.py b/python/triton/ops/softmax.py deleted file mode 100644 index 5e1075fdf..000000000 --- a/python/triton/ops/softmax.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -import triton -import os - -fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward']) -fwd_kernels = dict() - -def get_fwd_kernel(block, dtype, device): - key = (block, dtype, device) - if key not in fwd_kernels: - defines = {'BLOCK': block, 'TYPE': dtype} - fwd_kernels[key] = triton.kernel(fwd_src, device=device, defines=defines) - return fwd_kernels[key] - -class _softmax(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - y = torch.empty_like(x) - M, N = x.shape - kernel = get_fwd_kernel(N, x.dtype, x.device) - grid = lambda opt: (M, ) - kernel(x.data_ptr(), y.data_ptr(), grid=grid) - return y - -softmax = _softmax.apply