[FRONTEND] Make the performance model work for int8, tf32, and fp32 (#456)

This commit is contained in:
daadaada
2022-02-12 14:34:42 +08:00
committed by GitHub
parent 9b100302d3
commit a9dfdcaaa9
5 changed files with 201 additions and 69 deletions

View File

@@ -6,6 +6,9 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
DEVICE_NAME = 'v100'
####################### #######################
# Utilities # Utilities
@@ -25,42 +28,76 @@ def nvsmi(attrs):
# Matrix Multiplication # Matrix Multiplication
####################### #######################
sm_clocks = {'v100': 1350, 'a100': 1350}
mem_clocks = {'v100': 877, 'a100': 1215}
matmul_data = { matmul_data = {
# square 'v100': {
(256, 256, 256): {'v100': 0.027}, # square
(512, 512, 512): {'v100': 0.158}, (256, 256, 256): {'float16': 0.027},
(1024, 1024, 1024): {'v100': 0.466}, (512, 512, 512): {'float16': 0.158},
(2048, 2048, 2048): {'v100': 0.680}, (1024, 1024, 1024): {'float16': 0.466},
(4096, 4096, 4096): {'v100': 0.831}, (2048, 2048, 2048): {'float16': 0.680},
(8192, 8192, 8192): {'v100': 0.849}, (4096, 4096, 4096): {'float16': 0.831},
# tall-skinny (8192, 8192, 8192): {'float16': 0.849},
(16, 1024, 1024): {'v100': 0.0128}, # tall-skinny
(16, 4096, 4096): {'v100': 0.0883}, (16, 1024, 1024): {'float16': 0.0128},
(16, 8192, 8192): {'v100': 0.101}, (16, 4096, 4096): {'float16': 0.0883},
(64, 1024, 1024): {'v100': 0.073}, (16, 8192, 8192): {'float16': 0.101},
(64, 4096, 4096): {'v100': 0.270}, (64, 1024, 1024): {'float16': 0.073},
(64, 8192, 8192): {'v100': 0.360}, (64, 4096, 4096): {'float16': 0.270},
(1024, 64, 1024): {'v100': 0.0692}, (64, 8192, 8192): {'float16': 0.459},
(4096, 64, 4096): {'v100': 0.264}, (1024, 64, 1024): {'float16': 0.0692},
(8192, 64, 8192): {'v100': 0.323}, (4096, 64, 4096): {'float16': 0.264},
(8192, 64, 8192): {'float16': 0.452},
},
'a100': {
(256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006},
(512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030},
(1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385},
(4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711},
(8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259},
(16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431},
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
(64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177},
}
# # deep reductions # # deep reductions
# (64 , 64 , 16384) : {'v100': 0.}, # (64 , 64 , 16384) : {'a100': 0.},
# (64 , 64 , 65536) : {'v100': 0.}, # (64 , 64 , 65536) : {'a100': 0.},
# (256 , 256 , 8192 ) : {'v100': 0.}, # (256 , 256 , 8192 ) : {'a100': 0.},
# (256 , 256 , 32768) : {'v100': 0.}, # (256 , 256 , 32768) : {'a100': 0.},
} }
@pytest.mark.parametrize('M, N, K', matmul_data.keys()) @pytest.mark.parametrize('M, N, K, dtype_str',
def test_matmul(M, N, K): [(M, N, K, dtype_str)
for M, N, K in matmul_data[DEVICE_NAME].keys()
for dtype_str in ['float16']])
def test_matmul(M, N, K, dtype_str):
if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100':
pytest.skip('Only test float32 & int8 on a100')
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
torch.manual_seed(0) torch.manual_seed(0)
ref_gpu_util = matmul_data[(M, N, K)]['v100'] ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
cur_sm_clock = nvsmi(['clocks.current.sm'])[0] cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
ref_sm_clock = 1350 ref_sm_clock = sm_clocks[DEVICE_NAME]
max_gpu_perf = 1e-6 * 80 * 8 * 128 * cur_sm_clock max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz' assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
a = torch.randn((M, K), dtype=torch.float16, device='cuda') if dtype == torch.int8:
b = torch.randn((K, N), dtype=torch.float16, device='cuda') a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda')
b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda')
b = b.t() # only test row-col layout
else:
a = torch.randn((M, K), dtype=dtype, device='cuda')
b = torch.randn((K, N), dtype=dtype, device='cuda')
fn = lambda: triton.ops.matmul(a, b) fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000) ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
cur_gpu_perf = 2. * M * N * K / ms * 1e-9 cur_gpu_perf = 2. * M * N * K / ms * 1e-9
@@ -87,23 +124,34 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,
elementwise_data = { elementwise_data = {
1024 * 16: {'v100': 0.0219}, 'v100': {
1024 * 64: {'v100': 0.0791}, 1024 * 16: 0.0219,
1024 * 256: {'v100': 0.243}, 1024 * 64: 0.0791,
1024 * 1024: {'v100': 0.534}, 1024 * 256: 0.243,
1024 * 4096: {'v100': 0.796}, 1024 * 1024: 0.534,
1024 * 16384: {'v100': 0.905}, 1024 * 4096: 0.796,
1024 * 65536: {'v100': 0.939}, 1024 * 16384: 0.905,
1024 * 65536: 0.939,
},
'a100': {
1024 * 16: 0.008,
1024 * 64: 0.034,
1024 * 256: 0.114,
1024 * 1024: 0.315,
1024 * 4096: 0.580,
1024 * 16384: 0.782,
1024 * 65536: 0.850,
}
} }
@pytest.mark.parametrize('N', elementwise_data.keys()) @pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys())
def test_elementwise(N): def test_elementwise(N):
torch.manual_seed(0) torch.manual_seed(0)
ref_gpu_util = elementwise_data[N]['v100'] ref_gpu_util = elementwise_data[DEVICE_NAME][N]
cur_mem_clock = nvsmi(['clocks.current.memory'])[0] cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
ref_mem_clock = 877 ref_mem_clock = mem_clocks[DEVICE_NAME]
max_gpu_perf = 512 * 2 * ref_mem_clock * 1e-3 max_gpu_perf = get_dram_gbps()
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz' assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz'
z = torch.empty((N, ), dtype=torch.float16, device='cuda') z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z) x = torch.randn_like(z)

