diff --git a/python/tests/test_cross_entropy.py b/python/tests/test_cross_entropy.py new file mode 100644 index 000000000..969774a11 --- /dev/null +++ b/python/tests/test_cross_entropy.py @@ -0,0 +1,33 @@ +import torch +import triton +import pytest + +@pytest.mark.parametrize("M, N, dtype, mode", + [ + (M, N, dtype, mode) for M in [1024, 821] + for N in [512, 857, 1871, 2089, 8573, 31000] + for dtype in ['float16', 'float32']\ + for mode in ['forward', 'backward'] + ] + ) +def test_op(M, N, dtype, mode): + dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype] + # create inputs + x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True) + idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda') + # forward pass + tt_y = triton.ops.cross_entropy(x, idx) + th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx) + if mode == 'forward': + assert torch.allclose(th_y, tt_y, atol=1e-3, rtol=1e-2) + # backward pass + elif mode == 'backward': + dy = torch.randn_like(tt_y) + # triton backward + tt_y.backward(dy) + tt_dx = x.grad.clone() + # torch backward + x.grad.zero_() + th_y.backward(dy) + th_dx = x.grad.clone() + assert torch.allclose(th_dx, tt_dx, atol=1e-3, rtol=1e-2) \ No newline at end of file diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 585c7b8e0..02c6f8f46 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -12,25 +12,14 @@ def cleanup(): _triton.cleanup() codes = { - _triton.arg_type.int1: 'B', - _triton.arg_type.int8: 'B', - _triton.arg_type.int32: 'I', - _triton.arg_type.int64: 'Q', - _triton.arg_type.half: 'H', - _triton.arg_type.float: 'f', - _triton.arg_type.double: 'd', - _triton.arg_type.buffer: 'P' + _triton.arg_type.int1: 'B', _triton.arg_type.int8: 'B', _triton.arg_type.int32: 'I', _triton.arg_type.int64: 'Q', + _triton.arg_type.half: 'H', _triton.arg_type.float: 'f', _triton.arg_type.double: 'd', _triton.arg_type.buffer: 'P' } def th_to_triton(obj): tys = { - torch.int8: 'char', - torch.int16: 'short', - torch.int32: 'int', - torch.int64: 'long', - torch.float16: 'half', - torch.float32: 'float', - torch.float64: 'double' + torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long', torch.float16: 'half', + torch.float32: 'float', torch.float64: 'double' } if isinstance(obj, torch.dtype): return tys[obj] @@ -69,6 +58,7 @@ class kernel: _torch_utils.register_stream(self.device) # C++ function wrapper self.op_id = _triton.make_op_id() + _torch_utils.set_device(self.device) _triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key) # debug mode self.is_debug = 'TRITON_DEBUG' in os.environ diff --git a/python/triton/ops/__init__.py b/python/triton/ops/__init__.py index b2426c678..ef4974751 100644 --- a/python/triton/ops/__init__.py +++ b/python/triton/ops/__init__.py @@ -1,4 +1,5 @@ from .conv import _conv, conv from .matmul import _matmul, matmul from .softmax import _softmax, softmax +from .cross_entropy import _cross_entropy, cross_entropy from . import blocksparse \ No newline at end of file diff --git a/python/triton/ops/cross_entropy.c b/python/triton/ops/cross_entropy.c new file mode 100644 index 000000000..118f5d145 --- /dev/null +++ b/python/triton/ops/cross_entropy.c @@ -0,0 +1,42 @@ +__global__ void forward(TYPE *logit, + TYPE *modified_logit, + long *indices __readonly, + TYPE *result, + int n_cols) { + int row = get_program_id(0); + + bool check[TILE] = ((0 ... TILE) < n_cols); + int offset[TILE] = row * n_cols + 0 ... TILE; + TYPE *px[TILE] = logit + offset; + TYPE *pmodified[TILE] = modified_logit + offset; + long local_ind = *(indices + row); + + TYPE F16[TILE] = check ? *px : -INFINITY; + float shifted_logit[TILE] = F16 - F16[max]; + float neg_logprob[TILE] = log(exp(shifted_logit)[+]) - shifted_logit; + *? (check)pmodified = neg_logprob; + __debug_barrier(); + *(result + row) = *(modified_logit + (local_ind + n_cols * row)); +} + +__global__ void backward(TYPE *neg_logprobs, + long *indices, + TYPE *dneg_logprobs, + int n_cols) { + + int row = get_program_id(0); + // pointer arithmetic + bool check[TILE] = ((0 ... TILE) < n_cols); + int offset[TILE] = row * n_cols + 0 ... TILE; + TYPE *px[TILE] = neg_logprobs + offset; + long local_ind = *(indices + row); + TYPE local_dn = *(dneg_logprobs + row); + // We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + // and we have -log(p[k]) stored, so this is easy + TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * ? (check)px) : 0; + // selected_logit_idx is selected logit index for our token + bool find_one[TILE] = ((0 ... TILE) == local_ind); + intermediate = intermediate - ((TYPE[TILE])find_one); + // multiply by dneg_logprobs + *? (check)px = intermediate * local_dn; +} \ No newline at end of file diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py new file mode 100644 index 000000000..05a512238 --- /dev/null +++ b/python/triton/ops/cross_entropy.py @@ -0,0 +1,85 @@ +import os +import triton +import torch + +def next_power_of_2(n): + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n += 1 + return n + +def make_kernel(device, dtype, n_cols, cache, name): + rounded = next_power_of_2(n_cols) + key = (dtype, rounded) + if key not in cache: + fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c") + src = triton.read(fname, kernel_names=[name]) + infinities = { + torch.float16: "F16_INFINITY", + torch.float32: "F32_INFINITY", + } + defines = { + "TILE": rounded, + "TYPE": dtype, + "INFINITY": infinities[dtype], + } + cache[key] = triton.kernel(src, device=device, defines=defines) + return cache[key] + +# forward kernel +fwd_kernels = dict() +make_fwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, fwd_kernels, "forward") + +# backward kernel +bwd_kernels = dict() +make_bwd_kernel = lambda device, dtype, n_cols: make_kernel(device, dtype, n_cols, bwd_kernels, "backward") + +class _cross_entropy(torch.autograd.Function): + @classmethod + def forward(cls, ctx, logits, indices): + # make sure we can use triton + assert (indices.dtype == torch.int64), "Indices are expected to be of type long." + # make kernel + device, dtype = logits.device, logits.dtype + n_cols = logits.shape[-1] + kernel = make_fwd_kernel(device, dtype, n_cols) + # run the kernel + result = torch.empty_like(indices, dtype=dtype, device=device) + neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device) + kernel(logits.data_ptr(), + neg_logprobs.data_ptr(), + indices.data_ptr(), + result.data_ptr(), + n_cols, + grid=lambda opt: (logits.numel() // n_cols, )) + # save for backward + ctx.save_for_backward(neg_logprobs, indices) + return result + + @classmethod + def backward(cls, ctx, dneg_logprobs): + """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + so we initialize the gradient as neg_logprobs, so we can just exponentiate + to get p[k], which is most of what we need... neg_logprobs will be + modified in place to become the gradient we want + """ + # load saved tensors + neg_logprobs, indices = ctx.saved_tensors + # make kernel + device, dtype = neg_logprobs.device, neg_logprobs.dtype + n_cols = neg_logprobs.shape[-1] + kernel = make_bwd_kernel(device, dtype, n_cols) + # run the kernel + # neg_logprobs will be modified in place to become our gradient: + kernel(neg_logprobs.data_ptr(), + indices.data_ptr(), + dneg_logprobs.data_ptr(), + n_cols, + grid=lambda opt: (neg_logprobs.numel() // n_cols, )) + return neg_logprobs, None + +cross_entropy = _cross_entropy.apply \ No newline at end of file