[LANG] Fixed semantics of NaN in float comparisons (#281)
This commit is contained in:
		@@ -275,7 +275,7 @@ ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builde
 | 
			
		||||
  ir::type *scalar_ty = input->get_type()->get_scalar_ty();
 | 
			
		||||
  // float > float
 | 
			
		||||
  if (scalar_ty->is_floating_point_ty())
 | 
			
		||||
    return builder->create_fcmpUGT(input, other);
 | 
			
		||||
    return builder->create_fcmpOGT(input, other);
 | 
			
		||||
  // int > int
 | 
			
		||||
  else if (scalar_ty->is_integer_ty())
 | 
			
		||||
    return builder->create_icmpSGT(input, other);
 | 
			
		||||
@@ -287,7 +287,7 @@ ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::build
 | 
			
		||||
  ir::type *scalar_ty = input->get_type()->get_scalar_ty();
 | 
			
		||||
  // float >= float
 | 
			
		||||
  if (scalar_ty->is_floating_point_ty())
 | 
			
		||||
    return builder->create_fcmpUGE(input, other);
 | 
			
		||||
    return builder->create_fcmpOGE(input, other);
 | 
			
		||||
  // int >= int
 | 
			
		||||
  else if (scalar_ty->is_integer_ty())
 | 
			
		||||
    return builder->create_icmpSGE(input, other);
 | 
			
		||||
@@ -299,7 +299,7 @@ ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *
 | 
			
		||||
  ir::type *scalar_ty = input->get_type()->get_scalar_ty();
 | 
			
		||||
  // float < float
 | 
			
		||||
  if (scalar_ty->is_floating_point_ty())
 | 
			
		||||
    return builder->create_fcmpULT(input, other);
 | 
			
		||||
    return builder->create_fcmpOLT(input, other);
 | 
			
		||||
  // int < int
 | 
			
		||||
  else if (scalar_ty->is_integer_ty())
 | 
			
		||||
    return builder->create_icmpSLT(input, other);
 | 
			
		||||
@@ -311,7 +311,7 @@ ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder
 | 
			
		||||
  ir::type *scalar_ty = input->get_type()->get_scalar_ty();
 | 
			
		||||
  // float < float
 | 
			
		||||
  if (scalar_ty->is_floating_point_ty())
 | 
			
		||||
    return builder->create_fcmpULE(input, other);
 | 
			
		||||
    return builder->create_fcmpOLE(input, other);
 | 
			
		||||
  // int < int
 | 
			
		||||
  else if (scalar_ty->is_integer_ty())
 | 
			
		||||
    return builder->create_icmpSLE(input, other);
 | 
			
		||||
@@ -323,7 +323,7 @@ ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *buil
 | 
			
		||||
  ir::type *scalar_ty = input->get_type()->get_scalar_ty();
 | 
			
		||||
  // float == float
 | 
			
		||||
  if (scalar_ty->is_floating_point_ty())
 | 
			
		||||
    return builder->create_fcmpUEQ(input, other);
 | 
			
		||||
    return builder->create_fcmpOEQ(input, other);
 | 
			
		||||
  // int == int
 | 
			
		||||
  else if (scalar_ty->is_integer_ty())
 | 
			
		||||
    return builder->create_icmpEQ(input, other);
 | 
			
		||||
 
 | 
			
		||||
@@ -69,7 +69,7 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
 | 
			
		||||
    triton.testing.assert_almost_equal(z_ref, z_tri)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
 | 
			
		||||
def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='cuda'):
 | 
			
		||||
    SIZE = 128
 | 
			
		||||
    # define the kernel / launch-grid
 | 
			
		||||
    @triton.jit
 | 
			
		||||
@@ -84,6 +84,8 @@ def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
 | 
			
		||||
    # inputs
 | 
			
		||||
    x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
 | 
			
		||||
    y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
 | 
			
		||||
    if mode_x == 'nan': x[:] = float('nan')
 | 
			
		||||
    if mode_y == 'nan': y[:] = float('nan')
 | 
			
		||||
    # reference result
 | 
			
		||||
    z_ref = eval(expr)
 | 
			
		||||
    # triton result
 | 
			
		||||
@@ -126,14 +128,25 @@ def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
 | 
			
		||||
# ---------------
 | 
			
		||||
# test compare ops
 | 
			
		||||
# ---------------
 | 
			
		||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
 | 
			
		||||
    (dtype_x, dtype_y, f' x {op} y') \
 | 
			
		||||
    for op in ['==', '!=', '>', '<', '>=', '<='] \
 | 
			
		||||
