[OPS] Add performance model for gemm/gemv (#397)
Significantly improves the performance of `triton.ops.matmul` in memory-bound settings via the use of many more block configs coupled with a performance model to drive the auto-tuning process.
This commit is contained in:
@@ -455,7 +455,7 @@ public:
|
|||||||
// masked load async
|
// masked load async
|
||||||
class masked_load_async_inst: public load_inst {
|
class masked_load_async_inst: public load_inst {
|
||||||
private:
|
private:
|
||||||
std::string repr_impl() const { return "masked_load_async_async" + get_cache_modifier_repr(); }
|
std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
|
||||||
masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
|
masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
|
||||||
const std::string &name, instruction *next);
|
const std::string &name, instruction *next);
|
||||||
|
|
||||||
@@ -728,12 +728,21 @@ public:
|
|||||||
class dot_inst: public builtin_inst {
|
class dot_inst: public builtin_inst {
|
||||||
public:
|
public:
|
||||||
enum TransT { NoTrans, Trans };
|
enum TransT { NoTrans, Trans };
|
||||||
|
enum DataType {
|
||||||
|
FP8, FP16, BF16, TF32, FP32,
|
||||||
|
INT1, INT4, INT8, INT32,
|
||||||
|
UNKNOWN,
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
|
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
|
||||||
std::string repr_impl() const { return "dot"; }
|
std::string repr_impl() const { return "dot"; }
|
||||||
|
|
||||||
bool is_prefetched_ = false;
|
bool is_prefetched_ = false;
|
||||||
|
DataType C_type_ = DataType::FP32;
|
||||||
|
DataType A_type_ = DataType::FP16;
|
||||||
|
DataType B_type_ = DataType::FP16;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
bool is_prefetched() const { return is_prefetched_; }
|
bool is_prefetched() const { return is_prefetched_; }
|
||||||
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
|
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
|
||||||
|
@@ -85,6 +85,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
|||||||
allocation.run(ir);
|
allocation.run(ir);
|
||||||
prefetch_s.run(ir);
|
prefetch_s.run(ir);
|
||||||
barriers.run(ir);
|
barriers.run(ir);
|
||||||
|
// ir.print(std::cout);
|
||||||
isel.visit(ir, *llvm);
|
isel.visit(ir, *llvm);
|
||||||
shared_static = allocation.allocated_size();
|
shared_static = allocation.allocated_size();
|
||||||
return llvm;
|
return llvm;
|
||||||
|
@@ -292,6 +292,16 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
return bin;
|
return bin;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
m.def("cc", [](backend_t backend, uint64_t device) -> int {
|
||||||
|
if (backend == CUDA) {
|
||||||
|
CUdevice dev = (CUdevice)device;
|
||||||
|
int major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
||||||
|
int minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
||||||
|
return major*10 + minor;
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
});
|
||||||
|
|
||||||
// query maximum shared memory
|
// query maximum shared memory
|
||||||
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
|
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
|
||||||
if (backend == HOST)
|
if (backend == HOST)
|
||||||
@@ -303,6 +313,31 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
return -1;
|
return -1;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// query DRAM & L2 cache
|
||||||
|
m.def("memory_clock_rate", [](backend_t backend, uint64_t device) {
|
||||||
|
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE>(device);
|
||||||
|
return -1;
|
||||||
|
});
|
||||||
|
m.def("global_memory_bus_width", [](backend_t backend, uint64_t device) {
|
||||||
|
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH>(device);
|
||||||
|
return -1;
|
||||||
|
});
|
||||||
|
m.def("l2_cache_size", [](backend_t backend, uint64_t device) {
|
||||||
|
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE>(device);
|
||||||
|
return -1;
|
||||||
|
});
|
||||||
|
|
||||||
|
// query clock rate (in kilohertz)
|
||||||
|
m.def("clock_rate", [](backend_t backend, uint64_t device) {
|
||||||
|
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_CLOCK_RATE>(device);
|
||||||
|
return -1;
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def("num_sm", [](backend_t backend, uint64_t device) {
|
||||||
|
if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT>(device);
|
||||||
|
return -1;
|
||||||
|
});
|
||||||
|
|
||||||
// enqueue
|
// enqueue
|
||||||
m.def("enqueue", [](backend_t backend, uint64_t stream, uint64_t kernel,
|
m.def("enqueue", [](backend_t backend, uint64_t stream, uint64_t kernel,
|
||||||
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
|
uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
|
||||||
|
@@ -25,7 +25,7 @@ def nvsmi(attrs):
|
|||||||
matmul_data = {
|
matmul_data = {
|
||||||
# square
|
# square
|
||||||
(256 , 256 , 256 ) : {'v100': 0.027},
|
(256 , 256 , 256 ) : {'v100': 0.027},
|
||||||
(512 , 512 , 512 ) : {'v100': 0.141},
|
(512 , 512 , 512 ) : {'v100': 0.158},
|
||||||
(1024, 1024, 1024 ) : {'v100': 0.466},
|
(1024, 1024, 1024 ) : {'v100': 0.466},
|
||||||
(2048, 2048, 2048 ) : {'v100': 0.680},
|
(2048, 2048, 2048 ) : {'v100': 0.680},
|
||||||
(4096, 4096, 4096 ) : {'v100': 0.831},
|
(4096, 4096, 4096 ) : {'v100': 0.831},
|
||||||
@@ -35,10 +35,10 @@ matmul_data = {
|
|||||||
(16 , 4096, 4096 ) : {'v100': 0.0883},
|
(16 , 4096, 4096 ) : {'v100': 0.0883},
|
||||||
(16 , 8192, 8192 ) : {'v100': 0.101},
|
(16 , 8192, 8192 ) : {'v100': 0.101},
|
||||||
(64 , 1024, 1024 ) : {'v100': 0.073},
|
(64 , 1024, 1024 ) : {'v100': 0.073},
|
||||||
(64 , 4096, 4096 ) : {'v100': 0.228},
|
(64 , 4096, 4096 ) : {'v100': 0.270},
|
||||||
(64 , 8192, 8192 ) : {'v100': 0.360},
|
(64 , 8192, 8192 ) : {'v100': 0.360},
|
||||||
(1024, 64 , 1024 ) : {'v100': 0.0692},
|
(1024, 64 , 1024 ) : {'v100': 0.0692},
|
||||||
(4096, 64 , 4096 ) : {'v100': 0.223},
|
(4096, 64 , 4096 ) : {'v100': 0.264},
|
||||||
(8192, 64 , 8192 ) : {'v100': 0.323},
|
(8192, 64 , 8192 ) : {'v100': 0.323},
|
||||||
# # deep reductions
|
# # deep reductions
|
||||||
# (64 , 64 , 16384) : {'v100': 0.},
|
# (64 , 64 , 16384) : {'v100': 0.},
|
||||||
|
@@ -17,6 +17,10 @@ import triton
|
|||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
import dbm
|
import dbm
|
||||||
|
import tempfile
|
||||||
|
from typing import Optional, Dict
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CodeGenerator(ast.NodeVisitor):
|
class CodeGenerator(ast.NodeVisitor):
|
||||||
@@ -508,6 +512,7 @@ class LoadedBinary:
|
|||||||
device)
|
device)
|
||||||
self.bin = bin
|
self.bin = bin
|
||||||
self.asm = bin.asm
|
self.asm = bin.asm
|
||||||
|
self.sass = ''
|
||||||
self.module = module
|
self.module = module
|
||||||
self.kernel = kernel
|
self.kernel = kernel
|
||||||
self.device = device
|
self.device = device
|
||||||
@@ -519,6 +524,19 @@ class LoadedBinary:
|
|||||||
self.bin.num_warps * 32, 1, 1,
|
self.bin.num_warps * 32, 1, 1,
|
||||||
args, self.bin.shared_mem)
|
args, self.bin.shared_mem)
|
||||||
|
|
||||||
|
def get_sass(self, fun=None):
|
||||||
|
if self.sass:
|
||||||
|
return self.sass
|
||||||
|
fd, path = tempfile.mkstemp()
|
||||||
|
try:
|
||||||
|
with open(fd, 'wb') as cubin:
|
||||||
|
cubin.write(self.asm['cubin'])
|
||||||
|
self.sass = extract(path, fun)
|
||||||
|
finally:
|
||||||
|
os.remove(path)
|
||||||
|
self.asm['sass'] = self.sass
|
||||||
|
return self.sass
|
||||||
|
|
||||||
|
|
||||||
class CompilationError(Exception):
|
class CompilationError(Exception):
|
||||||
def __init__(self, src, node):
|
def __init__(self, src, node):
|
||||||
@@ -530,8 +548,8 @@ class CompilationError(Exception):
|
|||||||
|
|
||||||
class OutOfResources(Exception):
|
class OutOfResources(Exception):
|
||||||
def __init__(self, required, limit, name):
|
def __init__(self, required, limit, name):
|
||||||
self.message = f'out of resource: {name}'\
|
self.message = f'out of resource: {name}, '\
|
||||||
f'Required: {required}'\
|
f'Required: {required}, '\
|
||||||
f'Hardware limit: {limit}'
|
f'Hardware limit: {limit}'
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
self.args = (required, limit, name)
|
self.args = (required, limit, name)
|
||||||
@@ -727,7 +745,13 @@ class Launcher:
|
|||||||
|
|
||||||
|
|
||||||
class Autotuner:
|
class Autotuner:
|
||||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero):
|
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None):
|
||||||
|
'''
|
||||||
|
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||||
|
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||||
|
'top_k': number of configs to bench
|
||||||
|
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||||
|
'''
|
||||||
if not configs:
|
if not configs:
|
||||||
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
|
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
|
||||||
else:
|
else:
|
||||||
@@ -744,6 +768,15 @@ class Autotuner:
|
|||||||
args[i].zero_()
|
args[i].zero_()
|
||||||
self.hook = _hook
|
self.hook = _hook
|
||||||
self.arg_names = arg_names
|
self.arg_names = arg_names
|
||||||
|
# prune configs
|
||||||
|
if prune_configs_by:
|
||||||
|
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||||
|
if 'prune_num_stages_by' in prune_configs_by:
|
||||||
|
prune_num_stages_by = prune_configs_by['prune_num_stages_by']
|
||||||
|
else:
|
||||||
|
perf_model, top_k, prune_num_stages_by = None, None, None
|
||||||
|
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||||
|
self.prune_num_stages_by = prune_num_stages_by
|
||||||
|
|
||||||
def _bench(self, *args, config, **meta):
|
def _bench(self, *args, config, **meta):
|
||||||
# check for conflicts, i.e. meta-parameters both provided
|
# check for conflicts, i.e. meta-parameters both provided
|
||||||
@@ -768,13 +801,29 @@ class Autotuner:
|
|||||||
if len(self.configs) > 1:
|
if len(self.configs) > 1:
|
||||||
key = tuple([args[i] for i in self.key_idx])
|
key = tuple([args[i] for i in self.key_idx])
|
||||||
if key not in self.cache:
|
if key not in self.cache:
|
||||||
|
# prune configs
|
||||||
|
pruned_configs = self.configs
|
||||||
|
if self.prune_num_stages_by:
|
||||||
|
pruned_configs = self.prune_num_stages_by(self.configs)
|
||||||
|
if self.perf_model:
|
||||||
|
top_k = self.configs_top_k
|
||||||
|
if isinstance(top_k, float) and top_k <= 1.0:
|
||||||
|
top_k = int(len(self.configs) * top_k)
|
||||||
|
if len(pruned_configs) > top_k:
|
||||||
|
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
||||||
|
pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k]
|
||||||
|
bench_start = time.time()
|
||||||
timings = {config: self._bench(*args, config=config, **kwargs) \
|
timings = {config: self._bench(*args, config=config, **kwargs) \
|
||||||
for config in self.configs}
|
for config in pruned_configs}
|
||||||
|
bench_end = time.time()
|
||||||
|
self.bench_time = bench_end - bench_start
|
||||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||||
self.hook(args)
|
self.hook(args)
|
||||||
|
self.configs_timings = timings
|
||||||
config = self.cache[key]
|
config = self.cache[key]
|
||||||
else:
|
else:
|
||||||
config = self.configs[0]
|
config = self.configs[0]
|
||||||
|
self.best_config = config
|
||||||
if config.pre_hook != None:
|
if config.pre_hook != None:
|
||||||
config.pre_hook(self.nargs)
|
config.pre_hook(self.nargs)
|
||||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||||
@@ -832,6 +881,8 @@ class DependenciesFinder(ast.NodeVisitor):
|
|||||||
module = inspect.getmodule(func)
|
module = inspect.getmodule(func)
|
||||||
if module and module.__name__.startswith('triton.'):
|
if module and module.__name__.startswith('triton.'):
|
||||||
return
|
return
|
||||||
|
if inspect.isbuiltin(func):
|
||||||
|
return
|
||||||
if not hasattr(func, 'hash'):
|
if not hasattr(func, 'hash'):
|
||||||
src = textwrap.dedent(inspect.getsource(func))
|
src = textwrap.dedent(inspect.getsource(func))
|
||||||
tree = ast.parse(src)
|
tree = ast.parse(src)
|
||||||
@@ -957,8 +1008,16 @@ class Config:
|
|||||||
self.num_stages = num_stages
|
self.num_stages = num_stages
|
||||||
self.pre_hook = pre_hook
|
self.pre_hook = pre_hook
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
res = []
|
||||||
|
for k, v in self.kwargs.items():
|
||||||
|
res.append(f'{k}: {v}')
|
||||||
|
res.append(f'num_warps: {self.num_warps}')
|
||||||
|
res.append(f'num_stages: {self.num_stages}')
|
||||||
|
return ', '.join(res)
|
||||||
|
|
||||||
def autotune(configs, key, reset_to_zero=None):
|
|
||||||
|
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||||
"""
|
"""
|
||||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||||
|
|
||||||
@@ -985,12 +1044,16 @@ def autotune(configs, key, reset_to_zero=None):
|
|||||||
:type configs: list[triton.Config]
|
:type configs: list[triton.Config]
|
||||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||||
:type key: list[str]
|
:type key: list[str]
|
||||||
|
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||||
|
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||||
|
'top_k': number of configs to bench
|
||||||
|
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||||
:type reset_to_zero: list[str]
|
:type reset_to_zero: list[str]
|
||||||
"""
|
"""
|
||||||
def decorator(fn):
|
def decorator(fn):
|
||||||
def wrapper(kernel):
|
def wrapper(kernel):
|
||||||
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero)
|
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
|
||||||
|
|
||||||
fn.kernel_decorators.append(wrapper)
|
fn.kernel_decorators.append(wrapper)
|
||||||
return fn
|
return fn
|
||||||
@@ -1023,7 +1086,6 @@ def heuristics(values):
|
|||||||
assert v not in meta
|
assert v not in meta
|
||||||
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
|
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
|
||||||
return kernel(*args, **meta)
|
return kernel(*args, **meta)
|
||||||
|
|
||||||
return fun
|
return fun
|
||||||
|
|
||||||
fn.kernel_decorators.append(wrapper)
|
fn.kernel_decorators.append(wrapper)
|
||||||
|
@@ -1,15 +1,33 @@
|
|||||||
import torch
|
import torch
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import triton
|
import triton
|
||||||
|
from .matmul_perf_model import *
|
||||||
|
|
||||||
def init_to_zero(name):
|
def init_to_zero(name):
|
||||||
return lambda nargs: nargs[name].zero_()
|
return lambda nargs: nargs[name].zero_()
|
||||||
|
|
||||||
|
def get_configs_io_bound():
|
||||||
|
configs = []
|
||||||
|
for num_stages in [2, 3, 4, 5, 6]:
|
||||||
|
for block_m in [16, 32]:
|
||||||
|
for block_k in [32, 64]:
|
||||||
|
for block_n in [32, 64, 128, 256]:
|
||||||
|
num_warps = 2 if block_n <= 64 else 4
|
||||||
|
configs.append(
|
||||||
|
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||||
|
num_stages=num_stages, num_warps=num_warps))
|
||||||
|
# split_k
|
||||||
|
for split_k in [2, 4, 8, 16]:
|
||||||
|
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||||
|
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||||
|
return configs
|
||||||
|
|
||||||
@triton.heuristics({
|
@triton.heuristics({
|
||||||
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
|
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
|
||||||
})
|
})
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
|
# basic configs for compute-bound matmuls
|
||||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
@@ -19,17 +37,13 @@ def init_to_zero(name):
|
|||||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||||
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
] + get_configs_io_bound(),
|
||||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
|
|
||||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
|
|
||||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
|
|
||||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
|
|
||||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
|
|
||||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
|
|
||||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
|
|
||||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
|
|
||||||
],
|
|
||||||
key=['M', 'N', 'K'],
|
key=['M', 'N', 'K'],
|
||||||
|
prune_configs_by={
|
||||||
|
'prune_num_stages_by' : prune_num_stages,
|
||||||
|
'perf_model': estimate_matmul_time,
|
||||||
|
'top_k': 10
|
||||||
|
},
|
||||||
)
|
)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _kernel(A, B, C, M, N, K,
|
def _kernel(A, B, C, M, N, K,
|
||||||
|
116
python/triton/ops/matmul_perf_model.py
Normal file
116
python/triton/ops/matmul_perf_model.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton._C.libtriton.triton as _triton
|
||||||
|
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
||||||
|
import heapq
|
||||||
|
|
||||||
|
def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
|
||||||
|
''' return compute throughput in TOPS '''
|
||||||
|
total_warps = num_ctas * min(num_warps, 4)
|
||||||
|
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||||
|
tflops = min(num_subcores, total_warps)/num_subcores * get_max_tensorcore_tflops(backend, device)
|
||||||
|
return tflops
|
||||||
|
|
||||||
|
def estimate_matmul_time(
|
||||||
|
# backend, device,
|
||||||
|
num_warps, num_stages,
|
||||||
|
M, N, K,
|
||||||
|
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
|
||||||
|
debug=False, **kwargs
|
||||||
|
):
|
||||||
|
''' return estimated running time in ms
|
||||||
|
= max(compute, loading) + store '''
|
||||||
|
backend = _triton.runtime.backend.CUDA
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
|
||||||
|
num_cta_m = triton.cdiv(M, BLOCK_M)
|
||||||
|
num_cta_n = triton.cdiv(N, BLOCK_N)
|
||||||
|
num_cta_k = SPLIT_K
|
||||||
|
num_ctas = num_cta_m * num_cta_n * num_cta_k
|
||||||
|
|
||||||
|
# If the input is smaller than the block size
|
||||||
|
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
|
||||||
|
|
||||||
|
# time to compute
|
||||||
|
total_ops = 2*M*N*K / (1024*1024*1024) # GOPS
|
||||||
|
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps)
|
||||||
|
compute_ms = total_ops / tput
|
||||||
|
|
||||||
|
# time to load data
|
||||||
|
num_sm = _triton.runtime.num_sm(backend, device)
|
||||||
|
active_cta_ratio = min(1, num_ctas/num_sm)
|
||||||
|
active_cta_ratio_bw1 = min(1, num_ctas/32) # 32 active ctas are enough to saturate
|
||||||
|
active_cta_ratio_bw2 = max(min(1, (num_ctas-32)/(108-32)), 0) # 32-108, remaining 5%
|
||||||
|
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1*0.95 + active_cta_ratio_bw2*0.05) # in GB/s
|
||||||
|
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
|
||||||
|
# assume 80% of (following) loads are in L2 cache
|
||||||
|
load_a_dram = M*K*2*(1+0.2*(num_cta_n-1)) # assume dtype=float16 (size==2)
|
||||||
|
load_a_l2 = M*K*2*0.8*(num_cta_n-1)
|
||||||
|
load_b_dram = N*K*2*(1+0.2*(num_cta_m-1))
|
||||||
|
load_b_l2 = N*K*2*0.8*(num_cta_m-1)
|
||||||
|
# total
|
||||||
|
total_dram = (load_a_dram + load_b_dram) / (1024*1024) # MB
|
||||||
|
total_l2 = (load_a_l2 + load_b_l2) / (1024*1024)
|
||||||
|
# loading time in ms
|
||||||
|
load_ms = total_dram/dram_bw + total_l2/l2_bw
|
||||||
|
|
||||||
|
# estimate storing time
|
||||||
|
store_bw = dram_bw * 0.6 # :o
|
||||||
|
store_c_dram = M*N*2*SPLIT_K / (1024*1024) # MB
|
||||||
|
if SPLIT_K == 1:
|
||||||
|
store_ms = store_c_dram /store_bw
|
||||||
|
else:
|
||||||
|
reduce_bw = store_bw
|
||||||
|
store_ms = store_c_dram/reduce_bw
|
||||||
|
# c.zero_()
|
||||||
|
zero_ms = M*N*2/(1024*1024)/store_bw
|
||||||
|
store_ms += zero_ms
|
||||||
|
|
||||||
|
total_time_ms = max(compute_ms, load_ms) + store_ms
|
||||||
|
if debug:
|
||||||
|
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
|
||||||
|
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
|
||||||
|
f'Activate CTAs: {active_cta_ratio*100}%')
|
||||||
|
return total_time_ms
|
||||||
|
|
||||||
|
def prune_num_stages(configs):
|
||||||
|
backend = _triton.runtime.backend.CUDA
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
cc = _triton.runtime.cc(backend, device)
|
||||||
|
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
||||||
|
|
||||||
|
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
|
||||||
|
configs_map = {}
|
||||||
|
for config in configs:
|
||||||
|
kw = config.kwargs
|
||||||
|
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
|
||||||
|
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
|
||||||
|
|
||||||
|
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
|
||||||
|
if key in configs_map:
|
||||||
|
configs_map[key].append((config, num_stages))
|
||||||
|
else:
|
||||||
|
configs_map[key] = [(config, num_stages)]
|
||||||
|
|
||||||
|
pruned_configs = []
|
||||||
|
for k, v in configs_map.items():
|
||||||
|
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
|
||||||
|
if cc >= 80:
|
||||||
|
# compute cycles (only works for ampere GPUs)
|
||||||
|
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16*8*16)
|
||||||
|
mma_cycles = mmas/min(4, num_warps) * 8
|
||||||
|
|
||||||
|
ldgsts_latency = 300 # Does this matter?
|
||||||
|
optimal_num_stages = ldgsts_latency/mma_cycles
|
||||||
|
|
||||||
|
# nearest stages, prefer large #stages
|
||||||
|
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \
|
||||||
|
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||||
|
|
||||||
|
for n in nearest:
|
||||||
|
pruned_configs.append(n[0])
|
||||||
|
else: # Volta & Turing only supports num_stages <= 2
|
||||||
|
random_config = v[0][0]
|
||||||
|
random_config.num_stages = 2
|
||||||
|
pruned_configs.append(random_config)
|
||||||
|
return pruned_configs
|
@@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
import triton._C.libtriton.triton as _triton
|
||||||
from .code_gen import OutOfResources
|
from .code_gen import OutOfResources
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
@@ -320,3 +321,27 @@ def perf_report(benchmarks):
|
|||||||
"""
|
"""
|
||||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
def get_dram_gbps(backend=None, device=None):
|
||||||
|
''' return DRAM bandwidth in GB/s '''
|
||||||
|
# assert backend == CUDA
|
||||||
|
if not backend:
|
||||||
|
backend = _triton.runtime.backend.CUDA
|
||||||
|
if not device:
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
|
||||||
|
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
|
||||||
|
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
|
||||||
|
return bw_gbps
|
||||||
|
|
||||||
|
def get_max_tensorcore_tflops(backend, device):
|
||||||
|
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||||
|
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||||
|
# assume fp32 += fp16*fp16
|
||||||
|
cc = _triton.runtime.cc(backend, device)
|
||||||
|
if cc < 80:
|
||||||
|
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
||||||
|
else:
|
||||||
|
ops_per_sub_core = 512
|
||||||
|
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024)
|
||||||
|
return tflops
|
Reference in New Issue
Block a user