2022-01-06 14:34:17 -08:00
|
|
|
import pytest
|
2021-02-07 15:53:42 -05:00
|
|
|
import torch
|
2022-01-06 14:34:17 -08:00
|
|
|
|
2021-02-07 15:53:42 -05:00
|
|
|
import triton
|
2022-01-06 14:34:17 -08:00
|
|
|
|
2021-02-07 15:53:42 -05:00
|
|
|
|
|
|
|
@pytest.mark.parametrize("M, N, dtype, mode",
|
2022-01-06 14:34:17 -08:00
|
|
|
[
|
|
|
|
(M, N, dtype, mode) for M in [1024, 821]
|
|
|
|
for N in [512, 857, 1871, 2089, 8573, 31000]
|
2022-12-21 01:30:50 -08:00
|
|
|
for dtype in ['float16', 'float32']
|
2022-01-06 14:34:17 -08:00
|
|
|
for mode in ['forward', 'backward']
|
|
|
|
]
|
2021-02-07 15:53:42 -05:00
|
|
|
)
|
|
|
|
def test_op(M, N, dtype, mode):
|
2022-12-21 01:30:50 -08:00
|
|
|
capability = torch.cuda.get_device_capability()
|
|
|
|
if capability[0] < 8 and dtype == "bfloat16":
|
2022-07-20 12:22:37 +08:00
|
|
|
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
|
|
|
dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype]
|
2021-02-07 15:53:42 -05:00
|
|
|
# create inputs
|
|
|
|
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
|
|
|
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
|
|
|
# forward pass
|
|
|
|
tt_y = triton.ops.cross_entropy(x, idx)
|
|
|
|
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
|
|
|
|
if mode == 'forward':
|
2021-08-12 12:00:30 -07:00
|
|
|
triton.testing.assert_almost_equal(th_y, tt_y)
|
2021-02-07 15:53:42 -05:00
|
|
|
# backward pass
|
|
|
|
elif mode == 'backward':
|
|
|
|
dy = torch.randn_like(tt_y)
|
|
|
|
# triton backward
|
|
|
|
tt_y.backward(dy)
|
|
|
|
tt_dx = x.grad.clone()
|
|
|
|
# torch backward
|
|
|
|
x.grad.zero_()
|
|
|
|
th_y.backward(dy)
|
|
|
|
th_dx = x.grad.clone()
|
2022-01-06 14:34:17 -08:00
|
|
|
triton.testing.assert_almost_equal(th_dx, tt_dx)
|