[FRONTEND] Make the performance model work for int8, tf32, and fp32 (#456)

This commit is contained in:
daadaada
2022-02-12 14:34:42 +08:00
committed by GitHub
parent 9b100302d3
commit a9dfdcaaa9
5 changed files with 201 additions and 69 deletions

View File

@@ -6,6 +6,9 @@ import torch
import triton
import triton.language as tl
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
DEVICE_NAME = 'v100'
#######################
# Utilities
@@ -25,42 +28,76 @@ def nvsmi(attrs):
# Matrix Multiplication
#######################
sm_clocks = {'v100': 1350, 'a100': 1350}
mem_clocks = {'v100': 877, 'a100': 1215}
matmul_data = {
# square
(256, 256, 256): {'v100': 0.027},
(512, 512, 512): {'v100': 0.158},
(1024, 1024, 1024): {'v100': 0.466},
(2048, 2048, 2048): {'v100': 0.680},
(4096, 4096, 4096): {'v100': 0.831},
(8192, 8192, 8192): {'v100': 0.849},
# tall-skinny
(16, 1024, 1024): {'v100': 0.0128},
(16, 4096, 4096): {'v100': 0.0883},
(16, 8192, 8192): {'v100': 0.101},
(64, 1024, 1024): {'v100': 0.073},
(64, 4096, 4096): {'v100': 0.270},
(64, 8192, 8192): {'v100': 0.360},
(1024, 64, 1024): {'v100': 0.0692},
(4096, 64, 4096): {'v100': 0.264},
(8192, 64, 8192): {'v100': 0.323},
'v100': {
# square
(256, 256, 256): {'float16': 0.027},
(512, 512, 512): {'float16': 0.158},
(1024, 1024, 1024): {'float16': 0.466},
(2048, 2048, 2048): {'float16': 0.680},
(4096, 4096, 4096): {'float16': 0.831},
(8192, 8192, 8192): {'float16': 0.849},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0128},
(16, 4096, 4096): {'float16': 0.0883},
(16, 8192, 8192): {'float16': 0.101},
(64, 1024, 1024): {'float16': 0.073},
(64, 4096, 4096): {'float16': 0.270},
(64, 8192, 8192): {'float16': 0.459},
(1024, 64, 1024): {'float16': 0.0692},
(4096, 64, 4096): {'float16': 0.264},
(8192, 64, 8192): {'float16': 0.452},
},
'a100': {
(256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006},
(512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030},
(1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385},
(4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711},
(8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259},
(16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431},
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
(64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177},
}
# # deep reductions
# (64 , 64 , 16384) : {'v100': 0.},
# (64 , 64 , 65536) : {'v100': 0.},
# (256 , 256 , 8192 ) : {'v100': 0.},
# (256 , 256 , 32768) : {'v100': 0.},
# (64 , 64 , 16384) : {'a100': 0.},
# (64 , 64 , 65536) : {'a100': 0.},
# (256 , 256 , 8192 ) : {'a100': 0.},
# (256 , 256 , 32768) : {'a100': 0.},
}
@pytest.mark.parametrize('M, N, K', matmul_data.keys())
def test_matmul(M, N, K):
@pytest.mark.parametrize('M, N, K, dtype_str',
[(M, N, K, dtype_str)
for M, N, K in matmul_data[DEVICE_NAME].keys()
for dtype_str in ['float16']])
def test_matmul(M, N, K, dtype_str):
if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100':
pytest.skip('Only test float32 & int8 on a100')
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
torch.manual_seed(0)
ref_gpu_util = matmul_data[(M, N, K)]['v100']
ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
ref_sm_clock = 1350
max_gpu_perf = 1e-6 * 80 * 8 * 128 * cur_sm_clock
ref_sm_clock = sm_clocks[DEVICE_NAME]
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
a = torch.randn((M, K), dtype=torch.float16, device='cuda')
b = torch.randn((K, N), dtype=torch.float16, device='cuda')
if dtype == torch.int8:
a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda')
b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda')
b = b.t() # only test row-col layout
else:
a = torch.randn((M, K), dtype=dtype, device='cuda')
b = torch.randn((K, N), dtype=dtype, device='cuda')
fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
@@ -87,23 +124,34 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,
elementwise_data = {
1024 * 16: {'v100': 0.0219},
1024 * 64: {'v100': 0.0791},
1024 * 256: {'v100': 0.243},
1024 * 1024: {'v100': 0.534},
1024 * 4096: {'v100': 0.796},
1024 * 16384: {'v100': 0.905},
1024 * 65536: {'v100': 0.939},
'v100': {
1024 * 16: 0.0219,
1024 * 64: 0.0791,
1024 * 256: 0.243,
1024 * 1024: 0.534,
1024 * 4096: 0.796,
1024 * 16384: 0.905,
1024 * 65536: 0.939,
},
'a100': {
1024 * 16: 0.008,
1024 * 64: 0.034,
1024 * 256: 0.114,
1024 * 1024: 0.315,
1024 * 4096: 0.580,
1024 * 16384: 0.782,
1024 * 65536: 0.850,
}
}
@pytest.mark.parametrize('N', elementwise_data.keys())
@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys())
def test_elementwise(N):
torch.manual_seed(0)
ref_gpu_util = elementwise_data[N]['v100']
ref_gpu_util = elementwise_data[DEVICE_NAME][N]
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
ref_mem_clock = 877
max_gpu_perf = 512 * 2 * ref_mem_clock * 1e-3
ref_mem_clock = mem_clocks[DEVICE_NAME]
max_gpu_perf = get_dram_gbps()
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz'
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z)