View File

@@ -811,12 +811,12 @@ class Autotuner:
# prune configs # prune configs
if prune_configs_by: if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'prune_num_stages_by' in prune_configs_by: if 'early_config_prune' in prune_configs_by:
prune_num_stages_by = prune_configs_by['prune_num_stages_by'] early_config_prune = prune_configs_by['early_config_prune']
else: else:
perf_model, top_k, prune_num_stages_by = None, None, None perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k self.perf_model, self.configs_top_k = perf_model, top_k
self.prune_num_stages_by = prune_num_stages_by self.early_config_prune = early_config_prune
def _bench(self, *args, config, **meta): def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided # check for conflicts, i.e. meta-parameters both provided
@@ -844,8 +844,8 @@ class Autotuner:
if key not in self.cache: if key not in self.cache:
# prune configs # prune configs
pruned_configs = self.configs pruned_configs = self.configs
if self.prune_num_stages_by: if self.early_config_prune:
pruned_configs = self.prune_num_stages_by(self.configs, self.nargs) pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model: if self.perf_model:
top_k = self.configs_top_k top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0: if isinstance(top_k, float) and top_k <= 1.0:
@@ -1096,7 +1096,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
:param prune_configs_by: a dict of functions that are used to prune configs, fields: :param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time 'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench 'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str] :type reset_to_zero: list[str]
""" """

View File

