Added a Softmax Xent Op (#53)
Also includes a bugfix in kernel.py to set the device before registering the c++ function object
This commit is contained in:
committed by
Philippe Tillet
parent
dffd66bc83
commit
682ac4c60e
33
python/tests/test_cross_entropy.py
Normal file
33
python/tests/test_cross_entropy.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
import triton
|
||||
import pytest
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['float16', 'float32']\
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
)
|
||||
def test_op(M, N, dtype, mode):
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||
# create inputs
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
||||
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
||||
# forward pass
|
||||
tt_y = triton.ops.cross_entropy(x, idx)
|
||||
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
|
||||
if mode == 'forward':
|
||||
assert torch.allclose(th_y, tt_y, atol=1e-3, rtol=1e-2)
|
||||
# backward pass
|
||||
elif mode == 'backward':
|
||||
dy = torch.randn_like(tt_y)
|
||||
# triton backward
|
||||
tt_y.backward(dy)
|
||||
tt_dx = x.grad.clone()
|
||||
# torch backward
|
||||
x.grad.zero_()
|
||||
th_y.backward(dy)
|
||||
th_dx = x.grad.clone()
|
||||
assert torch.allclose(th_dx, tt_dx, atol=1e-3, rtol=1e-2)
|
@@ -12,25 +12,14 @@ def cleanup():
|
||||
_triton.cleanup()
|
||||
|
||||
codes = {
|
||||
_triton.arg_type.int1: 'B',
|
||||
_triton.arg_type.int8: 'B',
|
||||
_triton.arg_type.int32: 'I',
|
||||
_triton.arg_type.int64: 'Q',
|
||||
_triton.arg_type.half: 'H',
|
||||
_triton.arg_type.float: 'f',
|
||||
_triton.arg_type.double: 'd',
|
||||
_triton.arg_type.buffer: 'P'
|
||||
_triton.arg_type.int1: 'B', _triton.arg_type.int8: 'B', _triton.arg_type.int32: 'I', _triton.arg_type.int64: 'Q',
|
||||
_triton.arg_type.half: 'H', _triton.arg_type.float: 'f', _triton.arg_type.double: 'd', _triton.arg_type.buffer: 'P'
|
||||
}
|
||||
|
||||
def th_to_triton(obj):
|
||||
tys = {
|
||||
torch.int8: 'char',
|
||||
torch.int16: 'short',
|
||||
torch.int32: 'int',
|
||||
torch.int64: 'long',
|
||||
torch.float16: 'half',
|
||||
torch.float32: 'float',
|
||||
torch.float64: 'double'
|
||||
torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long', torch.float16: 'half',
|
||||
torch.float32: 'float', torch.float64: 'double'
|
||||
}
|
||||
if isinstance(obj, torch.dtype):
|
||||
return tys[obj]
|
||||
@@ -69,6 +58,7 @@ class kernel:
|
||||
_torch_utils.register_stream(self.device)
|
||||
# C++ function wrapper
|
||||
self.op_id = _triton.make_op_id()
|
||||
_torch_utils.set_device(self.device)
|
||||
_triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
|
||||
# debug mode
|
||||
self.is_debug = 'TRITON_DEBUG' in os.environ
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from .conv import _conv, conv
|
||||
from .matmul import _matmul, matmul
|
||||
from .softmax import _softmax, softmax
|
||||
from .cross_entropy import _cross_entropy, cross_entropy
|
||||
from . import blocksparse
|
42
python/triton/ops/cross_entropy.c
Normal file
42
python/triton/ops/cross_entropy.c
Normal file
@@ -0,0 +1,42 @@
|
||||
__global__ void forward(TYPE *logit,
|
||||
TYPE *modified_logit,
|
||||
long *indices __readonly,
|
||||
TYPE *result,
|
||||
int n_cols) {
|
||||
int row = get_program_id(0);
|
||||
|
||||
bool check[TILE] = ((0 ... TILE) < n_cols);
|
||||
int offset[TILE] = row * n_cols + 0 ... TILE;
|
||||
TYPE *px[TILE] = logit + offset;
|
||||
TYPE *pmodified[TILE] = modified_logit + offset;
|
||||
long local_ind = *(indices + row);
|
||||
|
||||
TYPE F16[TILE] = check ? *px : -INFINITY;
|
||||
float shifted_logit[TILE] = F16 - F16[max];
|
||||
float neg_logprob[TILE] = log(exp(shifted_logit)[+]) - shifted_logit;
|
||||
*? (check)pmodified = neg_logprob;
|
||||
__debug_barrier();
|
||||
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
|
||||
}
|
||||
|
||||
__global__ void backward(TYPE *neg_logprobs,
|
||||
long *indices,
|
||||
TYPE *dneg_logprobs,
|
||||
int n_cols) {
|
||||
|
||||
int row = get_program_id(0);
|
||||
// pointer arithmetic
|
||||
bool check[TILE] = ((0 ... TILE) < n_cols);
|
||||
int offset[TILE] = row * n_cols + 0 ... TILE;
|
||||
TYPE *px[TILE] = neg_logprobs + offset;
|
||||
long local_ind = *(indices + row);
|
||||
TYPE local_dn = *(dneg_logprobs + row);
|
||||
// We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
||||
// and we have -log(p[k]) stored, so this is easy
|
||||
TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * ? (check)px) : 0;
|
||||
// selected_logit_idx is selected logit index for our token
|
||||
bool find_one[TILE] = ((0 ... TILE) == local_ind);
|
||||
intermediate = intermediate - ((TYPE[TILE])find_one);
|
||||
// multiply by dneg_logprobs
|
||||
*? (check)px = intermediate * local_dn;
|
||||
}
|
85
python/triton/ops/cross_entropy.py
Normal file
85
python/triton/ops/cross_entropy.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
import triton
|
||||
import torch
|
||||
|
||||
def next_power_of_2(n):
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n += 1
|
||||
return n
|
||||
|
||||
def make_kernel(device, dtype, n_cols, cache, name):
|
||||
rounded = next_power_of_2(n_cols)
|
||||
key = (dtype, rounded)
|
||||
if key not in cache:
|
||||
fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
|
||||
src = triton.read(fname, kernel_names=[name])
|
||||
infinities = {
|
||||
torch.float16: "F16_INFINITY",
|
||||
torch.float32: "F32_INFINITY",
|
||||
}
|
||||
defines = {
|
||||
"TILE": rounded,
|
||||
"TYPE": dtype,
|
||||
"INFINITY": infinities[dtype],
|
||||
}
|
||||
cache[key] = triton.kernel(src, device=device, defines=defines)
|
||||
return cache[key]
|
||||
|
||||
# forward kernel
|
||||
fwd_kernels = dict()
|
||||
make_fwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, fwd_kernels, "forward")
|
||||
|
||||
# backward kernel
|
||||
bwd_kernels = dict()
|
||||
make_bwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, bwd_kernels, "backward")
|
||||
|
||||
class _cross_entropy(torch.autograd.Function):
|
||||
@classmethod
|
||||
def forward(cls, ctx, logits, indices):
|
||||
# make sure we can use triton
|
||||
assert (indices.dtype == torch.int64), "Indices are expected to be of type long."
|
||||
# make kernel
|
||||
device, dtype = logits.device, logits.dtype
|
||||
n_cols = logits.shape[-1]
|
||||
kernel = make_fwd_kernel(device, dtype, n_cols)
|
||||
# run the kernel
|
||||
result = torch.empty_like(indices, dtype=dtype, device=device)
|
||||
neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
|
||||
kernel(logits.data_ptr(),
|
||||
neg_logprobs.data_ptr(),
|
||||
indices.data_ptr(),
|
||||
result.data_ptr(),
|
||||
n_cols,
|
||||
grid=lambda opt: (logits.numel() // n_cols, ))
|
||||
# save for backward
|
||||
ctx.save_for_backward(neg_logprobs, indices)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def backward(cls, ctx, dneg_logprobs):
|
||||
"""We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
||||
so we initialize the gradient as neg_logprobs, so we can just exponentiate
|
||||
to get p[k], which is most of what we need... neg_logprobs will be
|
||||
modified in place to become the gradient we want
|
||||
"""
|
||||
# load saved tensors
|
||||
neg_logprobs, indices = ctx.saved_tensors
|
||||
# make kernel
|
||||
device, dtype = neg_logprobs.device, neg_logprobs.dtype
|
||||
n_cols = neg_logprobs.shape[-1]
|
||||
kernel = make_bwd_kernel(device, dtype, n_cols)
|
||||
# run the kernel
|
||||
# neg_logprobs will be modified in place to become our gradient:
|
||||
kernel(neg_logprobs.data_ptr(),
|
||||
indices.data_ptr(),
|
||||
dneg_logprobs.data_ptr(),
|
||||
n_cols,
|
||||
grid=lambda opt: (neg_logprobs.numel() // n_cols, ))
|
||||
return neg_logprobs, None
|
||||
|
||||
cross_entropy = _cross_entropy.apply
|
Reference in New Issue
Block a user