[FRONTEND] Make the performance model work for int8, tf32, and fp32 (#456)
This commit is contained in:
@@ -6,6 +6,9 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
||||
|
||||
DEVICE_NAME = 'v100'
|
||||
|
||||
#######################
|
||||
# Utilities
|
||||
@@ -25,42 +28,76 @@ def nvsmi(attrs):
|
||||
# Matrix Multiplication
|
||||
#######################
|
||||
|
||||
sm_clocks = {'v100': 1350, 'a100': 1350}
|
||||
mem_clocks = {'v100': 877, 'a100': 1215}
|
||||
|
||||
matmul_data = {
|
||||
# square
|
||||
(256, 256, 256): {'v100': 0.027},
|
||||
(512, 512, 512): {'v100': 0.158},
|
||||
(1024, 1024, 1024): {'v100': 0.466},
|
||||
(2048, 2048, 2048): {'v100': 0.680},
|
||||
(4096, 4096, 4096): {'v100': 0.831},
|
||||
(8192, 8192, 8192): {'v100': 0.849},
|
||||
# tall-skinny
|
||||
(16, 1024, 1024): {'v100': 0.0128},
|
||||
(16, 4096, 4096): {'v100': 0.0883},
|
||||
(16, 8192, 8192): {'v100': 0.101},
|
||||
(64, 1024, 1024): {'v100': 0.073},
|
||||
(64, 4096, 4096): {'v100': 0.270},
|
||||
(64, 8192, 8192): {'v100': 0.360},
|
||||
(1024, 64, 1024): {'v100': 0.0692},
|
||||
(4096, 64, 4096): {'v100': 0.264},
|
||||
(8192, 64, 8192): {'v100': 0.323},
|
||||
'v100': {
|
||||
# square
|
||||
(256, 256, 256): {'float16': 0.027},
|
||||
(512, 512, 512): {'float16': 0.158},
|
||||
(1024, 1024, 1024): {'float16': 0.466},
|
||||
(2048, 2048, 2048): {'float16': 0.680},
|
||||
(4096, 4096, 4096): {'float16': 0.831},
|
||||
(8192, 8192, 8192): {'float16': 0.849},
|
||||
# tall-skinny
|
||||
(16, 1024, 1024): {'float16': 0.0128},
|
||||
(16, 4096, 4096): {'float16': 0.0883},
|
||||
(16, 8192, 8192): {'float16': 0.101},
|
||||
(64, 1024, 1024): {'float16': 0.073},
|
||||
(64, 4096, 4096): {'float16': 0.270},
|
||||
(64, 8192, 8192): {'float16': 0.459},
|
||||
(1024, 64, 1024): {'float16': 0.0692},
|
||||
(4096, 64, 4096): {'float16': 0.264},
|
||||
(8192, 64, 8192): {'float16': 0.452},
|
||||
},
|
||||
'a100': {
|
||||
(256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006},
|
||||
(512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030},
|
||||
(1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169},
|
||||
(2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385},
|
||||
(4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711},
|
||||
(8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860},
|
||||
# tall-skinny
|
||||
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
|
||||
(16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259},
|
||||
(16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431},
|
||||
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
|
||||
(64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097},
|
||||
(64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174},
|
||||
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
|
||||
(4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102},
|
||||
(8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177},
|
||||
}
|
||||
# # deep reductions
|
||||
# (64 , 64 , 16384) : {'v100': 0.},
|
||||
# (64 , 64 , 65536) : {'v100': 0.},
|
||||
# (256 , 256 , 8192 ) : {'v100': 0.},
|
||||
# (256 , 256 , 32768) : {'v100': 0.},
|
||||
# (64 , 64 , 16384) : {'a100': 0.},
|
||||
# (64 , 64 , 65536) : {'a100': 0.},
|
||||
# (256 , 256 , 8192 ) : {'a100': 0.},
|
||||
# (256 , 256 , 32768) : {'a100': 0.},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M, N, K', matmul_data.keys())
|
||||
def test_matmul(M, N, K):
|
||||
@pytest.mark.parametrize('M, N, K, dtype_str',
|
||||
[(M, N, K, dtype_str)
|
||||
for M, N, K in matmul_data[DEVICE_NAME].keys()
|
||||
for dtype_str in ['float16']])
|
||||
def test_matmul(M, N, K, dtype_str):
|
||||
if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100':
|
||||
pytest.skip('Only test float32 & int8 on a100')
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
|
||||
torch.manual_seed(0)
|
||||
ref_gpu_util = matmul_data[(M, N, K)]['v100']
|
||||
ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
ref_sm_clock = 1350
|
||||
max_gpu_perf = 1e-6 * 80 * 8 * 128 * cur_sm_clock
|
||||
ref_sm_clock = sm_clocks[DEVICE_NAME]
|
||||
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
|
||||
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
|
||||
a = torch.randn((M, K), dtype=torch.float16, device='cuda')
|
||||
b = torch.randn((K, N), dtype=torch.float16, device='cuda')
|
||||
if dtype == torch.int8:
|
||||
a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda')
|
||||
b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda')
|
||||
b = b.t() # only test row-col layout
|
||||
else:
|
||||
a = torch.randn((M, K), dtype=dtype, device='cuda')
|
||||
b = torch.randn((K, N), dtype=dtype, device='cuda')
|
||||
fn = lambda: triton.ops.matmul(a, b)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
|
||||
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
|
||||
@@ -87,23 +124,34 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
|
||||
|
||||
elementwise_data = {
|
||||
1024 * 16: {'v100': 0.0219},
|
||||
1024 * 64: {'v100': 0.0791},
|
||||
1024 * 256: {'v100': 0.243},
|
||||
1024 * 1024: {'v100': 0.534},
|
||||
1024 * 4096: {'v100': 0.796},
|
||||
1024 * 16384: {'v100': 0.905},
|
||||
1024 * 65536: {'v100': 0.939},
|
||||
'v100': {
|
||||
1024 * 16: 0.0219,
|
||||
1024 * 64: 0.0791,
|
||||
1024 * 256: 0.243,
|
||||
1024 * 1024: 0.534,
|
||||
1024 * 4096: 0.796,
|
||||
1024 * 16384: 0.905,
|
||||
1024 * 65536: 0.939,
|
||||
},
|
||||
'a100': {
|
||||
1024 * 16: 0.008,
|
||||
1024 * 64: 0.034,
|
||||
1024 * 256: 0.114,
|
||||
1024 * 1024: 0.315,
|
||||
1024 * 4096: 0.580,
|
||||
1024 * 16384: 0.782,
|
||||
1024 * 65536: 0.850,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('N', elementwise_data.keys())
|
||||
@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys())
|
||||
def test_elementwise(N):
|
||||
torch.manual_seed(0)
|
||||
ref_gpu_util = elementwise_data[N]['v100']
|
||||
ref_gpu_util = elementwise_data[DEVICE_NAME][N]
|
||||
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
|
||||
ref_mem_clock = 877
|
||||
max_gpu_perf = 512 * 2 * ref_mem_clock * 1e-3
|
||||
ref_mem_clock = mem_clocks[DEVICE_NAME]
|
||||
max_gpu_perf = get_dram_gbps()
|
||||
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz'
|
||||
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
|
||||
x = torch.randn_like(z)
|
||||
|
@@ -811,12 +811,12 @@ class Autotuner:
|
||||
# prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||
if 'prune_num_stages_by' in prune_configs_by:
|
||||
prune_num_stages_by = prune_configs_by['prune_num_stages_by']
|
||||
if 'early_config_prune' in prune_configs_by:
|
||||
early_config_prune = prune_configs_by['early_config_prune']
|
||||
else:
|
||||
perf_model, top_k, prune_num_stages_by = None, None, None
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.prune_num_stages_by = prune_num_stages_by
|
||||
self.early_config_prune = early_config_prune
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
@@ -844,8 +844,8 @@ class Autotuner:
|
||||
if key not in self.cache:
|
||||
# prune configs
|
||||
pruned_configs = self.configs
|
||||
if self.prune_num_stages_by:
|
||||
pruned_configs = self.prune_num_stages_by(self.configs, self.nargs)
|
||||
if self.early_config_prune:
|
||||
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||
if self.perf_model:
|
||||
top_k = self.configs_top_k
|
||||
if isinstance(top_k, float) and top_k <= 1.0:
|
||||
@@ -1096,7 +1096,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
|
@@ -2,7 +2,7 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .matmul_perf_model import estimate_matmul_time, prune_num_stages
|
||||
from .matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
@@ -27,7 +27,7 @@ def get_configs_io_bound():
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
@@ -41,10 +41,20 @@ def get_configs_io_bound():
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'prune_num_stages_by': prune_num_stages,
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
@@ -55,7 +65,9 @@ def _kernel(A, B, C, M, N, K,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
@@ -76,7 +88,7 @@ def _kernel(A, B, C, M, N, K,
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
for k in range(K, 0, -BLOCK_K * SPLIT_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
@@ -119,13 +131,15 @@ class _matmul(torch.autograd.Function):
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=a.dtype)
|
||||
# accumulator types
|
||||
ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
GROUP_M=8)
|
||||
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
|
@@ -4,20 +4,36 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
||||
from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
|
||||
|
||||
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device)
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
if cc < 80 and dtype == torch.float32:
|
||||
return get_simd_tflops()
|
||||
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
|
||||
|
||||
def estimate_matmul_time(
|
||||
# backend, device,
|
||||
num_warps, num_stages,
|
||||
A, B, C,
|
||||
M, N, K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
|
||||
debug=False, **kwargs
|
||||
@@ -26,6 +42,8 @@ def estimate_matmul_time(
|
||||
= max(compute, loading) + store '''
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
dtype = A.dtype
|
||||
dtsize = A.element_size()
|
||||
|
||||
num_cta_m = triton.cdiv(M, BLOCK_M)
|
||||
num_cta_n = triton.cdiv(N, BLOCK_N)
|
||||
@@ -37,7 +55,7 @@ def estimate_matmul_time(
|
||||
|
||||
# time to compute
|
||||
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
|
||||
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps)
|
||||
tput = get_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
compute_ms = total_ops / tput
|
||||
|
||||
# time to load data
|
||||
@@ -48,10 +66,10 @@ def estimate_matmul_time(
|
||||
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
|
||||
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
|
||||
# assume 80% of (following) loads are in L2 cache
|
||||
load_a_dram = M * K * 2 * (1 + 0.2 * (num_cta_n - 1)) # assume dtype=float16 (size==2)
|
||||
load_a_l2 = M * K * 2 * 0.8 * (num_cta_n - 1)
|
||||
load_b_dram = N * K * 2 * (1 + 0.2 * (num_cta_m - 1))
|
||||
load_b_l2 = N * K * 2 * 0.8 * (num_cta_m - 1)
|
||||
load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
|
||||
load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
|
||||
load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
|
||||
load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
|
||||
# total
|
||||
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
|
||||
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
|
||||
@@ -60,7 +78,7 @@ def estimate_matmul_time(
|
||||
|
||||
# estimate storing time
|
||||
store_bw = dram_bw * 0.6 # :o
|
||||
store_c_dram = M * N * 2 * SPLIT_K / (1024 * 1024) # MB
|
||||
store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB
|
||||
if SPLIT_K == 1:
|
||||
store_ms = store_c_dram / store_bw
|
||||
else:
|
||||
@@ -78,14 +96,28 @@ def estimate_matmul_time(
|
||||
return total_time_ms
|
||||
|
||||
|
||||
def prune_num_stages(configs, named_args):
|
||||
def early_config_prune(configs, named_args):
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
||||
dtsize = named_args['A'].element_size()
|
||||
dtype = named_args['A'].dtype
|
||||
|
||||
# 1. make sure we have enough smem
|
||||
pruned_configs = []
|
||||
for config in configs:
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
|
||||
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
||||
if required_shared_memory <= max_shared_memory:
|
||||
pruned_configs.append(config)
|
||||
configs = pruned_configs
|
||||
|
||||
# Some dtypes do not allow atomic_add
|
||||
if named_args['A'].dtype == torch.bfloat16:
|
||||
if dtype not in [torch.float16, torch.float32]:
|
||||
configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1]
|
||||
|
||||
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
|
||||
|
@@ -330,18 +330,56 @@ def get_dram_gbps(backend=None, device=None):
|
||||
device = torch.cuda.current_device()
|
||||
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
|
||||
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
|
||||
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
|
||||
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
|
||||
return bw_gbps
|
||||
|
||||
|
||||
def get_max_tensorcore_tflops(backend, device):
|
||||
def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clock_rate=None):
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||
# assume fp32 += fp16*fp16
|
||||
if not clock_rate:
|
||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
if cc < 80:
|
||||
assert dtype == torch.float16
|
||||
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
||||
else:
|
||||
ops_per_sub_core = 512
|
||||
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024 * 1024 * 1024)
|
||||
if dtype == torch.float32:
|
||||
ops_per_sub_core = 256
|
||||
elif dtype in [torch.float16, torch.bfloat16]:
|
||||
ops_per_sub_core = 512
|
||||
elif dtype == torch.int8:
|
||||
ops_per_sub_core = 1024
|
||||
else:
|
||||
raise RuntimeError("dtype not supported")
|
||||
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
||||
return tflops
|
||||
|
||||
|
||||
def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None):
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
if cc < 80:
|
||||
if dtype == torch.float32:
|
||||
ops_per_sub_core = 32 # 2*16
|
||||
elif dtype == torch.float16:
|
||||
ops_per_sub_core = 64
|
||||
else:
|
||||
raise RuntimeError("dtype not supported")
|
||||
else:
|
||||
if dtype == torch.float32:
|
||||
ops_per_sub_core = 32
|
||||
elif dtype in [torch.float16, torch.bfloat16]:
|
||||
ops_per_sub_core = 64
|
||||
else:
|
||||
raise RuntimeError("dtype not supported")
|
||||
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
||||
return tflops
|
||||
|
Reference in New Issue
Block a user