From da5063d89800f1a36d0cb9377ab386979541613d Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 14 Sep 2021 01:46:32 -0700 Subject: [PATCH] [TEST] Added performance regression tests (#283) --- .github/workflows/integration-tests.yml | 15 ++- python/test/regression/test_performance.py | 108 +++++++++++++++++++++ python/triton/testing.py | 8 +- 3 files changed, 121 insertions(+), 10 deletions(-) create mode 100644 python/test/regression/test_performance.py diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 21f3f3ea3..74ae590c4 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -24,11 +24,16 @@ jobs: cd python pip3 install -e . - - name: Run benchmarks + - name: Regression tests run: | - cd python/bench - python3 -m run + cd python/test/regression + sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350 + sudo nvidia-smi -i 0 --lock-memory-clocks=877,877 + pytest -vs . + sudo nvidia-smi -i 0 -rgc + sudo nvidia-smi -i 0 -rmc - - name: Run unit tests + - name: Unit tests run: | - pytest . \ No newline at end of file + cd python/test/unit + pytest -vs . \ No newline at end of file diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py new file mode 100644 index 000000000..d8ae65742 --- /dev/null +++ b/python/test/regression/test_performance.py @@ -0,0 +1,108 @@ +from numpy import record +import torch +import triton +import subprocess +import sys +import pytest + +####################### +# Utilities +####################### + +def nvsmi(attrs): + attrs = ','.join(attrs) + cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(',') + ret = [int(x) for x in ret] + return ret + + +####################### +# Matrix Multiplication +####################### + +matmul_data = { + # square + (256 , 256 , 256 ) : {'v100': 0.027}, + (512 , 512 , 512 ) : {'v100': 0.141}, + (1024, 1024, 1024 ) : {'v100': 0.466}, + (2048, 2048, 2048 ) : {'v100': 0.680}, + (4096, 4096, 4096 ) : {'v100': 0.831}, + (8192, 8192, 8192 ) : {'v100': 0.841}, + # tall-skinny + (16 , 1024, 1024 ) : {'v100': 0.0128}, + (16 , 4096, 4096 ) : {'v100': 0.0558}, + (16 , 8192, 8192 ) : {'v100': 0.101}, + (64 , 1024, 1024 ) : {'v100': 0.049}, + (64 , 4096, 4096 ) : {'v100': 0.211}, + (64 , 8192, 8192 ) : {'v100': 0.360}, + (1024, 64 , 1024 ) : {'v100': 0.0469}, + (4096, 64 , 4096 ) : {'v100': 0.198}, + (8192, 64 , 8192 ) : {'v100': 0.323}, +# # deep reductions +# (64 , 64 , 16384) : {'v100': 0.}, +# (64 , 64 , 65536) : {'v100': 0.}, +# (256 , 256 , 8192 ) : {'v100': 0.}, +# (256 , 256 , 32768) : {'v100': 0.}, +} +@pytest.mark.parametrize('M, N, K', matmul_data.keys()) +def test_matmul(M, N, K): + ref_gpu_util = matmul_data[(M, N, K)]['v100'] + cur_sm_clock = nvsmi(['clocks.current.sm'])[0] + ref_sm_clock = 1350 + max_gpu_perf = 1e-6*80*8*128*cur_sm_clock + assert cur_sm_clock == ref_sm_clock, f'GPU SMs must run at {ref_sm_clock} MHz' + a = torch.randn((M, K), dtype=torch.float16, device='cuda') + b = torch.randn((K, N), dtype=torch.float16, device='cuda') + fn = lambda: triton.ops.matmul(a, b) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=1000) + cur_gpu_perf = 2.*M*N*K/ms * 1e-9 + cur_gpu_util = cur_gpu_perf / max_gpu_perf + triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) + +####################### +# Element-Wise +####################### +import triton.language as tl + +@triton.jit +def _add(x_ptr, y_ptr, output_ptr, n_elements, **meta): + BLOCK_SIZE = meta['BLOCK_SIZE'] + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + +elementwise_data = { + 1024*16 : {'v100': 0.0219}, + 1024*64 : {'v100': 0.0791}, + 1024*256 : {'v100': 0.243}, + 1024*1024 : {'v100': 0.534}, + 1024*4096 : {'v100': 0.796}, + 1024*16384: {'v100': 0.905}, + 1024*65536: {'v100': 0.939}, +} + +@pytest.mark.parametrize('N', elementwise_data.keys()) +def test_elementwise(N): + ref_gpu_util = elementwise_data[N]['v100'] + cur_mem_clock = nvsmi(['clocks.current.memory'])[0] + ref_mem_clock = 877 + max_gpu_perf = 512*2*ref_mem_clock*1e-3 + assert cur_mem_clock == ref_mem_clock, f'GPU memmory must run at {ref_mem_clock} MHz' + z = torch.empty((N, ), dtype=torch.float16, device='cuda') + x = torch.randn_like(z) + y = torch.randn_like(z) + grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), ) + fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250) + cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6 + cur_gpu_util = cur_gpu_perf / max_gpu_perf + triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) + diff --git a/python/triton/testing.py b/python/triton/testing.py index d41c46734..08ad62580 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -110,7 +110,7 @@ def nvsmi(attrs): ret = [int(x) for x in ret] return ret -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8], record_clocks=False): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -146,7 +146,6 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8], # doesn't contain any input data before the run start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] - clocks = [None for i in range(n_repeat)] cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') # Warm-up for _ in range(n_warmup): @@ -168,12 +167,11 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8], # Record clocks torch.cuda.synchronize() times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)]) - med_ms = torch.median(times).item() if percentiles: percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist() - return tuple([med_ms] + percentiles) + return tuple(percentiles) else: - return med_ms + return torch.mean(times).item() class Benchmark: