[TEST] Added performance regression tests (#283)
This commit is contained in:
15
.github/workflows/integration-tests.yml
vendored
15
.github/workflows/integration-tests.yml
vendored
@@ -24,11 +24,16 @@ jobs:
|
||||
cd python
|
||||
pip3 install -e .
|
||||
|
||||
- name: Run benchmarks
|
||||
- name: Regression tests
|
||||
run: |
|
||||
cd python/bench
|
||||
python3 -m run
|
||||
cd python/test/regression
|
||||
sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350
|
||||
sudo nvidia-smi -i 0 --lock-memory-clocks=877,877
|
||||
pytest -vs .
|
||||
sudo nvidia-smi -i 0 -rgc
|
||||
sudo nvidia-smi -i 0 -rmc
|
||||
|
||||
- name: Run unit tests
|
||||
- name: Unit tests
|
||||
run: |
|
||||
pytest .
|
||||
cd python/test/unit
|
||||
pytest -vs .
|
108
python/test/regression/test_performance.py
Normal file
108
python/test/regression/test_performance.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from numpy import record
|
||||
import torch
|
||||
import triton
|
||||
import subprocess
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
#######################
|
||||
# Utilities
|
||||
#######################
|
||||
|
||||
def nvsmi(attrs):
|
||||
attrs = ','.join(attrs)
|
||||
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
|
||||
out = subprocess.check_output(cmd)
|
||||
ret = out.decode(sys.stdout.encoding).split(',')
|
||||
ret = [int(x) for x in ret]
|
||||
return ret
|
||||
|
||||
|
||||
#######################
|
||||
# Matrix Multiplication
|
||||
#######################
|
||||
|
||||
matmul_data = {
|
||||
# square
|
||||
(256 , 256 , 256 ) : {'v100': 0.027},
|
||||
(512 , 512 , 512 ) : {'v100': 0.141},
|
||||
(1024, 1024, 1024 ) : {'v100': 0.466},
|
||||
(2048, 2048, 2048 ) : {'v100': 0.680},
|
||||
(4096, 4096, 4096 ) : {'v100': 0.831},
|
||||
(8192, 8192, 8192 ) : {'v100': 0.841},
|
||||
# tall-skinny
|
||||
(16 , 1024, 1024 ) : {'v100': 0.0128},
|
||||
(16 , 4096, 4096 ) : {'v100': 0.0558},
|
||||
(16 , 8192, 8192 ) : {'v100': 0.101},
|
||||
(64 , 1024, 1024 ) : {'v100': 0.049},
|
||||
(64 , 4096, 4096 ) : {'v100': 0.211},
|
||||
(64 , 8192, 8192 ) : {'v100': 0.360},
|
||||
(1024, 64 , 1024 ) : {'v100': 0.0469},
|
||||
(4096, 64 , 4096 ) : {'v100': 0.198},
|
||||
(8192, 64 , 8192 ) : {'v100': 0.323},
|
||||
# # deep reductions
|
||||
# (64 , 64 , 16384) : {'v100': 0.},
|
||||
# (64 , 64 , 65536) : {'v100': 0.},
|
||||
# (256 , 256 , 8192 ) : {'v100': 0.},
|
||||
# (256 , 256 , 32768) : {'v100': 0.},
|
||||
}
|
||||
@pytest.mark.parametrize('M, N, K', matmul_data.keys())
|
||||
def test_matmul(M, N, K):
|
||||
ref_gpu_util = matmul_data[(M, N, K)]['v100']
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
ref_sm_clock = 1350
|
||||
max_gpu_perf = 1e-6*80*8*128*cur_sm_clock
|
||||
assert cur_sm_clock == ref_sm_clock, f'GPU SMs must run at {ref_sm_clock} MHz'
|
||||
a = torch.randn((M, K), dtype=torch.float16, device='cuda')
|
||||
b = torch.randn((K, N), dtype=torch.float16, device='cuda')
|
||||
fn = lambda: triton.ops.matmul(a, b)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=1000)
|
||||
cur_gpu_perf = 2.*M*N*K/ms * 1e-9
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
|
||||
#######################
|
||||
# Element-Wise
|
||||
#######################
|
||||
import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def _add(x_ptr, y_ptr, output_ptr, n_elements, **meta):
|
||||
BLOCK_SIZE = meta['BLOCK_SIZE']
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
elementwise_data = {
|
||||
1024*16 : {'v100': 0.0219},
|
||||
1024*64 : {'v100': 0.0791},
|
||||
1024*256 : {'v100': 0.243},
|
||||
1024*1024 : {'v100': 0.534},
|
||||
1024*4096 : {'v100': 0.796},
|
||||
1024*16384: {'v100': 0.905},
|
||||
1024*65536: {'v100': 0.939},
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize('N', elementwise_data.keys())
|
||||
def test_elementwise(N):
|
||||
ref_gpu_util = elementwise_data[N]['v100']
|
||||
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
|
||||
ref_mem_clock = 877
|
||||
max_gpu_perf = 512*2*ref_mem_clock*1e-3
|
||||
assert cur_mem_clock == ref_mem_clock, f'GPU memmory must run at {ref_mem_clock} MHz'
|
||||
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
|
||||
x = torch.randn_like(z)
|
||||
y = torch.randn_like(z)
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
|
||||
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250)
|
||||
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
|
@@ -110,7 +110,7 @@ def nvsmi(attrs):
|
||||
ret = [int(x) for x in ret]
|
||||
return ret
|
||||
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8], record_clocks=False):
|
||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
|
||||
"""
|
||||
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
||||
the 20-th and 80-th performance percentile.
|
||||
@@ -146,7 +146,6 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8],
|
||||
# doesn't contain any input data before the run
|
||||
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
|
||||
clocks = [None for i in range(n_repeat)]
|
||||
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
||||
# Warm-up
|
||||
for _ in range(n_warmup):
|
||||
@@ -168,12 +167,11 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8],
|
||||
# Record clocks
|
||||
torch.cuda.synchronize()
|
||||
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
|
||||
med_ms = torch.median(times).item()
|
||||
if percentiles:
|
||||
percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist()
|
||||
return tuple([med_ms] + percentiles)
|
||||
return tuple(percentiles)
|
||||
else:
|
||||
return med_ms
|
||||
return torch.mean(times).item()
|
||||
|
||||
|
||||
class Benchmark:
|
||||
|
Reference in New Issue
Block a user