diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 39299a89a..1df3a0b49 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -6,6 +6,9 @@ import torch import triton import triton.language as tl +from triton.testing import get_dram_gbps, get_max_tensorcore_tflops + +DEVICE_NAME = 'v100' ####################### # Utilities @@ -25,42 +28,76 @@ def nvsmi(attrs): # Matrix Multiplication ####################### +sm_clocks = {'v100': 1350, 'a100': 1350} +mem_clocks = {'v100': 877, 'a100': 1215} + matmul_data = { - # square - (256, 256, 256): {'v100': 0.027}, - (512, 512, 512): {'v100': 0.158}, - (1024, 1024, 1024): {'v100': 0.466}, - (2048, 2048, 2048): {'v100': 0.680}, - (4096, 4096, 4096): {'v100': 0.831}, - (8192, 8192, 8192): {'v100': 0.849}, - # tall-skinny - (16, 1024, 1024): {'v100': 0.0128}, - (16, 4096, 4096): {'v100': 0.0883}, - (16, 8192, 8192): {'v100': 0.101}, - (64, 1024, 1024): {'v100': 0.073}, - (64, 4096, 4096): {'v100': 0.270}, - (64, 8192, 8192): {'v100': 0.360}, - (1024, 64, 1024): {'v100': 0.0692}, - (4096, 64, 4096): {'v100': 0.264}, - (8192, 64, 8192): {'v100': 0.323}, + 'v100': { + # square + (256, 256, 256): {'float16': 0.027}, + (512, 512, 512): {'float16': 0.158}, + (1024, 1024, 1024): {'float16': 0.466}, + (2048, 2048, 2048): {'float16': 0.680}, + (4096, 4096, 4096): {'float16': 0.831}, + (8192, 8192, 8192): {'float16': 0.849}, + # tall-skinny + (16, 1024, 1024): {'float16': 0.0128}, + (16, 4096, 4096): {'float16': 0.0883}, + (16, 8192, 8192): {'float16': 0.101}, + (64, 1024, 1024): {'float16': 0.073}, + (64, 4096, 4096): {'float16': 0.270}, + (64, 8192, 8192): {'float16': 0.459}, + (1024, 64, 1024): {'float16': 0.0692}, + (4096, 64, 4096): {'float16': 0.264}, + (8192, 64, 8192): {'float16': 0.452}, + }, + 'a100': { + (256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006}, + (512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030}, + (1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169}, + (2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385}, + (4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711}, + (8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860}, + # tall-skinny + (16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005}, + (16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259}, + (16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431}, + (64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169}, + (64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097}, + (64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174}, + (1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017}, + (4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102}, + (8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177}, + } # # deep reductions - # (64 , 64 , 16384) : {'v100': 0.}, - # (64 , 64 , 65536) : {'v100': 0.}, - # (256 , 256 , 8192 ) : {'v100': 0.}, - # (256 , 256 , 32768) : {'v100': 0.}, + # (64 , 64 , 16384) : {'a100': 0.}, + # (64 , 64 , 65536) : {'a100': 0.}, + # (256 , 256 , 8192 ) : {'a100': 0.}, + # (256 , 256 , 32768) : {'a100': 0.}, } -@pytest.mark.parametrize('M, N, K', matmul_data.keys()) -def test_matmul(M, N, K): +@pytest.mark.parametrize('M, N, K, dtype_str', + [(M, N, K, dtype_str) + for M, N, K in matmul_data[DEVICE_NAME].keys() + for dtype_str in ['float16']]) +def test_matmul(M, N, K, dtype_str): + if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100': + pytest.skip('Only test float32 & int8 on a100') + dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str] torch.manual_seed(0) - ref_gpu_util = matmul_data[(M, N, K)]['v100'] + ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str] cur_sm_clock = nvsmi(['clocks.current.sm'])[0] - ref_sm_clock = 1350 - max_gpu_perf = 1e-6 * 80 * 8 * 128 * cur_sm_clock + ref_sm_clock = sm_clocks[DEVICE_NAME] + max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3) assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz' - a = torch.randn((M, K), dtype=torch.float16, device='cuda') - b = torch.randn((K, N), dtype=torch.float16, device='cuda') + if dtype == torch.int8: + a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda') + b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda') + b = b.t() # only test row-col layout + else: + a = torch.randn((M, K), dtype=dtype, device='cuda') + b = torch.randn((K, N), dtype=dtype, device='cuda') fn = lambda: triton.ops.matmul(a, b) ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000) cur_gpu_perf = 2. * M * N * K / ms * 1e-9 @@ -87,23 +124,34 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements, elementwise_data = { - 1024 * 16: {'v100': 0.0219}, - 1024 * 64: {'v100': 0.0791}, - 1024 * 256: {'v100': 0.243}, - 1024 * 1024: {'v100': 0.534}, - 1024 * 4096: {'v100': 0.796}, - 1024 * 16384: {'v100': 0.905}, - 1024 * 65536: {'v100': 0.939}, + 'v100': { + 1024 * 16: 0.0219, + 1024 * 64: 0.0791, + 1024 * 256: 0.243, + 1024 * 1024: 0.534, + 1024 * 4096: 0.796, + 1024 * 16384: 0.905, + 1024 * 65536: 0.939, + }, + 'a100': { + 1024 * 16: 0.008, + 1024 * 64: 0.034, + 1024 * 256: 0.114, + 1024 * 1024: 0.315, + 1024 * 4096: 0.580, + 1024 * 16384: 0.782, + 1024 * 65536: 0.850, + } } -@pytest.mark.parametrize('N', elementwise_data.keys()) +@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys()) def test_elementwise(N): torch.manual_seed(0) - ref_gpu_util = elementwise_data[N]['v100'] + ref_gpu_util = elementwise_data[DEVICE_NAME][N] cur_mem_clock = nvsmi(['clocks.current.memory'])[0] - ref_mem_clock = 877 - max_gpu_perf = 512 * 2 * ref_mem_clock * 1e-3 + ref_mem_clock = mem_clocks[DEVICE_NAME] + max_gpu_perf = get_dram_gbps() assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz' z = torch.empty((N, ), dtype=torch.float16, device='cuda') x = torch.randn_like(z) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index dc2b375b8..894b3f1e3 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -811,12 +811,12 @@ class Autotuner: # prune configs if prune_configs_by: perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] - if 'prune_num_stages_by' in prune_configs_by: - prune_num_stages_by = prune_configs_by['prune_num_stages_by'] + if 'early_config_prune' in prune_configs_by: + early_config_prune = prune_configs_by['early_config_prune'] else: - perf_model, top_k, prune_num_stages_by = None, None, None + perf_model, top_k, early_config_prune = None, None, None self.perf_model, self.configs_top_k = perf_model, top_k - self.prune_num_stages_by = prune_num_stages_by + self.early_config_prune = early_config_prune def _bench(self, *args, config, **meta): # check for conflicts, i.e. meta-parameters both provided @@ -844,8 +844,8 @@ class Autotuner: if key not in self.cache: # prune configs pruned_configs = self.configs - if self.prune_num_stages_by: - pruned_configs = self.prune_num_stages_by(self.configs, self.nargs) + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) if self.perf_model: top_k = self.configs_top_k if isinstance(top_k, float) and top_k <= 1.0: @@ -1096,7 +1096,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): :param prune_configs_by: a dict of functions that are used to prune configs, fields: 'perf_model': performance model used to predicate running time with different configs, returns running time 'top_k': number of configs to bench - 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. :type reset_to_zero: list[str] """ diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 9466b9ba7..f1ac78849 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -2,7 +2,7 @@ import torch import triton import triton.language as tl -from .matmul_perf_model import estimate_matmul_time, prune_num_stages +from .matmul_perf_model import early_config_prune, estimate_matmul_time def init_to_zero(name): @@ -27,7 +27,7 @@ def get_configs_io_bound(): @triton.heuristics({ - 'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0, + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, }) @triton.autotune( configs=[ @@ -41,10 +41,20 @@ def get_configs_io_bound(): triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), ] + get_configs_io_bound(), key=['M', 'N', 'K'], prune_configs_by={ - 'prune_num_stages_by': prune_num_stages, + 'early_config_prune': early_config_prune, 'perf_model': estimate_matmul_time, 'top_k': 10 }, @@ -55,7 +65,9 @@ def _kernel(A, B, C, M, N, K, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr): + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) @@ -76,7 +88,7 @@ def _kernel(A, B, C, M, N, K, # pointers A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) for k in range(K, 0, -BLOCK_K * SPLIT_K): if EVEN_K: a = tl.load(A) @@ -119,13 +131,15 @@ class _matmul(torch.autograd.Function): _, N = b.shape # allocates output c = torch.empty((M, N), device=device, dtype=a.dtype) + # accumulator types + ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 # launch kernel grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) _kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - GROUP_M=8) + GROUP_M=8, ACC_TYPE=ACC_TYPE) return c @staticmethod diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index 98a85bc85..9c10b88d8 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -4,20 +4,36 @@ import torch import triton import triton._C.libtriton.triton as _triton -from triton.testing import get_dram_gbps, get_max_tensorcore_tflops +from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops -def get_tensorcore_tflops(backend, device, num_ctas, num_warps): +def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype): ''' return compute throughput in TOPS ''' total_warps = num_ctas * min(num_warps, 4) num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs - tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(backend, device) + tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device) return tflops +def get_simd_tflops(backend, device, num_ctas, num_warps, dtype): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device) + return tflops + + +def get_tflops(backend, device, num_ctas, num_warps, dtype): + cc = _triton.runtime.cc(backend, device) + if cc < 80 and dtype == torch.float32: + return get_simd_tflops() + return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype) + + def estimate_matmul_time( # backend, device, num_warps, num_stages, + A, B, C, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, debug=False, **kwargs @@ -26,6 +42,8 @@ def estimate_matmul_time( = max(compute, loading) + store ''' backend = _triton.runtime.backend.CUDA device = torch.cuda.current_device() + dtype = A.dtype + dtsize = A.element_size() num_cta_m = triton.cdiv(M, BLOCK_M) num_cta_n = triton.cdiv(N, BLOCK_N) @@ -37,7 +55,7 @@ def estimate_matmul_time( # time to compute total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS - tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps) + tput = get_tflops(backend, device, num_ctas, num_warps, dtype) compute_ms = total_ops / tput # time to load data @@ -48,10 +66,10 @@ def estimate_matmul_time( dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) # assume 80% of (following) loads are in L2 cache - load_a_dram = M * K * 2 * (1 + 0.2 * (num_cta_n - 1)) # assume dtype=float16 (size==2) - load_a_l2 = M * K * 2 * 0.8 * (num_cta_n - 1) - load_b_dram = N * K * 2 * (1 + 0.2 * (num_cta_m - 1)) - load_b_l2 = N * K * 2 * 0.8 * (num_cta_m - 1) + load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1)) + load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1) + load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1)) + load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1) # total total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) @@ -60,7 +78,7 @@ def estimate_matmul_time( # estimate storing time store_bw = dram_bw * 0.6 # :o - store_c_dram = M * N * 2 * SPLIT_K / (1024 * 1024) # MB + store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB if SPLIT_K == 1: store_ms = store_c_dram / store_bw else: @@ -78,14 +96,28 @@ def estimate_matmul_time( return total_time_ms -def prune_num_stages(configs, named_args): +def early_config_prune(configs, named_args): backend = _triton.runtime.backend.CUDA device = torch.cuda.current_device() cc = _triton.runtime.cc(backend, device) # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args['A'].element_size() + dtype = named_args['A'].dtype + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages + max_shared_memory = _triton.runtime.max_shared_memory(backend, device) + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + configs = pruned_configs # Some dtypes do not allow atomic_add - if named_args['A'].dtype == torch.bfloat16: + if dtype not in [torch.float16, torch.float32]: configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1] # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) diff --git a/python/triton/testing.py b/python/triton/testing.py index c720f64cf..fbca719ff 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -330,18 +330,56 @@ def get_dram_gbps(backend=None, device=None): device = torch.cuda.current_device() mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device) bus_width = _triton.runtime.global_memory_bus_width(backend, device) - bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s + bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s return bw_gbps -def get_max_tensorcore_tflops(backend, device): +def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clock_rate=None): + if not backend: + backend = _triton.runtime.backend.CUDA + if not device: + device = torch.cuda.current_device() num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs - clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz - # assume fp32 += fp16*fp16 + if not clock_rate: + clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz cc = _triton.runtime.cc(backend, device) if cc < 80: + assert dtype == torch.float16 ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores else: - ops_per_sub_core = 512 - tflops = num_subcores * clock_rate * ops_per_sub_core / (1024 * 1024 * 1024) + if dtype == torch.float32: + ops_per_sub_core = 256 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 512 + elif dtype == torch.int8: + ops_per_sub_core = 1024 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops + + +def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None): + if not backend: + backend = _triton.runtime.backend.CUDA + if not device: + device = torch.cuda.current_device() + num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz + cc = _triton.runtime.cc(backend, device) + if cc < 80: + if dtype == torch.float32: + ops_per_sub_core = 32 # 2*16 + elif dtype == torch.float16: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + else: + if dtype == torch.float32: + ops_per_sub_core = 32 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 return tflops