Added a Softmax Xent Op (#53)
Also includes a bugfix in kernel.py to set the device before registering the c++ function object
This commit is contained in:
committed by
Philippe Tillet
parent
dffd66bc83
commit
682ac4c60e
33
python/tests/test_cross_entropy.py
Normal file
33
python/tests/test_cross_entropy.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||||
|
[
|
||||||
|
(M, N, dtype, mode) for M in [1024, 821]
|
||||||
|
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||||
|
for dtype in ['float16', 'float32']\
|
||||||
|
for mode in ['forward', 'backward']
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_op(M, N, dtype, mode):
|
||||||
|
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||||
|
# create inputs
|
||||||
|
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
||||||
|
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
||||||
|
# forward pass
|
||||||
|
tt_y = triton.ops.cross_entropy(x, idx)
|
||||||
|
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
|
||||||
|
if mode == 'forward':
|
||||||
|
assert torch.allclose(th_y, tt_y, atol=1e-3, rtol=1e-2)
|
||||||
|
# backward pass
|
||||||
|
elif mode == 'backward':
|
||||||
|
dy = torch.randn_like(tt_y)
|
||||||
|
# triton backward
|
||||||
|
tt_y.backward(dy)
|
||||||
|
tt_dx = x.grad.clone()
|
||||||
|
# torch backward
|
||||||
|
x.grad.zero_()
|
||||||
|
th_y.backward(dy)
|
||||||
|
th_dx = x.grad.clone()
|
||||||
|
assert torch.allclose(th_dx, tt_dx, atol=1e-3, rtol=1e-2)
|
@@ -12,25 +12,14 @@ def cleanup():
|
|||||||
_triton.cleanup()
|
_triton.cleanup()
|
||||||
|
|
||||||
codes = {
|
codes = {
|
||||||
_triton.arg_type.int1: 'B',
|
_triton.arg_type.int1: 'B', _triton.arg_type.int8: 'B', _triton.arg_type.int32: 'I', _triton.arg_type.int64: 'Q',
|
||||||
_triton.arg_type.int8: 'B',
|
_triton.arg_type.half: 'H', _triton.arg_type.float: 'f', _triton.arg_type.double: 'd', _triton.arg_type.buffer: 'P'
|
||||||
_triton.arg_type.int32: 'I',
|
|
||||||
_triton.arg_type.int64: 'Q',
|
|
||||||
_triton.arg_type.half: 'H',
|
|
||||||
_triton.arg_type.float: 'f',
|
|
||||||
_triton.arg_type.double: 'd',
|
|
||||||
_triton.arg_type.buffer: 'P'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def th_to_triton(obj):
|
def th_to_triton(obj):
|
||||||
tys = {
|
tys = {
|
||||||
torch.int8: 'char',
|
torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long', torch.float16: 'half',
|
||||||
torch.int16: 'short',
|
torch.float32: 'float', torch.float64: 'double'
|
||||||
torch.int32: 'int',
|
|
||||||
torch.int64: 'long',
|
|
||||||
torch.float16: 'half',
|
|
||||||
torch.float32: 'float',
|
|
||||||
torch.float64: 'double'
|
|
||||||
}
|
}
|
||||||
if isinstance(obj, torch.dtype):
|
if isinstance(obj, torch.dtype):
|
||||||
return tys[obj]
|
return tys[obj]
|
||||||
@@ -69,6 +58,7 @@ class kernel:
|
|||||||
_torch_utils.register_stream(self.device)
|
_torch_utils.register_stream(self.device)
|
||||||
# C++ function wrapper
|
# C++ function wrapper
|
||||||
self.op_id = _triton.make_op_id()
|
self.op_id = _triton.make_op_id()
|
||||||
|
_torch_utils.set_device(self.device)
|
||||||
_triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
|
_triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
|
||||||
# debug mode
|
# debug mode
|
||||||
self.is_debug = 'TRITON_DEBUG' in os.environ
|
self.is_debug = 'TRITON_DEBUG' in os.environ
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
from .conv import _conv, conv
|
from .conv import _conv, conv
|
||||||
from .matmul import _matmul, matmul
|
from .matmul import _matmul, matmul
|
||||||
from .softmax import _softmax, softmax
|
from .softmax import _softmax, softmax
|
||||||
|
from .cross_entropy import _cross_entropy, cross_entropy
|
||||||
from . import blocksparse
|
from . import blocksparse
|
42
python/triton/ops/cross_entropy.c
Normal file
42
python/triton/ops/cross_entropy.c
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
__global__ void forward(TYPE *logit,
|
||||||
|
TYPE *modified_logit,
|
||||||
|
long *indices __readonly,
|
||||||
|
TYPE *result,
|
||||||
|
int n_cols) {
|
||||||
|
int row = get_program_id(0);
|
||||||
|
|
||||||
|
bool check[TILE] = ((0 ... TILE) < n_cols);
|
||||||
|
int offset[TILE] = row * n_cols + 0 ... TILE;
|
||||||
|
TYPE *px[TILE] = logit + offset;
|
||||||
|
TYPE *pmodified[TILE] = modified_logit + offset;
|
||||||
|
long local_ind = *(indices + row);
|
||||||
|
|
||||||
|
TYPE F16[TILE] = check ? *px : -INFINITY;
|
||||||
|
float shifted_logit[TILE] = F16 - F16[max];
|
||||||
|
float neg_logprob[TILE] = log(exp(shifted_logit)[+]) - shifted_logit;
|
||||||
|
*? (check)pmodified = neg_logprob;
|
||||||
|
__debug_barrier();
|
||||||
|
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void backward(TYPE *neg_logprobs,
|
||||||
|
long *indices,
|
||||||
|
TYPE *dneg_logprobs,
|
||||||
|
int n_cols) {
|
||||||
|
|
||||||
|
int row = get_program_id(0);
|
||||||
|
// pointer arithmetic
|
||||||
|
bool check[TILE] = ((0 ... TILE) < n_cols);
|
||||||
|
int offset[TILE] = row * n_cols + 0 ... TILE;
|
||||||
|
TYPE *px[TILE] = neg_logprobs + offset;
|
||||||
|
long local_ind = *(indices + row);
|
||||||
|
TYPE local_dn = *(dneg_logprobs + row);
|
||||||
|
// We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
||||||
|
// and we have -log(p[k]) stored, so this is easy
|
||||||
|
TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * ? (check)px) : 0;
|
||||||
|
// selected_logit_idx is selected logit index for our token
|
||||||
|
bool find_one[TILE] = ((0 ... TILE) == local_ind);
|
||||||
|
intermediate = intermediate - ((TYPE[TILE])find_one);
|
||||||
|
// multiply by dneg_logprobs
|
||||||
|
*? (check)px = intermediate * local_dn;
|
||||||
|
}
|
85
python/triton/ops/cross_entropy.py
Normal file
85
python/triton/ops/cross_entropy.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import os
|
||||||
|
import triton
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def next_power_of_2(n):
|
||||||
|
n -= 1
|
||||||
|
n |= n >> 1
|
||||||
|
n |= n >> 2
|
||||||
|
n |= n >> 4
|
||||||
|
n |= n >> 8
|
||||||
|
n |= n >> 16
|
||||||
|
n += 1
|
||||||
|
return n
|
||||||
|
|
||||||
|
def make_kernel(device, dtype, n_cols, cache, name):
|
||||||
|
rounded = next_power_of_2(n_cols)
|
||||||
|
key = (dtype, rounded)
|
||||||
|
if key not in cache:
|
||||||
|
fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
|
||||||
|
src = triton.read(fname, kernel_names=[name])
|
||||||
|
infinities = {
|
||||||
|
torch.float16: "F16_INFINITY",
|
||||||
|
torch.float32: "F32_INFINITY",
|
||||||
|
}
|
||||||
|
defines = {
|
||||||
|
"TILE": rounded,
|
||||||
|
"TYPE": dtype,
|
||||||
|
"INFINITY": infinities[dtype],
|
||||||
|
}
|
||||||
|
cache[key] = triton.kernel(src, device=device, defines=defines)
|
||||||
|
return cache[key]
|
||||||
|
|
||||||
|
# forward kernel
|
||||||
|
fwd_kernels = dict()
|
||||||
|
make_fwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, fwd_kernels, "forward")
|
||||||
|
|
||||||
|
# backward kernel
|
||||||
|
bwd_kernels = dict()
|
||||||
|
make_bwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, bwd_kernels, "backward")
|
||||||
|
|
||||||
|
class _cross_entropy(torch.autograd.Function):
|
||||||
|
@classmethod
|
||||||
|
def forward(cls, ctx, logits, indices):
|
||||||
|
# make sure we can use triton
|
||||||
|
assert (indices.dtype == torch.int64), "Indices are expected to be of type long."
|
||||||
|
# make kernel
|
||||||
|
device, dtype = logits.device, logits.dtype
|
||||||
|
n_cols = logits.shape[-1]
|
||||||
|
kernel = make_fwd_kernel(device, dtype, n_cols)
|
||||||
|
# run the kernel
|
||||||
|
result = torch.empty_like(indices, dtype=dtype, device=device)
|
||||||
|
neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
|
||||||
|
kernel(logits.data_ptr(),
|
||||||
|
neg_logprobs.data_ptr(),
|
||||||
|
indices.data_ptr(),
|
||||||
|
result.data_ptr(),
|
||||||
|
n_cols,
|
||||||
|
grid=lambda opt: (logits.numel() // n_cols, ))
|
||||||
|
# save for backward
|
||||||
|
ctx.save_for_backward(neg_logprobs, indices)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def backward(cls, ctx, dneg_logprobs):
|
||||||
|
"""We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
||||||
|
so we initialize the gradient as neg_logprobs, so we can just exponentiate
|
||||||
|
to get p[k], which is most of what we need... neg_logprobs will be
|
||||||
|
modified in place to become the gradient we want
|
||||||
|
"""
|
||||||
|
# load saved tensors
|
||||||
|
neg_logprobs, indices = ctx.saved_tensors
|
||||||
|
# make kernel
|
||||||
|
device, dtype = neg_logprobs.device, neg_logprobs.dtype
|
||||||
|
n_cols = neg_logprobs.shape[-1]
|
||||||
|
kernel = make_bwd_kernel(device, dtype, n_cols)
|
||||||
|
# run the kernel
|
||||||
|
# neg_logprobs will be modified in place to become our gradient:
|
||||||
|
kernel(neg_logprobs.data_ptr(),
|
||||||
|
indices.data_ptr(),
|
||||||
|
dneg_logprobs.data_ptr(),
|
||||||
|
n_cols,
|
||||||
|
grid=lambda opt: (neg_logprobs.numel() // n_cols, ))
|
||||||
|
return neg_logprobs, None
|
||||||
|
|
||||||
|
cross_entropy = _cross_entropy.apply
|
Reference in New Issue
Block a user