[OPS] Add performance model for gemm/gemv (#397)

Significantly improves the performance of `triton.ops.matmul` in memory-bound settings via the use of many more block configs coupled with a performance model to drive the auto-tuning process.
This commit is contained in:
daadaada
2021-12-22 01:56:10 +08:00
committed by GitHub
parent 5cdb948c05
commit 39d4bfed83
12 changed files with 289 additions and 27 deletions

View File

@@ -292,6 +292,16 @@ void init_triton_runtime(py::module &&m) {
return bin;
});
m.def("cc", [](backend_t backend, uint64_t device) -> int {
if (backend == CUDA) {
CUdevice dev = (CUdevice)device;
int major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
int minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
return major*10 + minor;
}
return -1;
});
// query maximum shared memory
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
if (backend == HOST)
@@ -303,6 +313,31 @@ void init_triton_runtime(py::module &&m) {
return -1;
});
// query DRAM & L2 cache
m.def("memory_clock_rate", [](backend_t backend, uint64_t device) {
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE>(device);
return -1;
});
m.def("global_memory_bus_width", [](backend_t backend, uint64_t device) {
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH>(device);
return -1;
});
m.def("l2_cache_size", [](backend_t backend, uint64_t device) {
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE>(device);
return -1;
});
// query clock rate (in kilohertz)
m.def("clock_rate", [](backend_t backend, uint64_t device) {
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_CLOCK_RATE>(device);
return -1;
});
m.def("num_sm", [](backend_t backend, uint64_t device) {
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT>(device);
return -1;
});
// enqueue
m.def("enqueue", [](backend_t backend, uint64_t stream, uint64_t kernel,
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,

View File

@@ -25,7 +25,7 @@ def nvsmi(attrs):
matmul_data = {
# square
(256 , 256 , 256 ) : {'v100': 0.027},
(512 , 512 , 512 ) : {'v100': 0.141},
(512 , 512 , 512 ) : {'v100': 0.158},
(1024, 1024, 1024 ) : {'v100': 0.466},
(2048, 2048, 2048 ) : {'v100': 0.680},
(4096, 4096, 4096 ) : {'v100': 0.831},
@@ -35,10 +35,10 @@ matmul_data = {
(16 , 4096, 4096 ) : {'v100': 0.0883},
(16 , 8192, 8192 ) : {'v100': 0.101},
(64 , 1024, 1024 ) : {'v100': 0.073},
(64 , 4096, 4096 ) : {'v100': 0.228},
(64 , 4096, 4096 ) : {'v100': 0.270},
(64 , 8192, 8192 ) : {'v100': 0.360},
(1024, 64 , 1024 ) : {'v100': 0.0692},
(4096, 64 , 4096 ) : {'v100': 0.223},
(4096, 64 , 4096 ) : {'v100': 0.264},
(8192, 64 , 8192 ) : {'v100': 0.323},
# # deep reductions
# (64 , 64 , 16384) : {'v100': 0.},

View File

@@ -17,6 +17,10 @@ import triton
import triton._C.libtriton.triton as _triton
from filelock import FileLock
import dbm
import tempfile
from typing import Optional, Dict
import time
class CodeGenerator(ast.NodeVisitor):
@@ -508,6 +512,7 @@ class LoadedBinary:
device)
self.bin = bin
self.asm = bin.asm
self.sass = ''
self.module = module
self.kernel = kernel
self.device = device
@@ -519,6 +524,19 @@ class LoadedBinary:
self.bin.num_warps * 32, 1, 1,
args, self.bin.shared_mem)
def get_sass(self, fun=None):
if self.sass:
return self.sass
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(self.asm['cubin'])
self.sass = extract(path, fun)
finally:
os.remove(path)
self.asm['sass'] = self.sass
return self.sass
class CompilationError(Exception):
def __init__(self, src, node):
@@ -530,8 +548,8 @@ class CompilationError(Exception):
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}'\
f'Required: {required}'\
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
super().__init__(self.message)
self.args = (required, limit, name)
@@ -727,7 +745,13 @@ class Launcher:
class Autotuner:
def __init__(self, kernel, arg_names, configs, key, reset_to_zero):
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'''
if not configs:
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
else:
@@ -744,7 +768,16 @@ class Autotuner:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'prune_num_stages_by' in prune_configs_by:
prune_num_stages_by = prune_configs_by['prune_num_stages_by']
else:
perf_model, top_k, prune_num_stages_by = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.prune_num_stages_by = prune_num_stages_by
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
@@ -768,13 +801,29 @@ class Autotuner:
if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx])
if key not in self.cache:
# prune configs
pruned_configs = self.configs
if self.prune_num_stages_by:
pruned_configs = self.prune_num_stages_by(self.configs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k]
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) \
for config in self.configs}
for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook != None:
config.pre_hook(self.nargs)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
@@ -832,6 +881,8 @@ class DependenciesFinder(ast.NodeVisitor):
module = inspect.getmodule(func)
if module and module.__name__.startswith('triton.'):
return
if inspect.isbuiltin(func):
return
if not hasattr(func, 'hash'):
src = textwrap.dedent(inspect.getsource(func))
tree = ast.parse(src)
@@ -957,8 +1008,16 @@ class Config:
self.num_stages = num_stages
self.pre_hook = pre_hook
def __str__(self):
res = []
for k, v in self.kwargs.items():
res.append(f'{k}: {v}')
res.append(f'num_warps: {self.num_warps}')
res.append(f'num_stages: {self.num_stages}')
return ', '.join(res)
def autotune(configs, key, reset_to_zero=None):
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
@@ -985,12 +1044,16 @@ def autotune(configs, key, reset_to_zero=None):
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
def wrapper(kernel):
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
fn.kernel_decorators.append(wrapper)
return fn
@@ -1023,7 +1086,6 @@ def heuristics(values):
assert v not in meta
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
return kernel(*args, **meta)
return fun
fn.kernel_decorators.append(wrapper)

View File

@@ -1,15 +1,33 @@
import torch
import triton.language as tl
import triton
from .matmul_perf_model import *
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
})
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
@@ -19,17 +37,13 @@ def init_to_zero(name):
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
],
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'prune_num_stages_by' : prune_num_stages,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.jit
def _kernel(A, B, C, M, N, K,

View File

@@ -0,0 +1,116 @@
import torch
import triton
import triton._C.libtriton.triton as _triton
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
import heapq
def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
tflops = min(num_subcores, total_warps)/num_subcores * get_max_tensorcore_tflops(backend, device)
return tflops
def estimate_matmul_time(
# backend, device,
num_warps, num_stages,
M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
debug=False, **kwargs
):
''' return estimated running time in ms
= max(compute, loading) + store '''
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
num_cta_m = triton.cdiv(M, BLOCK_M)
num_cta_n = triton.cdiv(N, BLOCK_N)
num_cta_k = SPLIT_K
num_ctas = num_cta_m * num_cta_n * num_cta_k
# If the input is smaller than the block size
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
# time to compute
total_ops = 2*M*N*K / (1024*1024*1024) # GOPS
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps)
compute_ms = total_ops / tput
# time to load data
num_sm = _triton.runtime.num_sm(backend, device)
active_cta_ratio = min(1, num_ctas/num_sm)
active_cta_ratio_bw1 = min(1, num_ctas/32) # 32 active ctas are enough to saturate
active_cta_ratio_bw2 = max(min(1, (num_ctas-32)/(108-32)), 0) # 32-108, remaining 5%
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1*0.95 + active_cta_ratio_bw2*0.05) # in GB/s
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
# assume 80% of (following) loads are in L2 cache
load_a_dram = M*K*2*(1+0.2*(num_cta_n-1)) # assume dtype=float16 (size==2)
load_a_l2 = M*K*2*0.8*(num_cta_n-1)
load_b_dram = N*K*2*(1+0.2*(num_cta_m-1))
load_b_l2 = N*K*2*0.8*(num_cta_m-1)
# total
total_dram = (load_a_dram + load_b_dram) / (1024*1024) # MB
total_l2 = (load_a_l2 + load_b_l2) / (1024*1024)
# loading time in ms
load_ms = total_dram/dram_bw + total_l2/l2_bw
# estimate storing time
store_bw = dram_bw * 0.6 # :o
store_c_dram = M*N*2*SPLIT_K / (1024*1024) # MB
if SPLIT_K == 1:
store_ms = store_c_dram /store_bw
else:
reduce_bw = store_bw
store_ms = store_c_dram/reduce_bw
# c.zero_()
zero_ms = M*N*2/(1024*1024)/store_bw
store_ms += zero_ms
total_time_ms = max(compute_ms, load_ms) + store_ms
if debug:
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
f'Activate CTAs: {active_cta_ratio*100}%')
return total_time_ms
def prune_num_stages(configs):
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
cc = _triton.runtime.cc(backend, device)
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
configs_map = {}
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
if key in configs_map:
configs_map[key].append((config, num_stages))
else:
configs_map[key] = [(config, num_stages)]
pruned_configs = []
for k, v in configs_map.items():
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
if cc >= 80:
# compute cycles (only works for ampere GPUs)
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16*8*16)
mma_cycles = mmas/min(4, num_warps) * 8
ldgsts_latency = 300 # Does this matter?
optimal_num_stages = ldgsts_latency/mma_cycles
# nearest stages, prefer large #stages
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
for n in nearest:
pruned_configs.append(n[0])
else: # Volta & Turing only supports num_stages <= 2
random_config = v[0][0]
random_config.num_stages = 2
pruned_configs.append(random_config)
return pruned_configs

View File

@@ -1,5 +1,6 @@
import torch
import os
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources
import subprocess
import sys
@@ -320,3 +321,27 @@ def perf_report(benchmarks):
"""
wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper
def get_dram_gbps(backend=None, device=None):
''' return DRAM bandwidth in GB/s '''
# assert backend == CUDA
if not backend:
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
return bw_gbps
def get_max_tensorcore_tflops(backend, device):
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
# assume fp32 += fp16*fp16
cc = _triton.runtime.cc(backend, device)
if cc < 80:
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
else:
ops_per_sub_core = 512
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024)
return tflops