[LANG] Fixed semantics of NaN in float comparisons (#281)
This commit is contained in:
@@ -275,7 +275,7 @@ ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builde
|
|||||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||||
// float > float
|
// float > float
|
||||||
if (scalar_ty->is_floating_point_ty())
|
if (scalar_ty->is_floating_point_ty())
|
||||||
return builder->create_fcmpUGT(input, other);
|
return builder->create_fcmpOGT(input, other);
|
||||||
// int > int
|
// int > int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty())
|
||||||
return builder->create_icmpSGT(input, other);
|
return builder->create_icmpSGT(input, other);
|
||||||
@@ -287,7 +287,7 @@ ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::build
|
|||||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||||
// float >= float
|
// float >= float
|
||||||
if (scalar_ty->is_floating_point_ty())
|
if (scalar_ty->is_floating_point_ty())
|
||||||
return builder->create_fcmpUGE(input, other);
|
return builder->create_fcmpOGE(input, other);
|
||||||
// int >= int
|
// int >= int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty())
|
||||||
return builder->create_icmpSGE(input, other);
|
return builder->create_icmpSGE(input, other);
|
||||||
@@ -299,7 +299,7 @@ ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *
|
|||||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||||
// float < float
|
// float < float
|
||||||
if (scalar_ty->is_floating_point_ty())
|
if (scalar_ty->is_floating_point_ty())
|
||||||
return builder->create_fcmpULT(input, other);
|
return builder->create_fcmpOLT(input, other);
|
||||||
// int < int
|
// int < int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty())
|
||||||
return builder->create_icmpSLT(input, other);
|
return builder->create_icmpSLT(input, other);
|
||||||
@@ -311,7 +311,7 @@ ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder
|
|||||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||||
// float < float
|
// float < float
|
||||||
if (scalar_ty->is_floating_point_ty())
|
if (scalar_ty->is_floating_point_ty())
|
||||||
return builder->create_fcmpULE(input, other);
|
return builder->create_fcmpOLE(input, other);
|
||||||
// int < int
|
// int < int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty())
|
||||||
return builder->create_icmpSLE(input, other);
|
return builder->create_icmpSLE(input, other);
|
||||||
@@ -323,7 +323,7 @@ ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *buil
|
|||||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||||
// float == float
|
// float == float
|
||||||
if (scalar_ty->is_floating_point_ty())
|
if (scalar_ty->is_floating_point_ty())
|
||||||
return builder->create_fcmpUEQ(input, other);
|
return builder->create_fcmpOEQ(input, other);
|
||||||
// int == int
|
// int == int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty())
|
||||||
return builder->create_icmpEQ(input, other);
|
return builder->create_icmpEQ(input, other);
|
||||||
|
@@ -69,7 +69,7 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
|
|||||||
triton.testing.assert_almost_equal(z_ref, z_tri)
|
triton.testing.assert_almost_equal(z_ref, z_tri)
|
||||||
|
|
||||||
|
|
||||||
def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
|
def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='cuda'):
|
||||||
SIZE = 128
|
SIZE = 128
|
||||||
# define the kernel / launch-grid
|
# define the kernel / launch-grid
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -84,6 +84,8 @@ def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
|
|||||||
# inputs
|
# inputs
|
||||||
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
|
||||||
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
|
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
|
||||||
|
if mode_x == 'nan': x[:] = float('nan')
|
||||||
|
if mode_y == 'nan': y[:] = float('nan')
|
||||||
# reference result
|
# reference result
|
||||||
z_ref = eval(expr)
|
z_ref = eval(expr)
|
||||||
# triton result
|
# triton result
|
||||||
@@ -126,14 +128,25 @@ def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
|
|||||||
# ---------------
|
# ---------------
|
||||||
# test compare ops
|
# test compare ops
|
||||||
# ---------------
|
# ---------------
|
||||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
|
ops = ['==', '!=', '>', '<', '>=', '<=']
|
||||||
(dtype_x, dtype_y, f' x {op} y') \
|
@pytest.mark.parametrize("dtype_x, dtype_y, expr, mode_x, mode_y", \
|
||||||
for op in ['==', '!=', '>', '<', '>=', '<='] \
|
# real
|
||||||
|
[
|
||||||
|
(dtype_x, dtype_y, f' x {op} y', 'real', 'real') \
|
||||||
|
for op in ops \
|
||||||
for dtype_x in dtypes \
|
for dtype_x in dtypes \
|
||||||
for dtype_y in dtypes
|
for dtype_y in dtypes
|
||||||
|
] + \
|
||||||
|
# NaNs
|
||||||
|
[('float32', 'float32', f' x {op} y', mode_x, mode_y) \
|
||||||
|
for op in ops
|
||||||
|
for mode_x, mode_y in [('nan' , 'real'),
|
||||||
|
('real', 'nan'),
|
||||||
|
('nan' , 'nan')]
|
||||||
|
|
||||||
])
|
])
|
||||||
def test_compare_op(dtype_x, dtype_y, expr, device='cuda'):
|
def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'):
|
||||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
@@ -1,6 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
from .code_gen import OutOfResources
|
from .code_gen import OutOfResources
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton._C.libtriton.cutlass as _cutlass
|
import triton._C.libtriton.cutlass as _cutlass
|
||||||
@@ -99,7 +102,15 @@ def random(shape, dtype, device):
|
|||||||
raise RuntimeError(f'Unknown dtype {dtype}')
|
raise RuntimeError(f'Unknown dtype {dtype}')
|
||||||
|
|
||||||
|
|
||||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
|
def nvsmi(attrs):
|
||||||
|
attrs = ','.join(attrs)
|
||||||
|
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
|
||||||
|
out = subprocess.check_output(cmd)
|
||||||
|
ret = out.decode(sys.stdout.encoding).split(',')
|
||||||
|
ret = [int(x) for x in ret]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8], record_clocks=False):
|
||||||
"""
|
"""
|
||||||
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
||||||
the 20-th and 80-th performance percentile.
|
the 20-th and 80-th performance percentile.
|
||||||
@@ -127,17 +138,21 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
|
|||||||
end_event.record()
|
end_event.record()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
estimate_ms = start_event.elapsed_time(end_event) / 5
|
estimate_ms = start_event.elapsed_time(end_event) / 5
|
||||||
|
# compute number of warmup and repeat
|
||||||
|
n_warmup = max(1, int(warmup/estimate_ms))
|
||||||
|
n_repeat = max(1, int(rep/estimate_ms))
|
||||||
# We maintain a buffer of 256 MB that we clear
|
# We maintain a buffer of 256 MB that we clear
|
||||||
# before each kernel call to make sure that the L2
|
# before each kernel call to make sure that the L2
|
||||||
# doesn't contain any input data before the run
|
# doesn't contain any input data before the run
|
||||||
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
|
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
|
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||||
|
clocks = [None for i in range(n_repeat)]
|
||||||
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
||||||
# Warm-up
|
# Warm-up
|
||||||
for _ in range(int(warmup / estimate_ms)):
|
for _ in range(n_warmup):
|
||||||
fn()
|
fn()
|
||||||
# Benchmark
|
# Benchmark
|
||||||
for i in range(rep):
|
for i in range(n_repeat):
|
||||||
# we don't want `fn` to accumulate gradient values
|
# we don't want `fn` to accumulate gradient values
|
||||||
# if it contains a backward pass. So we clear the
|
# if it contains a backward pass. So we clear the
|
||||||
# provided gradients
|
# provided gradients
|
||||||
@@ -150,11 +165,12 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
|
|||||||
start_event[i].record()
|
start_event[i].record()
|
||||||
fn()
|
fn()
|
||||||
end_event[i].record()
|
end_event[i].record()
|
||||||
|
# Record clocks
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
|
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
|
||||||
percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
|
|
||||||
med_ms = torch.median(times).item()
|
med_ms = torch.median(times).item()
|
||||||
if percentiles:
|
if percentiles:
|
||||||
|
percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
|
||||||
return tuple([med_ms] + percentiles)
|
return tuple([med_ms] + percentiles)
|
||||||
else:
|
else:
|
||||||
return med_ms
|
return med_ms
|
||||||
|
Reference in New Issue
Block a user