Files
triton/python/triton/ops/cross_entropy.py
Jared Kaplan db55ef4fa7 Added a Softmax Xent Op (#53)
Also includes a bugfix in kernel.py to set the device before registering the c++ function object
2021-02-07 12:53:42 -08:00

85 lines
2.9 KiB
Python

import os
import triton
import torch
def next_power_of_2(n):
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
def make_kernel(device, dtype, n_cols, cache, name):
rounded = next_power_of_2(n_cols)
key = (dtype, rounded)
if key not in cache:
fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
src = triton.read(fname, kernel_names=[name])
infinities = {
torch.float16: "F16_INFINITY",
torch.float32: "F32_INFINITY",
}
defines = {
"TILE": rounded,
"TYPE": dtype,
"INFINITY": infinities[dtype],
}
cache[key] = triton.kernel(src, device=device, defines=defines)
return cache[key]
# forward kernel
fwd_kernels = dict()
make_fwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, fwd_kernels, "forward")
# backward kernel
bwd_kernels = dict()
make_bwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, bwd_kernels, "backward")
class _cross_entropy(torch.autograd.Function):
@classmethod
def forward(cls, ctx, logits, indices):
# make sure we can use triton
assert (indices.dtype == torch.int64), "Indices are expected to be of type long."
# make kernel
device, dtype = logits.device, logits.dtype
n_cols = logits.shape[-1]
kernel = make_fwd_kernel(device, dtype, n_cols)
# run the kernel
result = torch.empty_like(indices, dtype=dtype, device=device)
neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
kernel(logits.data_ptr(),
neg_logprobs.data_ptr(),
indices.data_ptr(),
result.data_ptr(),
n_cols,
grid=lambda opt: (logits.numel() // n_cols, ))
# save for backward
ctx.save_for_backward(neg_logprobs, indices)
return result
@classmethod
def backward(cls, ctx, dneg_logprobs):
"""We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
so we initialize the gradient as neg_logprobs, so we can just exponentiate
to get p[k], which is most of what we need... neg_logprobs will be
modified in place to become the gradient we want
"""
# load saved tensors
neg_logprobs, indices = ctx.saved_tensors
# make kernel
device, dtype = neg_logprobs.device, neg_logprobs.dtype
n_cols = neg_logprobs.shape[-1]
kernel = make_bwd_kernel(device, dtype, n_cols)
# run the kernel
# neg_logprobs will be modified in place to become our gradient:
kernel(neg_logprobs.data_ptr(),
indices.data_ptr(),
dneg_logprobs.data_ptr(),
n_cols,
grid=lambda opt: (neg_logprobs.numel() // n_cols, ))
return neg_logprobs, None
cross_entropy = _cross_entropy.apply