[OPS] Add performance model for gemm/gemv (#397)

Significantly improves the performance of `triton.ops.matmul` in memory-bound settings via the use of many more block configs coupled with a performance model to drive the auto-tuning process.
This commit is contained in:
daadaada
2021-12-22 01:56:10 +08:00
committed by GitHub
parent 5cdb948c05
commit 39d4bfed83
12 changed files with 289 additions and 27 deletions

View File

@@ -455,7 +455,7 @@ public:
// masked load async // masked load async
class masked_load_async_inst: public load_inst { class masked_load_async_inst: public load_inst {
private: private:
std::string repr_impl() const { return "masked_load_async_async" + get_cache_modifier_repr(); } std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next); const std::string &name, instruction *next);
@@ -728,12 +728,21 @@ public:
class dot_inst: public builtin_inst { class dot_inst: public builtin_inst {
public: public:
enum TransT { NoTrans, Trans }; enum TransT { NoTrans, Trans };
enum DataType {
FP8, FP16, BF16, TF32, FP32,
INT1, INT4, INT8, INT32,
UNKNOWN,
};
private: private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next); dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; } std::string repr_impl() const { return "dot"; }
bool is_prefetched_ = false; bool is_prefetched_ = false;
DataType C_type_ = DataType::FP32;
DataType A_type_ = DataType::FP16;
DataType B_type_ = DataType::FP16;
public: public:
bool is_prefetched() const { return is_prefetched_; } bool is_prefetched() const { return is_prefetched_; }
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; } void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }

View File

@@ -85,6 +85,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
allocation.run(ir); allocation.run(ir);
prefetch_s.run(ir); prefetch_s.run(ir);
barriers.run(ir); barriers.run(ir);
// ir.print(std::cout);
isel.visit(ir, *llvm); isel.visit(ir, *llvm);
shared_static = allocation.allocated_size(); shared_static = allocation.allocated_size();
return llvm; return llvm;

View File

