[LANG] Fixed semantics of NaN in float comparisons (#281)

This commit is contained in:
Philippe Tillet
2021-09-13 15:06:29 -07:00
committed by GitHub
parent cecca90bea
commit 3e395bc84e
8 changed files with 46 additions and 17 deletions

View File

@@ -69,7 +69,7 @@ def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
triton.testing.assert_almost_equal(z_ref, z_tri)
def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
def _test_binary(dtype_x, dtype_y, expr, mode_x='real', mode_y='real', device='cuda'):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
@@ -84,6 +84,8 @@ def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
# inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
y = triton.testing.random(SIZE, dtype=cvt[dtype_y], device=device)
if mode_x == 'nan': x[:] = float('nan')
if mode_y == 'nan': y[:] = float('nan')
# reference result
z_ref = eval(expr)
# triton result
@@ -126,14 +128,25 @@ def test_bitwise_op(dtype_x, dtype_y, expr, device='cuda'):
# ---------------
# test compare ops
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_y, expr", [
(dtype_x, dtype_y, f' x {op} y') \
for op in ['==', '!=', '>', '<', '>=', '<='] \
ops = ['==', '!=', '>', '<', '>=', '<=']
@pytest.mark.parametrize("dtype_x, dtype_y, expr, mode_x, mode_y", \
# real
[
(dtype_x, dtype_y, f' x {op} y', 'real', 'real') \
for op in ops \
for dtype_x in dtypes \
for dtype_y in dtypes
] + \
# NaNs
[('float32', 'float32', f' x {op} y', mode_x, mode_y) \
for op in ops
for mode_x, mode_y in [('nan' , 'real'),
('real', 'nan'),
('nan' , 'nan')]
])
def test_compare_op(dtype_x, dtype_y, expr, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, device=device)
def test_compare_op(dtype_x, dtype_y, expr, mode_x, mode_y, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
# ---------------

View File

@@ -1,6 +1,9 @@
import torch
import os
from .code_gen import OutOfResources
import subprocess
import sys
try:
import triton._C.libtriton.cutlass as _cutlass
@@ -99,7 +102,15 @@ def random(shape, dtype, device):
raise RuntimeError(f'Unknown dtype {dtype}')
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
def nvsmi(attrs):
attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
out = subprocess.check_output(cmd)
ret = out.decode(sys.stdout.encoding).split(',')
ret = [int(x) for x in ret]
return ret
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8], record_clocks=False):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
@@ -127,17 +138,21 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup/estimate_ms))
n_repeat = max(1, int(rep/estimate_ms))
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
clocks = [None for i in range(n_repeat)]
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Warm-up
for _ in range(int(warmup / estimate_ms)):
for _ in range(n_warmup):
fn()
# Benchmark
for i in range(rep):
for i in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
@@ -150,11 +165,12 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
start_event[i].record()
fn()
end_event[i].record()
# Record clocks
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
med_ms = torch.median(times).item()
if percentiles:
percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
return tuple([med_ms] + percentiles)
else:
return med_ms