diff --git a/include/triton/codegen/transform/cts.h b/include/triton/codegen/transform/cts.h index dcc5f36c2..70fbc474b 100644 --- a/include/triton/codegen/transform/cts.h +++ b/include/triton/codegen/transform/cts.h @@ -33,4 +33,4 @@ private: } } -#endif +#endif \ No newline at end of file diff --git a/include/triton/codegen/transform/peephole.h b/include/triton/codegen/transform/peephole.h index 1b015fb41..0e1ed222e 100644 --- a/include/triton/codegen/transform/peephole.h +++ b/include/triton/codegen/transform/peephole.h @@ -35,7 +35,7 @@ private: bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder); bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder); bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder); - + public: peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {} void run(ir::module &mod); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 28ff9f3d6..699d22257 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -455,7 +455,7 @@ public: // masked load async class masked_load_async_inst: public load_inst { private: - std::string repr_impl() const { return "masked_load_async_async" + get_cache_modifier_repr(); } + std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); } masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next); @@ -728,12 +728,21 @@ public: class dot_inst: public builtin_inst { public: enum TransT { NoTrans, Trans }; + enum DataType { + FP8, FP16, BF16, TF32, FP32, + INT1, INT4, INT8, INT32, + UNKNOWN, + }; private: dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next); std::string repr_impl() const { return "dot"; } bool is_prefetched_ = false; + DataType C_type_ = DataType::FP32; + DataType A_type_ = DataType::FP16; + DataType B_type_ = DataType::FP16; + public: bool is_prefetched() const { return is_prefetched_; } void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; } diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 845e2e36d..d38d81a9c 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -52,7 +52,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC peephole.run(ir); dce.run(ir); pipeline.run(ir); - dce.run(ir); + dce.run(ir); disassociate.run(ir); dce.run(ir); align.run(ir); @@ -85,6 +85,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC allocation.run(ir); prefetch_s.run(ir); barriers.run(ir); + // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); return llvm; diff --git a/lib/codegen/transform/cts.cc b/lib/codegen/transform/cts.cc index 2641dad53..c223d2413 100644 --- a/lib/codegen/transform/cts.cc +++ b/lib/codegen/transform/cts.cc @@ -94,4 +94,4 @@ void cts::run(ir::module &mod) { } } -} +} \ No newline at end of file diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index cc7835bbc..bc249841b 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -327,4 +327,4 @@ void pipeline::run(ir::module &mod) { } } -} +} \ No newline at end of file diff --git a/python/src/triton.cc b/python/src/triton.cc index 01ad402aa..ce56d9c26 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -292,6 +292,16 @@ void init_triton_runtime(py::module &&m) { return bin; }); + m.def("cc", [](backend_t backend, uint64_t device) -> int { + if (backend == CUDA) { + CUdevice dev = (CUdevice)device; + int major = cuGetInfo(dev); + int minor = cuGetInfo(dev); + return major*10 + minor; + } + return -1; + }); + // query maximum shared memory m.def("max_shared_memory", [](backend_t backend, uint64_t device) { if (backend == HOST) @@ -303,6 +313,31 @@ void init_triton_runtime(py::module &&m) { return -1; }); + // query DRAM & L2 cache + m.def("memory_clock_rate", [](backend_t backend, uint64_t device) { + if (backend == CUDA) return cuGetInfo(device); + return -1; + }); + m.def("global_memory_bus_width", [](backend_t backend, uint64_t device) { + if (backend == CUDA) return cuGetInfo(device); + return -1; + }); + m.def("l2_cache_size", [](backend_t backend, uint64_t device) { + if (backend == CUDA) return cuGetInfo(device); + return -1; + }); + + // query clock rate (in kilohertz) + m.def("clock_rate", [](backend_t backend, uint64_t device) { + if (backend == CUDA) return cuGetInfo(device); + return -1; + }); + + m.def("num_sm", [](backend_t backend, uint64_t device) { + if (backend == CUDA) return cuGetInfo(device); + return -1; + }); + // enqueue m.def("enqueue", [](backend_t backend, uint64_t stream, uint64_t kernel, uint64_t grid_0, uint64_t grid_1, uint64_t grid_2, diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index eff21fdfd..ce93786b8 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -25,7 +25,7 @@ def nvsmi(attrs): matmul_data = { # square (256 , 256 , 256 ) : {'v100': 0.027}, - (512 , 512 , 512 ) : {'v100': 0.141}, + (512 , 512 , 512 ) : {'v100': 0.158}, (1024, 1024, 1024 ) : {'v100': 0.466}, (2048, 2048, 2048 ) : {'v100': 0.680}, (4096, 4096, 4096 ) : {'v100': 0.831}, @@ -35,10 +35,10 @@ matmul_data = { (16 , 4096, 4096 ) : {'v100': 0.0883}, (16 , 8192, 8192 ) : {'v100': 0.101}, (64 , 1024, 1024 ) : {'v100': 0.073}, - (64 , 4096, 4096 ) : {'v100': 0.228}, + (64 , 4096, 4096 ) : {'v100': 0.270}, (64 , 8192, 8192 ) : {'v100': 0.360}, (1024, 64 , 1024 ) : {'v100': 0.0692}, - (4096, 64 , 4096 ) : {'v100': 0.223}, + (4096, 64 , 4096 ) : {'v100': 0.264}, (8192, 64 , 8192 ) : {'v100': 0.323}, # # deep reductions # (64 , 64 , 16384) : {'v100': 0.}, diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index e1091eff7..2f6ddf3c1 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -17,6 +17,10 @@ import triton import triton._C.libtriton.triton as _triton from filelock import FileLock import dbm +import tempfile +from typing import Optional, Dict +import time + class CodeGenerator(ast.NodeVisitor): @@ -508,6 +512,7 @@ class LoadedBinary: device) self.bin = bin self.asm = bin.asm + self.sass = '' self.module = module self.kernel = kernel self.device = device @@ -519,6 +524,19 @@ class LoadedBinary: self.bin.num_warps * 32, 1, 1, args, self.bin.shared_mem) + def get_sass(self, fun=None): + if self.sass: + return self.sass + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(self.asm['cubin']) + self.sass = extract(path, fun) + finally: + os.remove(path) + self.asm['sass'] = self.sass + return self.sass + class CompilationError(Exception): def __init__(self, src, node): @@ -530,8 +548,8 @@ class CompilationError(Exception): class OutOfResources(Exception): def __init__(self, required, limit, name): - self.message = f'out of resource: {name}'\ - f'Required: {required}'\ + self.message = f'out of resource: {name}, '\ + f'Required: {required}, '\ f'Hardware limit: {limit}' super().__init__(self.message) self.args = (required, limit, name) @@ -727,7 +745,13 @@ class Launcher: class Autotuner: - def __init__(self, kernel, arg_names, configs, key, reset_to_zero): + def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict=None): + ''' + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + ''' if not configs: self.configs = [Config(dict(), num_warps=4, num_stages=2)] else: @@ -744,7 +768,16 @@ class Autotuner: args[i].zero_() self.hook = _hook self.arg_names = arg_names - + # prune configs + if prune_configs_by: + perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] + if 'prune_num_stages_by' in prune_configs_by: + prune_num_stages_by = prune_configs_by['prune_num_stages_by'] + else: + perf_model, top_k, prune_num_stages_by = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.prune_num_stages_by = prune_num_stages_by + def _bench(self, *args, config, **meta): # check for conflicts, i.e. meta-parameters both provided # as kwargs and by the autotuner @@ -768,13 +801,29 @@ class Autotuner: if len(self.configs) > 1: key = tuple([args[i] for i in self.key_idx]) if key not in self.cache: + # prune configs + pruned_configs = self.configs + if self.prune_num_stages_by: + pruned_configs = self.prune_num_stages_by(self.configs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} + pruned_configs = sorted(est_timing.keys(), key=lambda x:est_timing[x])[:top_k] + bench_start = time.time() timings = {config: self._bench(*args, config=config, **kwargs) \ - for config in self.configs} + for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start self.cache[key] = builtins.min(timings, key=timings.get) self.hook(args) + self.configs_timings = timings config = self.cache[key] else: config = self.configs[0] + self.best_config = config if config.pre_hook != None: config.pre_hook(self.nargs) return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) @@ -832,6 +881,8 @@ class DependenciesFinder(ast.NodeVisitor): module = inspect.getmodule(func) if module and module.__name__.startswith('triton.'): return + if inspect.isbuiltin(func): + return if not hasattr(func, 'hash'): src = textwrap.dedent(inspect.getsource(func)) tree = ast.parse(src) @@ -957,8 +1008,16 @@ class Config: self.num_stages = num_stages self.pre_hook = pre_hook + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f'{k}: {v}') + res.append(f'num_warps: {self.num_warps}') + res.append(f'num_stages: {self.num_stages}') + return ', '.join(res) -def autotune(configs, key, reset_to_zero=None): + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): """ Decorator for auto-tuning a :code:`triton.jit`'d function. @@ -985,12 +1044,16 @@ def autotune(configs, key, reset_to_zero=None): :type configs: list[triton.Config] :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. :type reset_to_zero: list[str] """ def decorator(fn): def wrapper(kernel): - return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero) + return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by) fn.kernel_decorators.append(wrapper) return fn @@ -1023,7 +1086,6 @@ def heuristics(values): assert v not in meta meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta}) return kernel(*args, **meta) - return fun fn.kernel_decorators.append(wrapper) diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index ae404b8d6..8b7299a8b 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -1,15 +1,33 @@ import torch import triton.language as tl import triton +from .matmul_perf_model import * def init_to_zero(name): return lambda nargs: nargs[name].zero_() +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + @triton.heuristics({ 'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0, }) @triton.autotune( configs=[ + # basic configs for compute-bound matmuls triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), @@ -19,17 +37,13 @@ def init_to_zero(name): triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')), - triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')), - ], + ] + get_configs_io_bound(), key=['M', 'N', 'K'], + prune_configs_by={ + 'prune_num_stages_by' : prune_num_stages, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, ) @triton.jit def _kernel(A, B, C, M, N, K, diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py new file mode 100644 index 000000000..16667a7b1 --- /dev/null +++ b/python/triton/ops/matmul_perf_model.py @@ -0,0 +1,116 @@ +import torch +import triton +import triton._C.libtriton.triton as _triton +from triton.testing import get_dram_gbps, get_max_tensorcore_tflops +import heapq + +def get_tensorcore_tflops(backend, device, num_ctas, num_warps): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + tflops = min(num_subcores, total_warps)/num_subcores * get_max_tensorcore_tflops(backend, device) + return tflops + +def estimate_matmul_time( + # backend, device, + num_warps, num_stages, + M, N, K, + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, + debug=False, **kwargs +): + ''' return estimated running time in ms + = max(compute, loading) + store ''' + backend = _triton.runtime.backend.CUDA + device = torch.cuda.current_device() + + num_cta_m = triton.cdiv(M, BLOCK_M) + num_cta_n = triton.cdiv(N, BLOCK_N) + num_cta_k = SPLIT_K + num_ctas = num_cta_m * num_cta_n * num_cta_k + + # If the input is smaller than the block size + M, N = max(M, BLOCK_M), max(N, BLOCK_N) + + # time to compute + total_ops = 2*M*N*K / (1024*1024*1024) # GOPS + tput = get_tensorcore_tflops(backend, device, num_ctas, num_warps) + compute_ms = total_ops / tput + + # time to load data + num_sm = _triton.runtime.num_sm(backend, device) + active_cta_ratio = min(1, num_ctas/num_sm) + active_cta_ratio_bw1 = min(1, num_ctas/32) # 32 active ctas are enough to saturate + active_cta_ratio_bw2 = max(min(1, (num_ctas-32)/(108-32)), 0) # 32-108, remaining 5% + dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1*0.95 + active_cta_ratio_bw2*0.05) # in GB/s + l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) + # assume 80% of (following) loads are in L2 cache + load_a_dram = M*K*2*(1+0.2*(num_cta_n-1)) # assume dtype=float16 (size==2) + load_a_l2 = M*K*2*0.8*(num_cta_n-1) + load_b_dram = N*K*2*(1+0.2*(num_cta_m-1)) + load_b_l2 = N*K*2*0.8*(num_cta_m-1) + # total + total_dram = (load_a_dram + load_b_dram) / (1024*1024) # MB + total_l2 = (load_a_l2 + load_b_l2) / (1024*1024) + # loading time in ms + load_ms = total_dram/dram_bw + total_l2/l2_bw + + # estimate storing time + store_bw = dram_bw * 0.6 # :o + store_c_dram = M*N*2*SPLIT_K / (1024*1024) # MB + if SPLIT_K == 1: + store_ms = store_c_dram /store_bw + else: + reduce_bw = store_bw + store_ms = store_c_dram/reduce_bw + # c.zero_() + zero_ms = M*N*2/(1024*1024)/store_bw + store_ms += zero_ms + + total_time_ms = max(compute_ms, load_ms) + store_ms + if debug: + print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, ' + f'loading time: {load_ms}ms, store time: {store_ms}ms, ' + f'Activate CTAs: {active_cta_ratio*100}%') + return total_time_ms + +def prune_num_stages(configs): + backend = _triton.runtime.backend.CUDA + device = torch.cuda.current_device() + cc = _triton.runtime.cc(backend, device) + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + + # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages + + key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] + + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k + if cc >= 80: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16*8*16) + mma_cycles = mmas/min(4, num_warps) * 8 + + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency/mma_cycles + + # nearest stages, prefer large #stages + nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) \ + if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs \ No newline at end of file diff --git a/python/triton/testing.py b/python/triton/testing.py index 051a8f378..f274e808f 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,5 +1,6 @@ import torch import os +import triton._C.libtriton.triton as _triton from .code_gen import OutOfResources import subprocess import sys @@ -320,3 +321,27 @@ def perf_report(benchmarks): """ wrapper = lambda fn: Mark(fn, benchmarks) return wrapper + +def get_dram_gbps(backend=None, device=None): + ''' return DRAM bandwidth in GB/s ''' + # assert backend == CUDA + if not backend: + backend = _triton.runtime.backend.CUDA + if not device: + device = torch.cuda.current_device() + mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device) + bus_width = _triton.runtime.global_memory_bus_width(backend, device) + bw_gbps = mem_clock_khz * bus_width * 2 // 1024 // 1024 // 8 # In GB/s + return bw_gbps + +def get_max_tensorcore_tflops(backend, device): + num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs + clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz + # assume fp32 += fp16*fp16 + cc = _triton.runtime.cc(backend, device) + if cc < 80: + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + else: + ops_per_sub_core = 512 + tflops = num_subcores * clock_rate * ops_per_sub_core / (1024*1024*1024) + return tflops \ No newline at end of file