@@ -292,6 +292,16 @@ void init_triton_runtime(py::module &&m) {
return bin; return bin;
}); });
m.def("cc", [](backend_t backend, uint64_t device) -> int {
if (backend == CUDA) {
CUdevice dev = (CUdevice)device;
int major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
int minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
return major*10 + minor;
}
return -1;
});
// query maximum shared memory // query maximum shared memory
m.def("max_shared_memory", [](backend_t backend, uint64_t device) { m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
if (backend == HOST) if (backend == HOST)
@@ -303,6 +313,31 @@ void init_triton_runtime(py::module &&m) {
return -1; return -1;
}); });
// query DRAM & L2 cache
m.def("memory_clock_rate", [](backend_t backend, uint64_t device) {
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE>(device);
return -1;
});
m.def("global_memory_bus_width", [](backend_t backend, uint64_t device) {
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH>(device);
return -1;
});
m.def("l2_cache_size", [](backend_t backend, uint64_t device) {
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE>(device);
return -1;
});
// query clock rate (in kilohertz)
m.def("clock_rate", [](backend_t backend, uint64_t device) {
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_CLOCK_RATE>(device);
return -1;
});
m.def("num_sm", [](backend_t backend, uint64_t device) {
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT>(device);
return -1;
});
// enqueue // enqueue
m.def("enqueue", [](backend_t backend, uint64_t stream, uint64_t kernel, m.def("enqueue", [](backend_t backend, uint64_t stream, uint64_t kernel,
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2, uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,

View File

@@ -25,7 +25,7 @@ def nvsmi(attrs):
matmul_data = { matmul_data = {
# square # square
(256 , 256 , 256 ) : {'v100': 0.027}, (256 , 256 , 256 ) : {'v100': 0.027},
(512 , 512 , 512 ) : {'v100': 0.141}, (512 , 512 , 512 ) : {'v100': 0.158},
(1024, 1024, 1024 ) : {'v100': 0.466}, (1024, 1024, 1024 ) : {'v100': 0.466},
(2048, 2048, 2048 ) : {'v100': 0.680}, (2048, 2048, 2048 ) : {'v100': 0.680},
(4096, 4096, 4096 ) : {'v100': 0.831}, (4096, 4096, 4096 ) : {'v100': 0.831},
@@ -35,10 +35,10 @@ matmul_data = {
(16 , 4096, 4096 ) : {'v100': 0.0883}, (16 , 4096, 4096 ) : {'v100': 0.0883},
(16 , 8192, 8192 ) : {'v100': 0.101}, (16 , 8192, 8192 ) : {'v100': 0.101},
(64 , 1024, 1024 ) : {'v100': 0.073}, (64 , 1024, 1024 ) : {'v100': 0.073},
(64 , 4096, 4096 ) : {'v100': 0.228}, (64 , 4096, 4096 ) : {'v100': 0.270},
(64 , 8192, 8192 ) : {'v100': 0.360}, (64 , 8192, 8192 ) : {'v100': 0.360},
(1024, 64 , 1024 ) : {'v100': 0.0692}, (1024, 64 , 1024 ) : {'v100': 0.0692},
(4096, 64 , 4096 ) : {'v100': 0.223}, (4096, 64 , 4096 ) : {'v100': 0.264},
(8192, 64 , 8192 ) : {'v100': 0.323}, (8192, 64 , 8192 ) : {'v100': 0.323},
# # deep reductions # # deep reductions
# (64 , 64 , 16384) : {'v100': 0.}, # (64 , 64 , 16384) : {'v100': 0.},

View File

@@ -17,6 +17,10 @@ import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from filelock import FileLock from filelock import FileLock
import dbm import dbm
import tempfile
from typing import Optional, Dict
import time
class CodeGenerator(ast.NodeVisitor): class CodeGenerator(ast.NodeVisitor):
@@ -508,6 +512,7 @@ class LoadedBinary:
device) device)
self.bin = bin self.bin = bin
self.asm = bin.asm self.asm = bin.asm
self.sass = ''
self.module = module self.module = module
self.kernel = kernel self.kernel = kernel
self.device = device self.device = device
@@ -519,6 +524,19 @@ class LoadedBinary:
self.bin.num_warps * 32, 1, 1, self.bin.num_warps * 32, 1, 1,
args, self.bin.shared_mem) args, self.bin.shared_mem)
def get_sass(self, fun=None):
if self.sass:
return self.sass
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(self.asm['cubin'])
self.sass = extract(path, fun)
finally:
os.remove(path)
self.asm['sass'] = self.sass
return self.sass
class CompilationError(Exception): class CompilationError(Exception):
def __init__(self, src, node): def __init__(self, src, node):
@@ -530,8 +548,8 @@ class CompilationError(Exception):
class OutOfResources(Exception): class OutOfResources(Exception):
def __init__(self, required, limit, name): def __init__(self, required, limit, name):
self.message = f'out of resource: {name}'\ self.message = f'out of resource: {name}, '\
f'Required: {required}'\ f'Required: {required}, '\
f'Hardware limit: {limit}' f'Hardware limit: {limit}'
super().__init__(self.message) super().__init__(self.message)
self.args = (required, limit, name) self.args = (required, limit, name)
@@ -727,7 +745,13 @@ class Launcher:
class Autotuner: class Autotuner:
def __init__(self, kernel, arg_names, configs, key, reset_to_zero): def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'''
if not configs: if not configs:
self.configs = [Config(dict(), num_warps=4, num_stages=2)] self.configs = [Config(dict(), num_warps=4, num_stages=2)]
else: else:
@@ -744,6 +768,15 @@ class Autotuner:
args[i].zero_() args[i].zero_()
self.hook = _hook self.hook = _hook
self.arg_names = arg_names self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'prune_num_stages_by' in prune_configs_by:
prune_num_stages_by = prune_configs_by['prune_num_stages_by']
else:
perf_model, top_k, prune_num_stages_by = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.prune_num_stages_by = prune_num_stages_by
def _bench(self, *args, config, **meta): def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided # check for conflicts, i.e. meta-parameters both provided
@@ -768,13 +801,29 @@ class Autotuner:
if len(self.configs) > 1: if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx]) key = tuple([args[i] for i in self.key_idx])
if key not in self.cache: if key not in self.cache:
# prune configs
pruned_configs = self.configs
if self.prune_num_stages_by:
pruned_configs = self.prune_num_stages_by(self.configs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k]
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) \ timings = {config: self._bench(*args, config=config, **kwargs) \
for config in self.configs} for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get) self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args) self.hook(args)
self.configs_timings = timings
config = self.cache[key] config = self.cache[key]
else: else:
config = self.configs[0] config = self.configs[0]
self.best_config = config
if config.pre_hook != None: if config.pre_hook != None:
config.pre_hook(self.nargs) config.pre_hook(self.nargs)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
@@ -832,6 +881,8 @@ class DependenciesFinder(ast.NodeVisitor):
module = inspect.getmodule(func) module = inspect.getmodule(func)
if module and module.__name__.startswith('triton.'): if module and module.__name__.startswith('triton.'):
return return
if inspect.isbuiltin(func):
return
if not hasattr(func, 'hash'): if not hasattr(func, 'hash'):
src = textwrap.dedent(inspect.getsource(func)) src = textwrap.dedent(inspect.getsource(func))
tree = ast.parse(src) tree = ast.parse(src)
@@ -957,8 +1008,16 @@ class Config:
self.num_stages = num_stages self.num_stages = num_stages
self.pre_hook = pre_hook self.pre_hook = pre_hook
def __str__(self):
res = []
for k, v in self.kwargs.items():
res.append(f'{k}: {v}')
res.append(f'num_warps: {self.num_warps}')
res.append(f'num_stages: {self.num_stages}')
return ', '.join(res)
def autotune(configs, key, reset_to_zero=None):
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
""" """
Decorator for auto-tuning a :code:`triton.jit`'d function. Decorator for auto-tuning a :code:`triton.jit`'d function.
@@ -985,12 +1044,16 @@ def autotune(configs, key, reset_to_zero=None):
:type configs: list[triton.Config] :type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str] :type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str] :type reset_to_zero: list[str]
""" """
def decorator(fn): def decorator(fn):
def wrapper(kernel): def wrapper(kernel):
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero) return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
fn.kernel_decorators.append(wrapper) fn.kernel_decorators.append(wrapper)
return fn return fn
@@ -1023,7 +1086,6 @@ def heuristics(values):
assert v not in meta assert v not in meta
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta}) meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
return kernel(*args, **meta) return kernel(*args, **meta)
return fun return fun
fn.kernel_decorators.append(wrapper) fn.kernel_decorators.append(wrapper)

View File

@@ -1,15 +1,33 @@
import torch import torch
import triton.language as tl import triton.language as tl
import triton import triton
from .matmul_perf_model import *
def init_to_zero(name): def init_to_zero(name):
return lambda nargs: nargs[name].zero_() return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.heuristics({ @triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0, 'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
}) })
@triton.autotune( @triton.autotune(
configs=[ configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
@@ -19,17 +37,13 @@ def init_to_zero(name):
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), ] + get_configs_io_bound(),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
],
key=['M', 'N', 'K'], key=['M', 'N', 'K'],
prune_configs_by={
'prune_num_stages_by' : prune_num_stages,
'perf_model': estimate_matmul_time,
'top_k': 10
},
) )
@triton.jit @triton.jit
def _kernel(A, B, C, M, N, K, def _kernel(A, B, C, M, N, K,

View File

@@ -0,0 +1,116 @@
import torch
import triton
import triton._C.libtriton.triton as _triton
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
import heapq
def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
tflops = min(num_subcores, total_warps)/num_subcores * get_max_tensorcore_tflops(backend, device)
return tflops
def estimate_matmul_time(
# backend, device,
num_warps, num_stages,
M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
debug=False, **kwargs
):
''' return estimated running time in ms
= max(compute, loading) + store '''
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
num_cta_m = triton.cdiv(M, BLOCK_M)
num_cta_n = triton.cdiv(N, BLOCK_N)
num_cta_k = SPLIT_K
num_ctas = num_cta_m * num_cta_n * num_cta_k
# If the input is smaller than the block size
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
# time to compute
total_ops = 2*M*N*K / (1024*1024*1024) # GOPS
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps)
compute_ms = total_ops / tput
# time to load data
num_sm = _triton.runtime.num_sm(backend, device)
active_cta_ratio = min(1, num_ctas/num_sm)
active_cta_ratio_bw1 = min(1, num_ctas/32) # 32 active ctas are enough to saturate
active_cta_ratio_bw2 = max(min(1, (num_ctas-32)/(108-32)), 0) # 32-108, remaining 5%
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1*0.95 + active_cta_ratio_bw2*0.05) # in GB/s
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
# assume 80% of (following) loads are in L2 cache
load_a_dram = M*K*2*(1+0.2*(num_cta_n-1)) # assume dtype=float16 (size==2)
load_a_l2 = M*K*2*0.8*(num_cta_n-1)
load_b_dram = N*K*2*(1+0.2*(num_cta_m-1))
load_b_l2 = N*K*2*0.8*(num_cta_m-1)
# total
total_dram = (load_a_dram + load_b_dram) / (1024*1024) # MB
total_l2 = (load_a_l2 + load_b_l2) / (1024*1024)
# loading time in ms
load_ms = total_dram/dram_bw + total_l2/l2_bw
# estimate storing time
store_bw = dram_bw * 0.6 # :o
store_c_dram = M*N*2*SPLIT_K / (1024*1024) # MB
if SPLIT_K == 1:
store_ms = store_c_dram /store_bw
else:
reduce_bw = store_bw
store_ms = store_c_dram/reduce_bw
# c.zero_()
zero_ms = M*N*2/(1024*1024)/store_bw
store_ms += zero_ms
total_time_ms = max(compute_ms, load_ms) + store_ms
if debug:
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
f'Activate CTAs: {active_cta_ratio*100}%')
return total_time_ms
def prune_num_stages(configs):
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
cc = _triton.runtime.cc(backend, device)
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
configs_map = {}
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
if key in configs_map:
configs_map[key].append((config, num_stages))
else:
configs_map[key] = [(config, num_stages)]
pruned_configs = []
for k, v in configs_map.items():
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
if cc >= 80:
# compute cycles (only works for ampere GPUs)
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16*8*16)
mma_cycles = mmas/min(4, num_warps) * 8
ldgsts_latency = 300 # Does this matter?
optimal_num_stages = ldgsts_latency/mma_cycles
# nearest stages, prefer large #stages
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
for n in nearest:
pruned_configs.append(n[0])
else: # Volta & Turing only supports num_stages <= 2
random_config = v[0][0]
random_config.num_stages = 2
pruned_configs.append(random_config)
return pruned_configs

View File

@@ -1,5 +1,6 @@
import torch import torch
import os import os
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources from .code_gen import OutOfResources
import subprocess import subprocess
import sys import sys
@@ -320,3 +321,27 @@ def perf_report(benchmarks):
""" """
wrapper = lambda fn: Mark(fn, benchmarks) wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper return wrapper
def get_dram_gbps(backend=None, device=None):
''' return DRAM bandwidth in GB/s '''
# assert backend == CUDA
if not backend:
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
return bw_gbps
def get_max_tensorcore_tflops(backend, device):
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
# assume fp32 += fp16*fp16
cc = _triton.runtime.cc(backend, device)
if cc < 80:
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
else:
ops_per_sub_core = 512
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024)
return tflops