Added a Softmax Xent Op (#53)

Also includes a bugfix in kernel.py to set the device before registering the c++ function object
This commit is contained in:
Jared Kaplan
2021-02-07 15:53:42 -05:00
committed by Philippe Tillet
parent dffd66bc83
commit 682ac4c60e
5 changed files with 166 additions and 15 deletions

View File

@@ -0,0 +1,33 @@
import torch
import triton
import pytest
@pytest.mark.parametrize("M, N, dtype, mode",
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']\
for mode in ['forward', 'backward']
]
)
def test_op(M, N, dtype, mode):
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
# create inputs
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
# forward pass
tt_y = triton.ops.cross_entropy(x, idx)
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
if mode == 'forward':
assert torch.allclose(th_y, tt_y, atol=1e-3, rtol=1e-2)
# backward pass
elif mode == 'backward':
dy = torch.randn_like(tt_y)
# triton backward
tt_y.backward(dy)
tt_dx = x.grad.clone()
# torch backward
x.grad.zero_()
th_y.backward(dy)
th_dx = x.grad.clone()
assert torch.allclose(th_dx, tt_dx, atol=1e-3, rtol=1e-2)

View File

@@ -12,25 +12,14 @@ def cleanup():
_triton.cleanup() _triton.cleanup()
codes = { codes = {
_triton.arg_type.int1: 'B', _triton.arg_type.int1: 'B', _triton.arg_type.int8: 'B', _triton.arg_type.int32: 'I', _triton.arg_type.int64: 'Q',
_triton.arg_type.int8: 'B', _triton.arg_type.half: 'H', _triton.arg_type.float: 'f', _triton.arg_type.double: 'd', _triton.arg_type.buffer: 'P'
_triton.arg_type.int32: 'I',
_triton.arg_type.int64: 'Q',
_triton.arg_type.half: 'H',
_triton.arg_type.float: 'f',
_triton.arg_type.double: 'd',
_triton.arg_type.buffer: 'P'
} }
def th_to_triton(obj): def th_to_triton(obj):
tys = { tys = {
torch.int8: 'char', torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long', torch.float16: 'half',
torch.int16: 'short', torch.float32: 'float', torch.float64: 'double'
torch.int32: 'int',
torch.int64: 'long',
torch.float16: 'half',
torch.float32: 'float',
torch.float64: 'double'
} }
if isinstance(obj, torch.dtype): if isinstance(obj, torch.dtype):
return tys[obj] return tys[obj]
@@ -69,6 +58,7 @@ class kernel:
_torch_utils.register_stream(self.device) _torch_utils.register_stream(self.device)
# C++ function wrapper # C++ function wrapper
self.op_id = _triton.make_op_id() self.op_id = _triton.make_op_id()
_torch_utils.set_device(self.device)
_triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key) _triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
# debug mode # debug mode
self.is_debug = 'TRITON_DEBUG' in os.environ self.is_debug = 'TRITON_DEBUG' in os.environ

View File

@@ -1,4 +1,5 @@
from .conv import _conv, conv from .conv import _conv, conv
from .matmul import _matmul, matmul from .matmul import _matmul, matmul
from .softmax import _softmax, softmax from .softmax import _softmax, softmax
from .cross_entropy import _cross_entropy, cross_entropy
from . import blocksparse from . import blocksparse

View File

@@ -0,0 +1,42 @@
__global__ void forward(TYPE *logit,
TYPE *modified_logit,
long *indices __readonly,
TYPE *result,
int n_cols) {
int row = get_program_id(0);
bool check[TILE] = ((0 ... TILE) < n_cols);
int offset[TILE] = row * n_cols + 0 ... TILE;
TYPE *px[TILE] = logit + offset;
TYPE *pmodified[TILE] = modified_logit + offset;
long local_ind = *(indices + row);
TYPE F16[TILE] = check ? *px : -INFINITY;
float shifted_logit[TILE] = F16 - F16[max];
float neg_logprob[TILE] = log(exp(shifted_logit)[+]) - shifted_logit;
*? (check)pmodified = neg_logprob;
__debug_barrier();
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
}
__global__ void backward(TYPE *neg_logprobs,
long *indices,
TYPE *dneg_logprobs,
int n_cols) {
int row = get_program_id(0);
// pointer arithmetic
bool check[TILE] = ((0 ... TILE) < n_cols);
int offset[TILE] = row * n_cols + 0 ... TILE;
TYPE *px[TILE] = neg_logprobs + offset;
long local_ind = *(indices + row);
TYPE local_dn = *(dneg_logprobs + row);
// We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
// and we have -log(p[k]) stored, so this is easy
TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * ? (check)px) : 0;
// selected_logit_idx is selected logit index for our token
bool find_one[TILE] = ((0 ... TILE) == local_ind);
intermediate = intermediate - ((TYPE[TILE])find_one);
// multiply by dneg_logprobs
*? (check)px = intermediate * local_dn;
}

View File

@@ -0,0 +1,85 @@
import os
import triton
import torch
def next_power_of_2(n):
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
def make_kernel(device, dtype, n_cols, cache, name):
rounded = next_power_of_2(n_cols)
key = (dtype, rounded)
if key not in cache:
fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
src = triton.read(fname, kernel_names=[name])
infinities = {
torch.float16: "F16_INFINITY",
torch.float32: "F32_INFINITY",
}
defines = {
"TILE": rounded,
"TYPE": dtype,
"INFINITY": infinities[dtype],
}
cache[key] = triton.kernel(src, device=device, defines=defines)
return cache[key]
# forward kernel
fwd_kernels = dict()
make_fwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, fwd_kernels, "forward")
# backward kernel
bwd_kernels = dict()
make_bwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, bwd_kernels, "backward")
class _cross_entropy(torch.autograd.Function):
@classmethod
def forward(cls, ctx, logits, indices):
# make sure we can use triton
assert (indices.dtype == torch.int64), "Indices are expected to be of type long."
# make kernel
device, dtype = logits.device, logits.dtype
n_cols = logits.shape[-1]
kernel = make_fwd_kernel(device, dtype, n_cols)
# run the kernel
result = torch.empty_like(indices, dtype=dtype, device=device)
neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
kernel(logits.data_ptr(),
neg_logprobs.data_ptr(),
indices.data_ptr(),
result.data_ptr(),
n_cols,
grid=lambda opt: (logits.numel() // n_cols, ))
# save for backward
ctx.save_for_backward(neg_logprobs, indices)
return result
@classmethod
def backward(cls, ctx, dneg_logprobs):
"""We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
so we initialize the gradient as neg_logprobs, so we can just exponentiate
to get p[k], which is most of what we need... neg_logprobs will be
modified in place to become the gradient we want
"""
# load saved tensors
neg_logprobs, indices = ctx.saved_tensors
# make kernel
device, dtype = neg_logprobs.device, neg_logprobs.dtype
n_cols = neg_logprobs.shape[-1]
kernel = make_bwd_kernel(device, dtype, n_cols)
# run the kernel
# neg_logprobs will be modified in place to become our gradient:
kernel(neg_logprobs.data_ptr(),
indices.data_ptr(),
dneg_logprobs.data_ptr(),
n_cols,
grid=lambda opt: (neg_logprobs.numel() // n_cols, ))
return neg_logprobs, None
cross_entropy = _cross_entropy.apply