[FRONTEND] Make the performance model work for int8, tf32, and fp32 (#456)
This commit is contained in:
@@ -6,6 +6,9 @@ import torch
|
|||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
||||||
|
|
||||||
|
DEVICE_NAME = 'v100'
|
||||||
|
|
||||||
#######################
|
#######################
|
||||||
# Utilities
|
# Utilities
|
||||||
@@ -25,42 +28,76 @@ def nvsmi(attrs):
|
|||||||
# Matrix Multiplication
|
# Matrix Multiplication
|
||||||
#######################
|
#######################
|
||||||
|
|
||||||
|
sm_clocks = {'v100': 1350, 'a100': 1350}
|
||||||
|
mem_clocks = {'v100': 877, 'a100': 1215}
|
||||||
|
|
||||||
matmul_data = {
|
matmul_data = {
|
||||||
# square
|
'v100': {
|
||||||
(256, 256, 256): {'v100': 0.027},
|
# square
|
||||||
(512, 512, 512): {'v100': 0.158},
|
(256, 256, 256): {'float16': 0.027},
|
||||||
(1024, 1024, 1024): {'v100': 0.466},
|
(512, 512, 512): {'float16': 0.158},
|
||||||
(2048, 2048, 2048): {'v100': 0.680},
|
(1024, 1024, 1024): {'float16': 0.466},
|
||||||
(4096, 4096, 4096): {'v100': 0.831},
|
(2048, 2048, 2048): {'float16': 0.680},
|
||||||
(8192, 8192, 8192): {'v100': 0.849},
|
(4096, 4096, 4096): {'float16': 0.831},
|
||||||
# tall-skinny
|
(8192, 8192, 8192): {'float16': 0.849},
|
||||||
(16, 1024, 1024): {'v100': 0.0128},
|
# tall-skinny
|
||||||
(16, 4096, 4096): {'v100': 0.0883},
|
(16, 1024, 1024): {'float16': 0.0128},
|
||||||
(16, 8192, 8192): {'v100': 0.101},
|
(16, 4096, 4096): {'float16': 0.0883},
|
||||||
(64, 1024, 1024): {'v100': 0.073},
|
(16, 8192, 8192): {'float16': 0.101},
|
||||||
(64, 4096, 4096): {'v100': 0.270},
|
(64, 1024, 1024): {'float16': 0.073},
|
||||||
(64, 8192, 8192): {'v100': 0.360},
|
(64, 4096, 4096): {'float16': 0.270},
|
||||||
(1024, 64, 1024): {'v100': 0.0692},
|
(64, 8192, 8192): {'float16': 0.459},
|
||||||
(4096, 64, 4096): {'v100': 0.264},
|
(1024, 64, 1024): {'float16': 0.0692},
|
||||||
(8192, 64, 8192): {'v100': 0.323},
|
(4096, 64, 4096): {'float16': 0.264},
|
||||||
|
(8192, 64, 8192): {'float16': 0.452},
|
||||||
|
},
|
||||||
|
'a100': {
|
||||||
|
(256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006},
|
||||||
|
(512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030},
|
||||||
|
(1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169},
|
||||||
|
(2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385},
|
||||||
|
(4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711},
|
||||||
|
(8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860},
|
||||||
|
# tall-skinny
|
||||||
|
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
|
||||||
|
(16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259},
|
||||||
|
(16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431},
|
||||||
|
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
|
||||||
|
(64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097},
|
||||||
|
(64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174},
|
||||||
|
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
|
||||||
|
(4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102},
|
||||||
|
(8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177},
|
||||||
|
}
|
||||||
# # deep reductions
|
# # deep reductions
|
||||||
# (64 , 64 , 16384) : {'v100': 0.},
|
# (64 , 64 , 16384) : {'a100': 0.},
|
||||||
# (64 , 64 , 65536) : {'v100': 0.},
|
# (64 , 64 , 65536) : {'a100': 0.},
|
||||||
# (256 , 256 , 8192 ) : {'v100': 0.},
|
# (256 , 256 , 8192 ) : {'a100': 0.},
|
||||||
# (256 , 256 , 32768) : {'v100': 0.},
|
# (256 , 256 , 32768) : {'a100': 0.},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('M, N, K', matmul_data.keys())
|
@pytest.mark.parametrize('M, N, K, dtype_str',
|
||||||
def test_matmul(M, N, K):
|
[(M, N, K, dtype_str)
|
||||||
|
for M, N, K in matmul_data[DEVICE_NAME].keys()
|
||||||
|
for dtype_str in ['float16']])
|
||||||
|
def test_matmul(M, N, K, dtype_str):
|
||||||
|
if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100':
|
||||||
|
pytest.skip('Only test float32 & int8 on a100')
|
||||||
|
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
ref_gpu_util = matmul_data[(M, N, K)]['v100']
|
ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
|
||||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||||
ref_sm_clock = 1350
|
ref_sm_clock = sm_clocks[DEVICE_NAME]
|
||||||
max_gpu_perf = 1e-6 * 80 * 8 * 128 * cur_sm_clock
|
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
|
||||||
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
|
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
|
||||||
a = torch.randn((M, K), dtype=torch.float16, device='cuda')
|
if dtype == torch.int8:
|
||||||
b = torch.randn((K, N), dtype=torch.float16, device='cuda')
|
a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda')
|
||||||
|
b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda')
|
||||||
|
b = b.t() # only test row-col layout
|
||||||
|
else:
|
||||||
|
a = torch.randn((M, K), dtype=dtype, device='cuda')
|
||||||
|
b = torch.randn((K, N), dtype=dtype, device='cuda')
|
||||||
fn = lambda: triton.ops.matmul(a, b)
|
fn = lambda: triton.ops.matmul(a, b)
|
||||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
|
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
|
||||||
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
|
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
|
||||||
@@ -87,23 +124,34 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
|||||||
|
|
||||||
|
|
||||||
elementwise_data = {
|
elementwise_data = {
|
||||||
1024 * 16: {'v100': 0.0219},
|
'v100': {
|
||||||
1024 * 64: {'v100': 0.0791},
|
1024 * 16: 0.0219,
|
||||||
1024 * 256: {'v100': 0.243},
|
1024 * 64: 0.0791,
|
||||||
1024 * 1024: {'v100': 0.534},
|
1024 * 256: 0.243,
|
||||||
1024 * 4096: {'v100': 0.796},
|
1024 * 1024: 0.534,
|
||||||
1024 * 16384: {'v100': 0.905},
|
1024 * 4096: 0.796,
|
||||||
1024 * 65536: {'v100': 0.939},
|
1024 * 16384: 0.905,
|
||||||
|
1024 * 65536: 0.939,
|
||||||
|
},
|
||||||
|
'a100': {
|
||||||
|
1024 * 16: 0.008,
|
||||||
|
1024 * 64: 0.034,
|
||||||
|
1024 * 256: 0.114,
|
||||||
|
1024 * 1024: 0.315,
|
||||||
|
1024 * 4096: 0.580,
|
||||||
|
1024 * 16384: 0.782,
|
||||||
|
1024 * 65536: 0.850,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('N', elementwise_data.keys())
|
@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys())
|
||||||
def test_elementwise(N):
|
def test_elementwise(N):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
ref_gpu_util = elementwise_data[N]['v100']
|
ref_gpu_util = elementwise_data[DEVICE_NAME][N]
|
||||||
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
|
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
|
||||||
ref_mem_clock = 877
|
ref_mem_clock = mem_clocks[DEVICE_NAME]
|
||||||
max_gpu_perf = 512 * 2 * ref_mem_clock * 1e-3
|
max_gpu_perf = get_dram_gbps()
|
||||||
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz'
|
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz'
|
||||||
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
|
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
|
||||||
x = torch.randn_like(z)
|
x = torch.randn_like(z)
|
||||||
|
@@ -811,12 +811,12 @@ class Autotuner:
|
|||||||
# prune configs
|
# prune configs
|
||||||
if prune_configs_by:
|
if prune_configs_by:
|
||||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||||
if 'prune_num_stages_by' in prune_configs_by:
|
if 'early_config_prune' in prune_configs_by:
|
||||||
prune_num_stages_by = prune_configs_by['prune_num_stages_by']
|
early_config_prune = prune_configs_by['early_config_prune']
|
||||||
else:
|
else:
|
||||||
perf_model, top_k, prune_num_stages_by = None, None, None
|
perf_model, top_k, early_config_prune = None, None, None
|
||||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||||
self.prune_num_stages_by = prune_num_stages_by
|
self.early_config_prune = early_config_prune
|
||||||
|
|
||||||
def _bench(self, *args, config, **meta):
|
def _bench(self, *args, config, **meta):
|
||||||
# check for conflicts, i.e. meta-parameters both provided
|
# check for conflicts, i.e. meta-parameters both provided
|
||||||
@@ -844,8 +844,8 @@ class Autotuner:
|
|||||||
if key not in self.cache:
|
if key not in self.cache:
|
||||||
# prune configs
|
# prune configs
|
||||||
pruned_configs = self.configs
|
pruned_configs = self.configs
|
||||||
if self.prune_num_stages_by:
|
if self.early_config_prune:
|
||||||
pruned_configs = self.prune_num_stages_by(self.configs, self.nargs)
|
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||||
if self.perf_model:
|
if self.perf_model:
|
||||||
top_k = self.configs_top_k
|
top_k = self.configs_top_k
|
||||||
if isinstance(top_k, float) and top_k <= 1.0:
|
if isinstance(top_k, float) and top_k <= 1.0:
|
||||||
@@ -1096,7 +1096,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
|||||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||||
'top_k': number of configs to bench
|
'top_k': number of configs to bench
|
||||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
||||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||||
:type reset_to_zero: list[str]
|
:type reset_to_zero: list[str]
|
||||||
"""
|
"""
|
||||||
|
@@ -2,7 +2,7 @@ import torch
|
|||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from .matmul_perf_model import estimate_matmul_time, prune_num_stages
|
from .matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
|
|
||||||
def init_to_zero(name):
|
def init_to_zero(name):
|
||||||
@@ -27,7 +27,7 @@ def get_configs_io_bound():
|
|||||||
|
|
||||||
|
|
||||||
@triton.heuristics({
|
@triton.heuristics({
|
||||||
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
|
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||||
})
|
})
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
@@ -41,10 +41,20 @@ def get_configs_io_bound():
|
|||||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||||
|
# good for int8
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||||
] + get_configs_io_bound(),
|
] + get_configs_io_bound(),
|
||||||
key=['M', 'N', 'K'],
|
key=['M', 'N', 'K'],
|
||||||
prune_configs_by={
|
prune_configs_by={
|
||||||
'prune_num_stages_by': prune_num_stages,
|
'early_config_prune': early_config_prune,
|
||||||
'perf_model': estimate_matmul_time,
|
'perf_model': estimate_matmul_time,
|
||||||
'top_k': 10
|
'top_k': 10
|
||||||
},
|
},
|
||||||
@@ -55,7 +65,9 @@ def _kernel(A, B, C, M, N, K,
|
|||||||
stride_bk, stride_bn,
|
stride_bk, stride_bn,
|
||||||
stride_cm, stride_cn,
|
stride_cm, stride_cn,
|
||||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
|
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||||
|
ACC_TYPE: tl.constexpr
|
||||||
|
):
|
||||||
# matrix multiplication
|
# matrix multiplication
|
||||||
pid = tl.program_id(0)
|
pid = tl.program_id(0)
|
||||||
pid_z = tl.program_id(1)
|
pid_z = tl.program_id(1)
|
||||||
@@ -76,7 +88,7 @@ def _kernel(A, B, C, M, N, K,
|
|||||||
# pointers
|
# pointers
|
||||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||||
for k in range(K, 0, -BLOCK_K * SPLIT_K):
|
for k in range(K, 0, -BLOCK_K * SPLIT_K):
|
||||||
if EVEN_K:
|
if EVEN_K:
|
||||||
a = tl.load(A)
|
a = tl.load(A)
|
||||||
@@ -119,13 +131,15 @@ class _matmul(torch.autograd.Function):
|
|||||||
_, N = b.shape
|
_, N = b.shape
|
||||||
# allocates output
|
# allocates output
|
||||||
c = torch.empty((M, N), device=device, dtype=a.dtype)
|
c = torch.empty((M, N), device=device, dtype=a.dtype)
|
||||||
|
# accumulator types
|
||||||
|
ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||||
# launch kernel
|
# launch kernel
|
||||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||||
_kernel[grid](a, b, c, M, N, K,
|
_kernel[grid](a, b, c, M, N, K,
|
||||||
a.stride(0), a.stride(1),
|
a.stride(0), a.stride(1),
|
||||||
b.stride(0), b.stride(1),
|
b.stride(0), b.stride(1),
|
||||||
c.stride(0), c.stride(1),
|
c.stride(0), c.stride(1),
|
||||||
GROUP_M=8)
|
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@@ -4,20 +4,36 @@ import torch
|
|||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
|
||||||
|
|
||||||
|
|
||||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps):
|
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||||
''' return compute throughput in TOPS '''
|
''' return compute throughput in TOPS '''
|
||||||
total_warps = num_ctas * min(num_warps, 4)
|
total_warps = num_ctas * min(num_warps, 4)
|
||||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device)
|
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
|
||||||
return tflops
|
return tflops
|
||||||
|
|
||||||
|
|
||||||
|
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||||
|
''' return compute throughput in TOPS '''
|
||||||
|
total_warps = num_ctas * min(num_warps, 4)
|
||||||
|
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||||
|
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
|
||||||
|
return tflops
|
||||||
|
|
||||||
|
|
||||||
|
def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||||
|
cc = _triton.runtime.cc(backend, device)
|
||||||
|
if cc < 80 and dtype == torch.float32:
|
||||||
|
return get_simd_tflops()
|
||||||
|
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||||
|
|
||||||
|
|
||||||
def estimate_matmul_time(
|
def estimate_matmul_time(
|
||||||
# backend, device,
|
# backend, device,
|
||||||
num_warps, num_stages,
|
num_warps, num_stages,
|
||||||
|
A, B, C,
|
||||||
M, N, K,
|
M, N, K,
|
||||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
|
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
|
||||||
debug=False, **kwargs
|
debug=False, **kwargs
|
||||||
@@ -26,6 +42,8 @@ def estimate_matmul_time(
|
|||||||
= max(compute, loading) + store '''
|
= max(compute, loading) + store '''
|
||||||
backend = _triton.runtime.backend.CUDA
|
backend = _triton.runtime.backend.CUDA
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
|
dtype = A.dtype
|
||||||
|
dtsize = A.element_size()
|
||||||
|
|
||||||
num_cta_m = triton.cdiv(M, BLOCK_M)
|
num_cta_m = triton.cdiv(M, BLOCK_M)
|
||||||
num_cta_n = triton.cdiv(N, BLOCK_N)
|
num_cta_n = triton.cdiv(N, BLOCK_N)
|
||||||
@@ -37,7 +55,7 @@ def estimate_matmul_time(
|
|||||||
|
|
||||||
# time to compute
|
# time to compute
|
||||||
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
|
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
|
||||||
tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps)
|
tput = get_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||||
compute_ms = total_ops / tput
|
compute_ms = total_ops / tput
|
||||||
|
|
||||||
# time to load data
|
# time to load data
|
||||||
@@ -48,10 +66,10 @@ def estimate_matmul_time(
|
|||||||
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
|
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
|
||||||
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
|
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
|
||||||
# assume 80% of (following) loads are in L2 cache
|
# assume 80% of (following) loads are in L2 cache
|
||||||
load_a_dram = M * K * 2 * (1 + 0.2 * (num_cta_n - 1)) # assume dtype=float16 (size==2)
|
load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
|
||||||
load_a_l2 = M * K * 2 * 0.8 * (num_cta_n - 1)
|
load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
|
||||||
load_b_dram = N * K * 2 * (1 + 0.2 * (num_cta_m - 1))
|
load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
|
||||||
load_b_l2 = N * K * 2 * 0.8 * (num_cta_m - 1)
|
load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
|
||||||
# total
|
# total
|
||||||
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
|
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
|
||||||
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
|
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
|
||||||
@@ -60,7 +78,7 @@ def estimate_matmul_time(
|
|||||||
|
|
||||||
# estimate storing time
|
# estimate storing time
|
||||||
store_bw = dram_bw * 0.6 # :o
|
store_bw = dram_bw * 0.6 # :o
|
||||||
store_c_dram = M * N * 2 * SPLIT_K / (1024 * 1024) # MB
|
store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB
|
||||||
if SPLIT_K == 1:
|
if SPLIT_K == 1:
|
||||||
store_ms = store_c_dram / store_bw
|
store_ms = store_c_dram / store_bw
|
||||||
else:
|
else:
|
||||||
@@ -78,14 +96,28 @@ def estimate_matmul_time(
|
|||||||
return total_time_ms
|
return total_time_ms
|
||||||
|
|
||||||
|
|
||||||
def prune_num_stages(configs, named_args):
|
def early_config_prune(configs, named_args):
|
||||||
backend = _triton.runtime.backend.CUDA
|
backend = _triton.runtime.backend.CUDA
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
cc = _triton.runtime.cc(backend, device)
|
cc = _triton.runtime.cc(backend, device)
|
||||||
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
||||||
|
dtsize = named_args['A'].element_size()
|
||||||
|
dtype = named_args['A'].dtype
|
||||||
|
|
||||||
|
# 1. make sure we have enough smem
|
||||||
|
pruned_configs = []
|
||||||
|
for config in configs:
|
||||||
|
kw = config.kwargs
|
||||||
|
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
|
||||||
|
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
|
||||||
|
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||||
|
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
||||||
|
if required_shared_memory <= max_shared_memory:
|
||||||
|
pruned_configs.append(config)
|
||||||
|
configs = pruned_configs
|
||||||
|
|
||||||
# Some dtypes do not allow atomic_add
|
# Some dtypes do not allow atomic_add
|
||||||
if named_args['A'].dtype == torch.bfloat16:
|
if dtype not in [torch.float16, torch.float32]:
|
||||||
configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1]
|
configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1]
|
||||||
|
|
||||||
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
|
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
|
||||||
|
@@ -330,18 +330,56 @@ def get_dram_gbps(backend=None, device=None):
|
|||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
|
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
|
||||||
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
|
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
|
||||||
bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s
|
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
|
||||||
return bw_gbps
|
return bw_gbps
|
||||||
|
|
||||||
|
|
||||||
def get_max_tensorcore_tflops(backend, device):
|
def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clock_rate=None):
|
||||||
|
if not backend:
|
||||||
|
backend = _triton.runtime.backend.CUDA
|
||||||
|
if not device:
|
||||||
|
device = torch.cuda.current_device()
|
||||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
if not clock_rate:
|
||||||
# assume fp32 += fp16*fp16
|
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||||
cc = _triton.runtime.cc(backend, device)
|
cc = _triton.runtime.cc(backend, device)
|
||||||
if cc < 80:
|
if cc < 80:
|
||||||
|
assert dtype == torch.float16
|
||||||
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
||||||
else:
|
else:
|
||||||
ops_per_sub_core = 512
|
if dtype == torch.float32:
|
||||||
tflops = num_subcores * clock_rate * ops_per_sub_core / (1024 * 1024 * 1024)
|
ops_per_sub_core = 256
|
||||||
|
elif dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
ops_per_sub_core = 512
|
||||||
|
elif dtype == torch.int8:
|
||||||
|
ops_per_sub_core = 1024
|
||||||
|
else:
|
||||||
|
raise RuntimeError("dtype not supported")
|
||||||
|
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
||||||
|
return tflops
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None):
|
||||||
|
if not backend:
|
||||||
|
backend = _triton.runtime.backend.CUDA
|
||||||
|
if not device:
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||||
|
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||||
|
cc = _triton.runtime.cc(backend, device)
|
||||||
|
if cc < 80:
|
||||||
|
if dtype == torch.float32:
|
||||||
|
ops_per_sub_core = 32 # 2*16
|
||||||
|
elif dtype == torch.float16:
|
||||||
|
ops_per_sub_core = 64
|
||||||
|
else:
|
||||||
|
raise RuntimeError("dtype not supported")
|
||||||
|
else:
|
||||||
|
if dtype == torch.float32:
|
||||||
|
ops_per_sub_core = 32
|
||||||
|
elif dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
ops_per_sub_core = 64
|
||||||
|
else:
|
||||||
|
raise RuntimeError("dtype not supported")
|
||||||
|
tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
|
||||||
return tflops
|
return tflops
|
||||||
|
Reference in New Issue
Block a user