[OPS] Add performance model for gemm/gemv (#397)
Significantly improves the performance of `triton.ops.matmul` in memory-bound settings via the use of many more block configs coupled with a performance model to drive the auto-tuning process.
This commit is contained in:
@@ -292,6 +292,16 @@ void init_triton_runtime(py::module &&m) {
|
||||
return bin;
|
||||
});
|
||||
|
||||
m.def("cc", [](backend_t backend, uint64_t device) -> int {
|
||||
if (backend == CUDA) {
|
||||
CUdevice dev = (CUdevice)device;
|
||||
int major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
||||
int minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
||||
return major*10 + minor;
|
||||
}
|
||||
return -1;
|
||||
});
|
||||
|
||||
// query maximum shared memory
|
||||
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
|
||||
if (backend == HOST)
|
||||
@@ -303,6 +313,31 @@ void init_triton_runtime(py::module &&m) {
|
||||
return -1;
|
||||
});
|
||||
|
||||
// query DRAM & L2 cache
|
||||
m.def("memory_clock_rate", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE>(device);
|
||||
return -1;
|
||||
});
|
||||
m.def("global_memory_bus_width", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH>(device);
|
||||
return -1;
|
||||
});
|
||||
m.def("l2_cache_size", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE>(device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
// query clock rate (in kilohertz)
|
||||
m.def("clock_rate", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_CLOCK_RATE>(device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
m.def("num_sm", [](backend_t backend, uint64_t device) {
|
||||
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT>(device);
|
||||
return -1;
|
||||
});
|
||||
|
||||
// enqueue
|
||||
m.def("enqueue", [](backend_t backend, uint64_t stream, uint64_t kernel,
|
||||
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
|
||||
|
@@ -25,7 +25,7 @@ def nvsmi(attrs):
|
||||
matmul_data = {
|
||||
# square
|
||||
(256 , 256 , 256 ) : {'v100': 0.027},
|
||||
(512 , 512 , 512 ) : {'v100': 0.141},
|
||||
(512 , 512 , 512 ) : {'v100': 0.158},
|
||||
(1024, 1024, 1024 ) : {'v100': 0.466},
|
||||
(2048, 2048, 2048 ) : {'v100': 0.680},
|
||||
(4096, 4096, 4096 ) : {'v100': 0.831},
|
||||
@@ -35,10 +35,10 @@ matmul_data = {
|
||||
(16 , 4096, 4096 ) : {'v100': 0.0883},
|
||||
(16 , 8192, 8192 ) : {'v100': 0.101},
|
||||
(64 , 1024, 1024 ) : {'v100': 0.073},
|
||||
(64 , 4096, 4096 ) : {'v100': 0.228},
|
||||
(64 , 4096, 4096 ) : {'v100': 0.270},
|
||||
(64 , 8192, 8192 ) : {'v100': 0.360},
|
||||
(1024, 64 , 1024 ) : {'v100': 0.0692},
|
||||
(4096, 64 , 4096 ) : {'v100': 0.223},
|
||||
(4096, 64 , 4096 ) : {'v100': 0.264},
|
||||
(8192, 64 , 8192 ) : {'v100': 0.323},
|
||||
# # deep reductions
|
||||
# (64 , 64 , 16384) : {'v100': 0.},
|
||||
|
@@ -17,6 +17,10 @@ import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from filelock import FileLock
|
||||
import dbm
|
||||
import tempfile
|
||||
from typing import Optional, Dict
|
||||
import time
|
||||
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
@@ -508,6 +512,7 @@ class LoadedBinary:
|
||||
device)
|
||||
self.bin = bin
|
||||
self.asm = bin.asm
|
||||
self.sass = ''
|
||||
self.module = module
|
||||
self.kernel = kernel
|
||||
self.device = device
|
||||
@@ -519,6 +524,19 @@ class LoadedBinary:
|
||||
self.bin.num_warps * 32, 1, 1,
|
||||
args, self.bin.shared_mem)
|
||||
|
||||
def get_sass(self, fun=None):
|
||||
if self.sass:
|
||||
return self.sass
|
||||
fd, path = tempfile.mkstemp()
|
||||
try:
|
||||
with open(fd, 'wb') as cubin:
|
||||
cubin.write(self.asm['cubin'])
|
||||
self.sass = extract(path, fun)
|
||||
finally:
|
||||
os.remove(path)
|
||||
self.asm['sass'] = self.sass
|
||||
return self.sass
|
||||
|
||||
|
||||
class CompilationError(Exception):
|
||||
def __init__(self, src, node):
|
||||
@@ -530,8 +548,8 @@ class CompilationError(Exception):
|
||||
|
||||
class OutOfResources(Exception):
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}'\
|
||||
f'Required: {required}'\
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
super().__init__(self.message)
|
||||
self.args = (required, limit, name)
|
||||
@@ -727,7 +745,13 @@ class Launcher:
|
||||
|
||||
|
||||
class Autotuner:
|
||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero):
|
||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None):
|
||||
'''
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
'''
|
||||
if not configs:
|
||||
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
|
||||
else:
|
||||
@@ -744,7 +768,16 @@ class Autotuner:
|
||||
args[i].zero_()
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
|
||||
# prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||
if 'prune_num_stages_by' in prune_configs_by:
|
||||
prune_num_stages_by = prune_configs_by['prune_num_stages_by']
|
||||
else:
|
||||
perf_model, top_k, prune_num_stages_by = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.prune_num_stages_by = prune_num_stages_by
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
# as kwargs and by the autotuner
|
||||
@@ -768,13 +801,29 @@ class Autotuner:
|
||||
if len(self.configs) > 1:
|
||||
key = tuple([args[i] for i in self.key_idx])
|
||||
if key not in self.cache:
|
||||
# prune configs
|
||||
pruned_configs = self.configs
|
||||
if self.prune_num_stages_by:
|
||||
pruned_configs = self.prune_num_stages_by(self.configs)
|
||||
if self.perf_model:
|
||||
top_k = self.configs_top_k
|
||||
if isinstance(top_k, float) and top_k <= 1.0:
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k]
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs) \
|
||||
for config in self.configs}
|
||||
for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
self.hook(args)
|
||||
self.configs_timings = timings
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
self.best_config = config
|
||||
if config.pre_hook != None:
|
||||
config.pre_hook(self.nargs)
|
||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
@@ -832,6 +881,8 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
module = inspect.getmodule(func)
|
||||
if module and module.__name__.startswith('triton.'):
|
||||
return
|
||||
if inspect.isbuiltin(func):
|
||||
return
|
||||
if not hasattr(func, 'hash'):
|
||||
src = textwrap.dedent(inspect.getsource(func))
|
||||
tree = ast.parse(src)
|
||||
@@ -957,8 +1008,16 @@ class Config:
|
||||
self.num_stages = num_stages
|
||||
self.pre_hook = pre_hook
|
||||
|
||||
def __str__(self):
|
||||
res = []
|
||||
for k, v in self.kwargs.items():
|
||||
res.append(f'{k}: {v}')
|
||||
res.append(f'num_warps: {self.num_warps}')
|
||||
res.append(f'num_stages: {self.num_stages}')
|
||||
return ', '.join(res)
|
||||
|
||||
def autotune(configs, key, reset_to_zero=None):
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
@@ -985,12 +1044,16 @@ def autotune(configs, key, reset_to_zero=None):
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
:type key: list[str]
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
def decorator(fn):
|
||||
def wrapper(kernel):
|
||||
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
|
||||
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
|
||||
|
||||
fn.kernel_decorators.append(wrapper)
|
||||
return fn
|
||||
@@ -1023,7 +1086,6 @@ def heuristics(values):
|
||||
assert v not in meta
|
||||
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
|
||||
return kernel(*args, **meta)
|
||||
|
||||
return fun
|
||||
|
||||
fn.kernel_decorators.append(wrapper)
|
||||
|
@@ -1,15 +1,33 @@
|
||||
import torch
|
||||
import triton.language as tl
|
||||
import triton
|
||||
from .matmul_perf_model import *
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
for block_m in [16, 32]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
@@ -19,17 +37,13 @@ def init_to_zero(name):
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
],
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'prune_num_stages_by' : prune_num_stages,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel(A, B, C, M, N, K,
|
||||
|
116
python/triton/ops/matmul_perf_model.py
Normal file
116
python/triton/ops/matmul_perf_model.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
||||
import heapq
|
||||
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps)/num_subcores * get_max_tensorcore_tflops(backend, device)
|
||||
return tflops
|
||||
|
||||
def estimate_matmul_time(
|
||||
# backend, device,
|
||||
num_warps, num_stages,
|
||||
M, N, K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
|
||||
debug=False, **kwargs
|
||||
):
|
||||
''' return estimated running time in ms
|
||||
= max(compute, loading) + store '''
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
num_cta_m = triton.cdiv(M, BLOCK_M)
|
||||
num_cta_n = triton.cdiv(N, BLOCK_N)
|
||||
num_cta_k = SPLIT_K
|
||||
num_ctas = num_cta_m * num_cta_n * num_cta_k
|
||||
|
||||
# If the input is smaller than the block size
|
||||
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
|
||||
|
||||
# time to compute
|
||||
total_ops = 2*M*N*K / (1024*1024*1024) # GOPS
|
||||
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps)
|
||||
compute_ms = total_ops / tput
|
||||
|
||||
# time to load data
|
||||
num_sm = _triton.runtime.num_sm(backend, device)
|
||||
active_cta_ratio = min(1, num_ctas/num_sm)
|
||||
active_cta_ratio_bw1 = min(1, num_ctas/32) # 32 active ctas are enough to saturate
|
||||
active_cta_ratio_bw2 = max(min(1, (num_ctas-32)/(108-32)), 0) # 32-108, remaining 5%
|
||||
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1*0.95 + active_cta_ratio_bw2*0.05) # in GB/s
|
||||
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
|
||||
# assume 80% of (following) loads are in L2 cache
|
||||
load_a_dram = M*K*2*(1+0.2*(num_cta_n-1)) # assume dtype=float16 (size==2)
|
||||
load_a_l2 = M*K*2*0.8*(num_cta_n-1)
|
||||
load_b_dram = N*K*2*(1+0.2*(num_cta_m-1))
|
||||
load_b_l2 = N*K*2*0.8*(num_cta_m-1)
|
||||
# total
|
||||
total_dram = (load_a_dram + load_b_dram) / (1024*1024) # MB
|
||||
total_l2 = (load_a_l2 + load_b_l2) / (1024*1024)
|
||||
# loading time in ms
|
||||
load_ms = total_dram/dram_bw + total_l2/l2_bw
|
||||
|
||||
# estimate storing time
|
||||
store_bw = dram_bw * 0.6 # :o
|
||||
store_c_dram = M*N*2*SPLIT_K / (1024*1024) # MB
|
||||
if SPLIT_K == 1:
|
||||
store_ms = store_c_dram /store_bw
|
||||
else:
|
||||
reduce_bw = store_bw
|
||||
store_ms = store_c_dram/reduce_bw
|
||||
# c.zero_()
|
||||
zero_ms = M*N*2/(1024*1024)/store_bw
|
||||
store_ms += zero_ms
|
||||
|
||||
total_time_ms = max(compute_ms, load_ms) + store_ms
|
||||
if debug:
|
||||
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
|
||||
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
|
||||
f'Activate CTAs: {active_cta_ratio*100}%')
|
||||
return total_time_ms
|
||||
|
||||
def prune_num_stages(configs):
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
||||
|
||||
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
|
||||
configs_map = {}
|
||||
for config in configs:
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
|
||||
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
|
||||
|
||||
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
|
||||
if key in configs_map:
|
||||
configs_map[key].append((config, num_stages))
|
||||
else:
|
||||
configs_map[key] = [(config, num_stages)]
|
||||
|
||||
pruned_configs = []
|
||||
for k, v in configs_map.items():
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
|
||||
if cc >= 80:
|
||||
# compute cycles (only works for ampere GPUs)
|
||||
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16*8*16)
|
||||
mma_cycles = mmas/min(4, num_warps) * 8
|
||||
|
||||
ldgsts_latency = 300 # Does this matter?
|
||||
optimal_num_stages = ldgsts_latency/mma_cycles
|
||||
|
||||
# nearest stages, prefer large #stages
|
||||
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
|
||||
for n in nearest:
|
||||
pruned_configs.append(n[0])
|
||||
else: # Volta & Turing only supports num_stages <= 2
|
||||
random_config = v[0][0]
|
||||
random_config.num_stages = 2
|
||||
pruned_configs.append(random_config)
|
||||
return pruned_configs
|
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import os
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .code_gen import OutOfResources
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -320,3 +321,27 @@ def perf_report(benchmarks):
|
||||
"""
|
||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||
return wrapper
|
||||
|
||||
def get_dram_gbps(backend=None, device=None):
|
||||
''' return DRAM bandwidth in GB/s '''
|
||||
# assert backend == CUDA
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
|
||||
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
|
||||
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
|
||||
return bw_gbps
|
||||
|
||||
def get_max_tensorcore_tflops(backend, device):
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||
# assume fp32 += fp16*fp16
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
if cc < 80:
|
||||
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
||||
else:
|
||||
ops_per_sub_core = 512
|
||||
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024)
|
||||
return tflops
|
Reference in New Issue
Block a user