@@ -2,7 +2,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from .matmul_perf_model import estimate_matmul_time, prune_num_stages from .matmul_perf_model import early_config_prune, estimate_matmul_time
def init_to_zero(name): def init_to_zero(name):
@@ -27,7 +27,7 @@ def get_configs_io_bound():
@triton.heuristics({ @triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0, 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
}) })
@triton.autotune( @triton.autotune(
configs=[ configs=[
@@ -41,10 +41,20 @@ def get_configs_io_bound():
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(), ] + get_configs_io_bound(),
key=['M', 'N', 'K'], key=['M', 'N', 'K'],
prune_configs_by={ prune_configs_by={
'prune_num_stages_by': prune_num_stages, 'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time, 'perf_model': estimate_matmul_time,
'top_k': 10 'top_k': 10
}, },
@@ -55,7 +65,9 @@ def _kernel(A, B, C, M, N, K,
stride_bk, stride_bn, stride_bk, stride_bn,
stride_cm, stride_cn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr): GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication # matrix multiplication
pid = tl.program_id(0) pid = tl.program_id(0)
pid_z = tl.program_id(1) pid_z = tl.program_id(1)
@@ -76,7 +88,7 @@ def _kernel(A, B, C, M, N, K,
# pointers # pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K * SPLIT_K): for k in range(K, 0, -BLOCK_K * SPLIT_K):
if EVEN_K: if EVEN_K:
a = tl.load(A) a = tl.load(A)
@@ -119,13 +131,15 @@ class _matmul(torch.autograd.Function):
_, N = b.shape _, N = b.shape
# allocates output # allocates output
c = torch.empty((M, N), device=device, dtype=a.dtype) c = torch.empty((M, N), device=device, dtype=a.dtype)
# accumulator types
ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel # launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K, _kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1), a.stride(0), a.stride(1),
b.stride(0), b.stride(1), b.stride(0), b.stride(1),
c.stride(0), c.stride(1), c.stride(0), c.stride(1),
GROUP_M=8) GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c return c
@staticmethod @staticmethod

View File

@@ -4,20 +4,36 @@ import torch
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
def get_tensorcore_tflops(backend, device, num_ctas, num_warps): def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS ''' ''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4) total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device) tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
return tflops return tflops
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
return tflops
def get_tflops(backend, device, num_ctas, num_warps, dtype):
cc = _triton.runtime.cc(backend, device)
if cc < 80 and dtype == torch.float32:
return get_simd_tflops()
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
def estimate_matmul_time( def estimate_matmul_time(
# backend, device, # backend, device,
num_warps, num_stages, num_warps, num_stages,
A, B, C,
M, N, K, M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
debug=False, **kwargs debug=False, **kwargs
@@ -26,6 +42,8 @@ def estimate_matmul_time(
= max(compute, loading) + store ''' = max(compute, loading) + store '''
backend = _triton.runtime.backend.CUDA backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device() device = torch.cuda.current_device()
dtype = A.dtype
dtsize = A.element_size()
num_cta_m = triton.cdiv(M, BLOCK_M) num_cta_m = triton.cdiv(M, BLOCK_M)
num_cta_n = triton.cdiv(N, BLOCK_N) num_cta_n = triton.cdiv(N, BLOCK_N)
@@ -37,7 +55,7 @@ def estimate_matmul_time(
# time to compute # time to compute
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps) tput = get_tflops(backend, device, num_ctas, num_warps, dtype)
compute_ms = total_ops / tput compute_ms = total_ops / tput
# time to load data # time to load data
@@ -48,10 +66,10 @@ def estimate_matmul_time(
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
# assume 80% of (following) loads are in L2 cache # assume 80% of (following) loads are in L2 cache
load_a_dram = M * K * 2 * (1 + 0.2 * (num_cta_n - 1)) # assume dtype=float16 (size==2) load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
load_a_l2 = M * K * 2 * 0.8 * (num_cta_n - 1) load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
load_b_dram = N * K * 2 * (1 + 0.2 * (num_cta_m - 1)) load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
load_b_l2 = N * K * 2 * 0.8 * (num_cta_m - 1) load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
# total # total
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
@@ -60,7 +78,7 @@ def estimate_matmul_time(
# estimate storing time # estimate storing time
store_bw = dram_bw * 0.6 # :o store_bw = dram_bw * 0.6 # :o
store_c_dram = M * N * 2 * SPLIT_K / (1024 * 1024) # MB store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB
if SPLIT_K == 1: if SPLIT_K == 1:
store_ms = store_c_dram / store_bw store_ms = store_c_dram / store_bw
else: else:
@@ -78,14 +96,28 @@ def estimate_matmul_time(
return total_time_ms return total_time_ms
def prune_num_stages(configs, named_args): def early_config_prune(configs, named_args):
backend = _triton.runtime.backend.CUDA backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device() device = torch.cuda.current_device()
cc = _triton.runtime.cc(backend, device) cc = _triton.runtime.cc(backend, device)
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
dtsize = named_args['A'].element_size()
dtype = named_args['A'].dtype
# 1. make sure we have enough smem
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory <= max_shared_memory:
pruned_configs.append(config)
configs = pruned_configs
# Some dtypes do not allow atomic_add # Some dtypes do not allow atomic_add
if named_args['A'].dtype == torch.bfloat16: if dtype not in [torch.float16, torch.float32]:
configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1] configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1]
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)

View File

@@ -330,18 +330,56 @@ def get_dram_gbps(backend=None, device=None):
device = torch.cuda.current_device() device = torch.cuda.current_device()
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device) mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
bus_width = _triton.runtime.global_memory_bus_width(backend, device) bus_width = _triton.runtime.global_memory_bus_width(backend, device)
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
return bw_gbps return bw_gbps
def get_max_tensorcore_tflops(backend, device): def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clock_rate=None):
if not backend:
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz if not clock_rate:
# assume fp32 += fp16*fp16 clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
cc = _triton.runtime.cc(backend, device) cc = _triton.runtime.cc(backend, device)
if cc < 80: if cc < 80:
assert dtype == torch.float16
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
else: else:
ops_per_sub_core = 512 if dtype == torch.float32:
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024 * 1024 * 1024) ops_per_sub_core = 256
elif dtype in [torch.float16, torch.bfloat16]:
ops_per_sub_core = 512
elif dtype == torch.int8:
ops_per_sub_core = 1024
else:
raise RuntimeError("dtype not supported")
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
return tflops
def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None):
if not backend:
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
cc = _triton.runtime.cc(backend, device)
if cc < 80:
if dtype == torch.float32:
ops_per_sub_core = 32 # 2*16
elif dtype == torch.float16:
ops_per_sub_core = 64
else:
raise RuntimeError("dtype not supported")
else:
if dtype == torch.float32:
ops_per_sub_core = 32
elif dtype in [torch.float16, torch.bfloat16]:
ops_per_sub_core = 64
else:
raise RuntimeError("dtype not supported")
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
return tflops return tflops