ops = ['==', '!=', '>', '<', '>=', '<=']
 | 
			
		||||
@pytest.mark.parametrize("dtype_x, dtype_y, expr, mode_x, mode_y", \
 | 
			
		||||
# real
 | 
			
		||||
[
 | 
			
		||||
    (dtype_x, dtype_y, f' x {op} y', 'real', 'real') \
 | 
			
		||||
    for op in ops \
 | 
			
		||||
    for dtype_x in dtypes \
 | 
			
		||||
    for dtype_y in dtypes
 | 
			
		||||
] + \
 | 
			
		||||
# NaNs
 | 
			
		||||
[('float32', 'float32', f' x {op} y', mode_x, mode_y) \
 | 
			
		||||
    for op in ops
 | 
			
		||||
    for mode_x, mode_y in [('nan' , 'real'), 
 | 
			
		||||
                           ('real', 'nan'), 
 | 
			
		||||
                           ('nan' , 'nan')]
 | 
			
		||||
 | 
			
		||||
])
 | 
			
		||||
def test_compare_op(dtype_x, dtype_y, expr, device='cuda'):
 | 
			
		||||
    _test_binary(dtype_x, dtype_y, expr, device=device)
 | 
			
		||||
def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'):
 | 
			
		||||
    _test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ---------------
 | 
			
		||||
@@ -1,6 +1,9 @@
 | 
			
		||||
import torch
 | 
			
		||||
import os
 | 
			
		||||
from .code_gen import OutOfResources
 | 
			
		||||
import subprocess
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    import triton._C.libtriton.cutlass as _cutlass
 | 
			
		||||
@@ -99,7 +102,15 @@ def random(shape, dtype, device):
 | 
			
		||||
    raise RuntimeError(f'Unknown dtype {dtype}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
 | 
			
		||||
def nvsmi(attrs):
 | 
			
		||||
    attrs = ','.join(attrs)
 | 
			
		||||
    cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
 | 
			
		||||
    out = subprocess.check_output(cmd)
 | 
			
		||||
    ret = out.decode(sys.stdout.encoding).split(',')
 | 
			
		||||
    ret = [int(x) for x in ret]
 | 
			
		||||
    return ret
 | 
			
		||||
 | 
			
		||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8], record_clocks=False):
 | 
			
		||||
    """
 | 
			
		||||
    Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
 | 
			
		||||
    the 20-th and 80-th performance percentile.
 | 
			
		||||
@@ -127,17 +138,21 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
 | 
			
		||||
    end_event.record()
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
    estimate_ms = start_event.elapsed_time(end_event) / 5
 | 
			
		||||
    # compute number of warmup and repeat
 | 
			
		||||
    n_warmup = max(1, int(warmup/estimate_ms))
 | 
			
		||||
    n_repeat = max(1, int(rep/estimate_ms))
 | 
			
		||||
    # We maintain a buffer of 256 MB that we clear
 | 
			
		||||
    # before each kernel call to make sure that the L2
 | 
			
		||||
    # doesn't contain any input data before the run
 | 
			
		||||
    start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
 | 
			
		||||
    end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
 | 
			
		||||
    start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
 | 
			
		||||
    end_event   = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
 | 
			
		||||
    clocks      = [None for i in range(n_repeat)]
 | 
			
		||||
    cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
 | 
			
		||||
    # Warm-up
 | 
			
		||||
    for _ in range(int(warmup / estimate_ms)):
 | 
			
		||||
    for _ in range(n_warmup):
 | 
			
		||||
        fn()
 | 
			
		||||
    # Benchmark
 | 
			
		||||
    for i in range(rep):
 | 
			
		||||
    for i in range(n_repeat):
 | 
			
		||||
        # we don't want `fn` to accumulate gradient values
 | 
			
		||||
        # if it contains a backward pass. So we clear the
 | 
			
		||||
        # provided gradients
 | 
			
		||||
@@ -150,11 +165,12 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
 | 
			
		||||
        start_event[i].record()
 | 
			
		||||
        fn()
 | 
			
		||||
        end_event[i].record()
 | 
			
		||||
    # Record clocks
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
    times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
 | 
			
		||||
    percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
 | 
			
		||||
    med_ms = torch.median(times).item()
 | 
			
		||||
    if percentiles:
 | 
			
		||||
        percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
 | 
			
		||||
        return tuple([med_ms] + percentiles)
 | 
			
		||||
    else:
 | 
			
		||||
        return med_ms
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user