View File

@@ -811,12 +811,12 @@ class Autotuner:
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'prune_num_stages_by' in prune_configs_by:
prune_num_stages_by = prune_configs_by['prune_num_stages_by']
if 'early_config_prune' in prune_configs_by:
early_config_prune = prune_configs_by['early_config_prune']
else:
perf_model, top_k, prune_num_stages_by = None, None, None
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.prune_num_stages_by = prune_num_stages_by
self.early_config_prune = early_config_prune
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
@@ -844,8 +844,8 @@ class Autotuner:
if key not in self.cache:
# prune configs
pruned_configs = self.configs
if self.prune_num_stages_by:
pruned_configs = self.prune_num_stages_by(self.configs, self.nargs)
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
@@ -1096,7 +1096,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""

View File

@@ -2,7 +2,7 @@ import torch
import triton
import triton.language as tl
from .matmul_perf_model import estimate_matmul_time, prune_num_stages
from .matmul_perf_model import early_config_prune, estimate_matmul_time
def init_to_zero(name):
@@ -27,7 +27,7 @@ def get_configs_io_bound():
@triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.autotune(
configs=[
@@ -41,10 +41,20 @@ def get_configs_io_bound():
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'prune_num_stages_by': prune_num_stages,
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
@@ -55,7 +65,9 @@ def _kernel(A, B, C, M, N, K,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
@@ -76,7 +88,7 @@ def _kernel(A, B, C, M, N, K,
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K * SPLIT_K):
if EVEN_K:
a = tl.load(A)
@@ -119,13 +131,15 @@ class _matmul(torch.autograd.Function):
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=a.dtype)
# accumulator types
ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8)
GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c
@staticmethod

View File

@@ -4,20 +4,36 @@ import torch
import triton
import triton._C.libtriton.triton as _triton
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device)
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
return tflops
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
return tflops
def get_tflops(backend, device, num_ctas, num_warps, dtype):
cc = _triton.runtime.cc(backend, device)
if cc < 80 and dtype == torch.float32:
return get_simd_tflops()
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
def estimate_matmul_time(
# backend, device,
num_warps, num_stages,
A, B, C,
M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
debug=False, **kwargs
@@ -26,6 +42,8 @@ def estimate_matmul_time(
= max(compute, loading) + store '''
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
dtype = A.dtype
dtsize = A.element_size()
num_cta_m = triton.cdiv(M, BLOCK_M)
num_cta_n = triton.cdiv(N, BLOCK_N)
@@ -37,7 +55,7 @@ def estimate_matmul_time(
# time to compute
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps)
tput = get_tflops(backend, device, num_ctas, num_warps, dtype)
compute_ms = total_ops / tput
# time to load data
@@ -48,10 +66,10 @@ def estimate_matmul_time(
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
# assume 80% of (following) loads are in L2 cache
load_a_dram = M * K * 2 * (1 + 0.2 * (num_cta_n - 1)) # assume dtype=float16 (size==2)
load_a_l2 = M * K * 2 * 0.8 * (num_cta_n - 1)
load_b_dram = N * K * 2 * (1 + 0.2 * (num_cta_m - 1))
load_b_l2 = N * K * 2 * 0.8 * (num_cta_m - 1)
load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
# total
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
@@ -60,7 +78,7 @@ def estimate_matmul_time(
# estimate storing time
store_bw = dram_bw * 0.6 # :o
store_c_dram = M * N * 2 * SPLIT_K / (1024 * 1024) # MB
store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB
if SPLIT_K == 1:
store_ms = store_c_dram / store_bw
else:
@@ -78,14 +96,28 @@ def estimate_matmul_time(
return total_time_ms
def prune_num_stages(configs, named_args):
def early_config_prune(configs, named_args):
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
cc = _triton.runtime.cc(backend, device)
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
dtsize = named_args['A'].element_size()
dtype = named_args['A'].dtype
# 1. make sure we have enough smem
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory <= max_shared_memory:
pruned_configs.append(config)
configs = pruned_configs
# Some dtypes do not allow atomic_add
if named_args['A'].dtype == torch.bfloat16:
if dtype not in [torch.float16, torch.float32]:
configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1]
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)

View File

@@ -330,18 +330,56 @@ def get_dram_gbps(backend=None, device=None):
device = torch.cuda.current_device()
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
return bw_gbps
def get_max_tensorcore_tflops(backend, device):
def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clock_rate=None):
if not backend:
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
# assume fp32 += fp16*fp16
if not clock_rate:
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
cc = _triton.runtime.cc(backend, device)
if cc < 80:
assert dtype == torch.float16
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
else:
ops_per_sub_core = 512
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024 * 1024 * 1024)
if dtype == torch.float32:
ops_per_sub_core = 256
elif dtype in [torch.float16, torch.bfloat16]:
ops_per_sub_core = 512
elif dtype == torch.int8:
ops_per_sub_core = 1024
else:
raise RuntimeError("dtype not supported")
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
return tflops
def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None):
if not backend:
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
cc = _triton.runtime.cc(backend, device)
if cc < 80:
if dtype == torch.float32:
ops_per_sub_core = 32 # 2*16
elif dtype == torch.float16:
ops_per_sub_core = 64
else:
raise RuntimeError("dtype not supported")
else:
if dtype == torch.float32:
ops_per_sub_core = 32
elif dtype in [torch.float16, torch.bfloat16]:
ops_per_sub_core = 64
else:
raise RuntimeError("dtype not supported")
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
return tflops