diff --git a/docs/index.rst b/docs/index.rst index c722d85bd..a73c52569 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,7 +1,7 @@ Welcome to Triton's documentation! ================================== -Triton is an imperative language and compiler for parallel programming. It aims to provide a programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware. +Triton is an language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware. Getting Started --------------- @@ -17,18 +17,22 @@ Getting Started getting-started/installation getting-started/tutorials/index -Language Reference +Python API ------------------- -- Checkout the :doc:`Python API Documentation ` +- :doc:`triton ` +- :doc:`triton.language ` +- :doc:`triton.testing ` .. toctree:: :maxdepth: 1 - :caption: Language Reference + :caption: Python API :hidden: - language-reference/python-api/index + python-api/triton + python-api/triton.language + python-api/triton.testing Going Further diff --git a/docs/language-reference/python-api/index.rst b/docs/python-api/triton.language.rst similarity index 94% rename from docs/language-reference/python-api/index.rst rename to docs/python-api/triton.language.rst index 152b7bd2b..dd1404c69 100644 --- a/docs/language-reference/python-api/index.rst +++ b/docs/python-api/triton.language.rst @@ -1,7 +1,7 @@ -Python API -=========== +triton.language +================ -.. currentmodule:: triton +.. currentmodule:: triton.language Programming Model diff --git a/docs/python-api/triton.rst b/docs/python-api/triton.rst new file mode 100644 index 000000000..2db99da77 --- /dev/null +++ b/docs/python-api/triton.rst @@ -0,0 +1,10 @@ +triton +======== + +.. currentmodule:: triton + +.. autosummary:: + :toctree: generated + :nosignatures: + + jit \ No newline at end of file diff --git a/docs/python-api/triton.testing.rst b/docs/python-api/triton.testing.rst new file mode 100644 index 000000000..c45e9981b --- /dev/null +++ b/docs/python-api/triton.testing.rst @@ -0,0 +1,12 @@ +triton.testing +================ + +.. currentmodule:: triton.testing + +.. autosummary:: + :toctree: generated + :nosignatures: + + do_bench + Benchmark + perf_report \ No newline at end of file diff --git a/python/bench/bench_blocksparse.py b/python/bench/bench_blocksparse.py index 313cef108..6954aa315 100644 --- a/python/bench/bench_blocksparse.py +++ b/python/bench/bench_blocksparse.py @@ -10,9 +10,9 @@ square_confs = [ triton.testing.Benchmark( x_names = ['M', 'N', 'K'], x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], - y_name = 'block', - y_vals = [16, 32, 64], - y_lines = ['Block16', 'Block32', 'Block64'], + line_arg = 'block', + line_vals = [16, 32, 64], + line_names = ['Block16', 'Block32', 'Block64'], ylabel = 'TFLOPS', plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', args = {'layout_mode': layout_mode, 'op_mode': op_mode, @@ -60,9 +60,9 @@ square_confs = [ triton.testing.Benchmark( x_names = ['M', 'N'], x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], - y_name = 'block', - y_vals = [16, 32, 64], - y_lines = ['Block16', 'Block32', 'Block64'], + line_arg = 'block', + line_vals = [16, 32, 64], + line_names = ['Block16', 'Block32', 'Block64'], ylabel = 'GBPS', plot_name = f'{layout_mode}-square', args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'} diff --git a/python/bench/bench_cross_entropy.py b/python/bench/bench_cross_entropy.py index f8ac18a1d..2c4d61d9c 100644 --- a/python/bench/bench_cross_entropy.py +++ b/python/bench/bench_cross_entropy.py @@ -5,9 +5,9 @@ confs = [ triton.testing.Benchmark( x_names = ['N'], x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192], - y_name = 'provider', - y_vals = ['triton', 'torch'], - y_lines = ['Triton', 'Torch'], + line_arg = 'provider', + line_vals = ['triton', 'torch'], + line_names = ['Triton', 'Torch'], ylabel = 'GBPS', plot_name = f'{mode}-2048', args = {'M': 2048, 'dtype': torch.float16, 'mode': mode} diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py index 3c8c7907a..7e912be31 100644 --- a/python/bench/bench_matmul.py +++ b/python/bench/bench_matmul.py @@ -16,9 +16,9 @@ square_confs = [ triton.testing.Benchmark( x_names=["M", "N", "K"], x_vals=rounded_linspace(512, 8192, 32, 128), - y_name="provider", - y_vals=["cublas", "triton", "cutlass"], - y_lines=["cuBLAS", "Triton", "CUTLASS"], + line_arg="provider", + line_vals=["cublas", "triton", "cutlass"], + line_names=["cuBLAS", "Triton", "CUTLASS"], ylabel="TFLOPS", plot_name=f"matmul-square-{nt[AT]}{nt[BT]}", args={"AT": AT, "BT": BT, "dtype": torch.float16}, @@ -30,9 +30,9 @@ transformer_confs = [ triton.testing.Benchmark( x_names=[x], x_vals = rounded_linspace(NK//16, NK, 32, 128), - y_name="provider", - y_vals=["cublas", "triton", "cutlass"], - y_lines=["cuBLAS", "Triton", "CUTLASS"], + line_arg="provider", + line_vals=["cublas", "triton", "cutlass"], + line_names=["cuBLAS", "Triton", "CUTLASS"], ylabel="TFLOPS", plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}", args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16} diff --git a/python/test/test_code_gen.py b/python/test/test_language.py similarity index 91% rename from python/test/test_code_gen.py rename to python/test/test_language.py index 140105c27..fbf5af3b1 100644 --- a/python/test/test_code_gen.py +++ b/python/test/test_language.py @@ -1,5 +1,6 @@ import torch import triton +import triton.language as tl import copy import pytest import ast @@ -37,10 +38,10 @@ def _test_unary(dtype_x, expr, device='cuda'): # define the kernel / launch-grid @triton.jit def kernel(Z, X, **meta): - off = triton.arange(0, meta['SIZE']) - x = triton.load(X + off) + off = tl.arange(0, meta['SIZE']) + x = tl.load(X + off) z = GENERATE_TEST_HERE - triton.store(Z + off, z) + tl.store(Z + off, z) kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) # inputs @@ -59,11 +60,11 @@ def _test_binary(dtype_x, dtype_y, expr, device='cuda'): # define the kernel / launch-grid @triton.jit def kernel(Z, X, Y, **meta): - off = triton.arange(0, meta['SIZE']) - x = triton.load(X + off) - y = triton.load(Y + off) + off = tl.arange(0, meta['SIZE']) + x = tl.load(X + off) + y = tl.load(Y + off) z = GENERATE_TEST_HERE - triton.store(Z + off, z) + tl.store(Z + off, z) kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) # inputs @@ -144,7 +145,7 @@ def make_ptr_str(name, shape): stride = 1 for i in reversed(range(rank)): idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) - offsets += [f'triton.arange(0, {shape[i]})[{idx}]*{stride}'] + offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}'] stride *= shape[i] return f"{name} + {' + '.join(offsets)}" @@ -164,11 +165,11 @@ def test_index1d(expr, device='cuda'): @triton.jit def kernel(Z, X, **meta): SIZE = meta['SIZE'] - m = triton.arange(0, SIZE) - n = triton.arange(0, SIZE) - x = triton.load(X_PTR_EXPR) + m = tl.arange(0, SIZE) + n = tl.arange(0, SIZE) + x = tl.load(X_PTR_EXPR) z = GENERATE_TEST_HERE - triton.store(Z_PTR_EXPR, z) + tl.store(Z_PTR_EXPR, z) to_replace = { 'X_PTR_EXPR': make_ptr_str('X', shape_x), diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 663a9c2df..9c1df2839 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -2,9 +2,9 @@ # or pybind11 shows `munmap_chunk(): invalid pointer` import torch # submodules -from .code_gen import jit, autotune, heuristics, Config, Autotuner -from .core import * +from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner +from . import language from . import code_gen from . import testing from . import ops diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index a357eb5b0..27ae177f4 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -26,21 +26,21 @@ class CodeGenerator(ast.NodeVisitor): ret = self.builtins[name] else: raise ValueError(f'{name} is not defined') - if isinstance(ret, triton.block): + if isinstance(ret, triton.language.block): handle = self.module.get_value(name) - return triton.block(handle) + return triton.language.block(handle) return ret def set_value(self, name, value): if isinstance(value, _triton.ir.value): - value = triton.block(value) - if isinstance(value, triton.block): + value = triton.language.block(value) + if isinstance(value, triton.language.block): self.module.set_value(name, value.handle) self.module.scope.set_type(name, value.handle.type) self.lscope[name] = value def is_triton_object(self, value): - return isinstance(value, triton.block) + return isinstance(value, triton.language.block) def visit_compound_statement(self, stmts, add_scope=False): if add_scope: @@ -63,7 +63,14 @@ class CodeGenerator(ast.NodeVisitor): self.constants = constants self.kwargs = kwargs self.last_node = None - self.builtins = {'range': range, 'min': triton.minimum, 'float': float, 'int': int, 'print': print, 'getattr': getattr} + self.builtins = { + 'range': range, + 'min': triton.language.minimum, + 'float': float, + 'int': int, + 'print': print, + 'getattr': getattr, + } def visit_Module(self, node): self.module.add_new_scope() @@ -303,7 +310,7 @@ class CodeGenerator(ast.NodeVisitor): pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]]) neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]]) pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)]) - build_cond = lambda: triton.where(self.visit(pos_step_node),\ + build_cond = lambda: triton.language.where(self.visit(pos_step_node),\ self.visit(pos_cond_node),\ self.visit(neg_cond_node),\ builder=self.builder) @@ -359,7 +366,7 @@ class CodeGenerator(ast.NodeVisitor): if isinstance(fn, JITFunction): return fn(*args, generator=self, **kws) if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ - sys.modules[fn.__module__] is triton.core: + sys.modules[fn.__module__] is triton.language: return fn(*args, builder=self.builder, **kws) return fn(*args, **kws) @@ -613,6 +620,11 @@ class JITFunction: raise e raise CompilationError(self.src, node, e) + def __setattr__(self, name, value): + if name == 'kernel_decorators': + self.kernel = None + super(JITFunction, self).__setattr__(name, value) + def _init_kernel(self): if self.kernel is None: self.kernel = Kernel(self) @@ -659,4 +671,23 @@ def heuristics(values): def jit(fn): + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * objects within the triton.language package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ return JITFunction(fn) + + +def cdiv(x, y): + return (x + y - 1) // y diff --git a/python/triton/core.py b/python/triton/language.py similarity index 98% rename from python/triton/core.py rename to python/triton/language.py index a84435df5..a9d8b75db 100644 --- a/python/triton/core.py +++ b/python/triton/language.py @@ -1,6 +1,6 @@ +import triton from triton._C.libtriton.triton import ir from triton._C.libtriton.triton import frontend -import triton from functools import wraps @@ -25,7 +25,7 @@ def _patch(fn): if x.type.is_void(): return None return block(x) - return x + return tl def wrapper(*args, **kwargs): builder = args[-1] @@ -547,7 +547,7 @@ def minimum(x, y): :param other: the second input block :type other: Block """ - return triton.where(x < y, x, y) + return triton.language.where(x < y, x, y) @triton.jit @@ -560,7 +560,7 @@ def maximum(x, y): :param other: the second input block :type other: Block """ - return triton.where(x > y, x, y) + return triton.language.where(x > y, x, y) @triton.jit @@ -571,7 +571,7 @@ def sigmoid(x): :param x: the input block :type x: Block """ - return 1 / (1 + np.exp(-x)) + return 1 / (1 + triton.language.exp(-x)) @triton.jit @@ -582,9 +582,9 @@ def softmax(x): :param x: the input block :type x: Block """ - z = x - triton.max(x, 0) - num = triton.exp(z) - den = triton.sum(num, 0) + z = x - triton.language.max(x, 0) + num = triton.language.exp(z) + den = triton.language.sum(num, 0) return num / den @@ -596,8 +596,4 @@ def ravel(x): :param x: the input block :type x: Block """ - return triton.reshape(x, [x.type.numel]) - - -def cdiv(x, y): - return (x + y - 1) // y + return triton.language.reshape(x, [x.type.numel]) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index f8e7a0388..926239b40 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -1,4 +1,5 @@ import triton +import triton.language as tl import triton._C.libtriton as libtriton import torch @@ -16,21 +17,21 @@ def _kernel( #------------# #- Prologue -# #------------# - pid0 = triton.program_id(0) - pid1 = triton.program_id(1) - pidz = triton.program_id(2) + pid0 = tl.program_id(0) + pid1 = tl.program_id(1) + pidz = tl.program_id(2) if meta['SDD']: pid1 = pid1 + SDD_off_width - blockidm = triton.arange(0, TM) // BLOCK - blockidn = triton.arange(0, TN) // BLOCK + blockidm = tl.arange(0, TM) // BLOCK + blockidn = tl.arange(0, TN) // BLOCK offlutm = blockidm * (TN // BLOCK) * 4 offlutn = blockidn * 4 header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4 - z = triton.load(header + 0) - i = triton.load(header + 1 + offlutm) - j = triton.load(header + 2 + offlutn) + z = tl.load(header + 0) + i = tl.load(header + 1 + offlutm) + j = tl.load(header + 2 + offlutn) AS1 = SDD_K // TZ - lockid = triton.where(TZ > 1, 1, 0) + lockid = tl.where(TZ > 1, 1, 0) offka = pid0 * AS1 offkb = pid0 * AS1 offmc = 0 @@ -41,16 +42,16 @@ def _kernel( offhc = 0 offha = z offhb = z - ram = i * BLOCK + (triton.arange(0, TM) % BLOCK) - rbn = j * BLOCK + (triton.arange(0, TN) % BLOCK) + ram = i * BLOCK + (tl.arange(0, TM) % BLOCK) + rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK) else: header = lut + pid0 * 6 - offset = triton.load(header + 0) - AS1 = triton.load(header + 1) - column = triton.load(header + 2) - depth = triton.load(header + 3) - lockid = triton.load(header + 4) - maxid = triton.load(header + 5) + offset = tl.load(header + 0) + AS1 = tl.load(header + 1) + column = tl.load(header + 2) + depth = tl.load(header + 3) + lockid = tl.load(header + 4) + maxid = tl.load(header + 5) pinc = lut + offset offhc = depth if meta['DSD']: @@ -60,14 +61,14 @@ def _kernel( offpc = 0 # dense input offset offnb = pid1 * TN - offkb = triton.load(pinc) - offkb = triton.multiple_of(offkb, 8) # compiler hint + offkb = tl.load(pinc) + offkb = tl.multiple_of(offkb, 8) # compiler hint offpb = 0 # sparse input offset offma = 0 offka = 0 - offpa = triton.load(pinc + 1) - offpa = triton.multiple_of(offpa, 8) # compiler hint + offpa = tl.load(pinc + 1) + offpa = tl.multiple_of(offpa, 8) # compiler hint offpa = offpa * BLOCK * BLOCK offha = 0 offhb = depth @@ -78,23 +79,23 @@ def _kernel( offpc = 0 # dense input offset offma = pid1 * TM - offka = triton.load(pinc) - offka = triton.multiple_of(offka, 8) # compiler hint + offka = tl.load(pinc) + offka = tl.multiple_of(offka, 8) # compiler hint offpa = 0 # sparse input offset offnb = 0 offkb = 0 - offpb = triton.load(pinc + 1) - offpb = triton.multiple_of(offpb, 8) # compiler hint + offpb = tl.load(pinc + 1) + offpb = tl.multiple_of(offpb, 8) # compiler hint offpb = offpb * BLOCK * BLOCK offha = depth offhb = 0 - ram = offma + triton.arange(0, TM) - rbn = offnb + triton.arange(0, TN) + ram = offma + tl.arange(0, TM) + rbn = offnb + tl.arange(0, TN) # initialize a, b pointers - rka = offka + triton.arange(0, TK) - rkb = offkb + triton.arange(0, TK) + rka = offka + tl.arange(0, TK) + rkb = offkb + tl.arange(0, TK) pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb if meta['DDS']: @@ -105,31 +106,31 @@ def _kernel( checkbn = rbn[None, :] < DS0 else: checkbn = AS1 > 0 - a = triton.load(pa, mask=checkam, other=0.) - b = triton.load(pb, mask=checkbn, other=0.) + a = tl.load(pa, mask=checkam, other=0.) + b = tl.load(pb, mask=checkbn, other=0.) ## ---------------- ## ## Inner Loop ## ## ---------------- ## - acc = triton.zeros((TM, TN), dtype=triton.float32) + acc = tl.zeros((TM, TN), dtype=tl.float32) for k in range(AS1, 0, -TK): - acc += triton.dot(a, b) + acc += tl.dot(a, b) if meta['SDD']: inc_a = TK * stride_ka inc_b = TK * stride_kb else: pinc += 2 if meta['DSD']: - inc_b = triton.load(pinc) - inc_a = triton.load(pinc + 1) - inc_b = triton.multiple_of(inc_b, 8) - inc_a = triton.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_a = tl.load(pinc + 1) + inc_b = tl.multiple_of(inc_b, 8) + inc_a = tl.multiple_of(inc_a, 8) inc_b = inc_b * stride_kb if meta['DDS']: - inc_a = triton.load(pinc) - inc_b = triton.load(pinc + 1) - inc_a = triton.multiple_of(inc_a, 8) - inc_b = triton.multiple_of(inc_b, 8) + inc_a = tl.load(pinc) + inc_b = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.multiple_of(inc_b, 8) inc_a = inc_a * stride_ka pa += inc_a pb += inc_b @@ -138,24 +139,24 @@ def _kernel( checkbk = k > TK checka = checkam & checkak checkb = checkbn & checkbk - a = triton.load(pa, mask=checka) - b = triton.load(pb, mask=checkb) + a = tl.load(pa, mask=checka) + b = tl.load(pb, mask=checkb) c = acc.to(C.dtype.element_ty) if meta['SDD']: checkc = True - rr_blockidm = triton.arange(0, TM) // BLOCK - rr_blockidn = triton.arange(0, TN) // BLOCK + rr_blockidm = tl.arange(0, TM) // BLOCK + rr_blockidn = tl.arange(0, TN) // BLOCK rr_offlutm = rr_blockidm * (TN // BLOCK) * 4 rr_offlutn = rr_blockidn * 4 off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :] - bkid = triton.load(header + off_bkid) + bkid = tl.load(header + off_bkid) offpc = bkid * BLOCK * BLOCK - rcm = triton.arange(0, TM) % BLOCK - rcn = triton.arange(0, TN) % BLOCK + rcm = tl.arange(0, TM) % BLOCK + rcn = tl.arange(0, TN) % BLOCK else: - rcm = offmc + triton.arange(0, TM) - rcn = offnc + triton.arange(0, TN) + rcm = offmc + tl.arange(0, TM) + rcn = offnc + tl.arange(0, TN) if meta['DSD']: checkc = rcn[None, :] < DS0 if meta['DDS']: @@ -164,21 +165,21 @@ def _kernel( pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc # write-back directly if lockid == 0: - triton.store(pc, c, mask=checkc) + tl.store(pc, c, mask=checkc) # accumulate partial results using spin-locks else: - plock = locks + triton.program_id(2) * nlocks * triton.num_programs(1) + triton.program_id(1) * nlocks + lockid - 1 - pcount = plock + triton.num_programs(2) * triton.num_programs(1) * nlocks - while triton.atomic_cas(plock, 0, 1) == 1: + plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1 + pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks + while tl.atomic_cas(plock, 0, 1) == 1: pass - count = triton.load(pcount) + count = tl.load(pcount) if count == 0: - triton.store(pc, c, mask=checkc) + tl.store(pc, c, mask=checkc) else: - d = triton.load(pc, mask=checkc) - triton.store(pc, d + c, mask=checkc) - triton.atomic_xchg(pcount, (count + 1) % maxid) - triton.atomic_xchg(plock, 0) + d = tl.load(pc, mask=checkc) + tl.store(pc, d + c, mask=checkc) + tl.atomic_xchg(pcount, (count + 1) % maxid) + tl.atomic_xchg(plock, 0) ############## diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index 55d86bbc0..e7fbe1fd8 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -1,3 +1,4 @@ +import triton.language as tl import triton import torch import os @@ -31,86 +32,86 @@ def _forward( ): TN = meta['TN'] BLOCK = meta['BLOCK'] - pidhm = triton.program_id(0) - pidz = triton.program_id(1) + pidhm = tl.program_id(0) + pidz = tl.program_id(1) # create index ranges rxm = pidhm % BLOCK rbm = pidhm // BLOCK - rxn = triton.arange(0, TN) % BLOCK - rbn = triton.arange(0, TN) // BLOCK + rxn = tl.arange(0, TN) % BLOCK + rbn = tl.arange(0, TN) // BLOCK # extract information from LUT header = LUT + rbm * 2 - size = triton.load(header + 0) - offset = triton.load(header + 1) + size = tl.load(header + 0) + offset = tl.load(header + 1) check = rbn < size - rbmn = triton.where(check, rbn, size - 1) + rbmn = tl.where(check, rbn, size - 1) # block id and column id - blockid = triton.load(LUT + offset + rbmn * 4 + 0) - columnid = triton.load(LUT + offset + rbmn * 4 + 1) - rowid = triton.load(LUT + offset + rbmn * 4 + 2) - headid = triton.load(LUT + offset + rbmn * 4 + 3) + blockid = tl.load(LUT + offset + rbmn * 4 + 0) + columnid = tl.load(LUT + offset + rbmn * 4 + 1) + rowid = tl.load(LUT + offset + rbmn * 4 + 2) + headid = tl.load(LUT + offset + rbmn * 4 + 3) # pointers to X px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn - x = triton.load(px, mask=check, other=-float('inf')) - x = x.to(triton.float32) + x = tl.load(px, mask=check, other=-float('inf')) + x = x.to(tl.float32) # apply scale if meta['APPLY_SCALE']: x = x * scale # apply RPE if meta['APPLY_RPE']: prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn - rpe = triton.load(prpe, mask=check, other=0) + rpe = tl.load(prpe, mask=check, other=0) x = x + rpe # apply key-padding mask if meta['APPLY_KP_MASK']: pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn - kp_m = triton.load(pkp_m, mask=check, other=-float('inf')) + kp_m = tl.load(pkp_m, mask=check, other=-float('inf')) if meta['KP_MASK_MUL']: - kp_m = triton.where(kp_m == 0, -float('inf'), 0.) + kp_m = tl.where(kp_m == 0, -float('inf'), 0.) x = x + kp_m # apply attention mask if meta['APPLY_ATTN_MASK']: pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn - attn_m = triton.load(pattn_m, mask=check, other=-float('inf')) + attn_m = tl.load(pattn_m, mask=check, other=-float('inf')) if meta['ATTN_MASK_MUL']: - attn_m = triton.where(attn_m == 0, -float('inf'), 0.) + attn_m = tl.where(attn_m == 0, -float('inf'), 0.) x = x + attn_m # computation - x = triton.softmax(x) - triton.store(px, x, mask=check) + x = tl.softmax(x) + tl.store(px, x, mask=check) @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])}) @triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']}) @triton.jit def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta): - pidhm = triton.program_id(0) - pidz = triton.program_id(1) + pidhm = tl.program_id(0) + pidz = tl.program_id(1) TN = meta['TN'] BLOCK = meta['BLOCK'] # create index ranges rxm = pidhm % BLOCK rbm = pidhm // BLOCK - rxn = triton.arange(0, TN) % BLOCK - rbn = triton.arange(0, TN) // BLOCK + rxn = tl.arange(0, TN) % BLOCK + rbn = tl.arange(0, TN) // BLOCK # extract information from look-up table header = LUT + rbm * 2 - size = triton.load(header + 0) - offset = triton.load(header + 1) + size = tl.load(header + 0) + offset = tl.load(header + 1) # bounds checking on lut check = rbn < size - rbmn = triton.where(check, rbn, size - 1) + rbmn = tl.where(check, rbn, size - 1) # initialize pointers to block-sparse input - blockid = triton.load(LUT + offset + rbmn * 4) + blockid = tl.load(LUT + offset + rbmn * 4) X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn # compute fused softmax backward - x = triton.load(X, mask=check, other=0) - dx = triton.load(DX, mask=check, other=0) - x = x.to(triton.float32) - dx = dx.to(triton.float32) - y = x * (dx - triton.sum(x * dx, 0)) * scale - triton.store(DX, y, mask=check) + x = tl.load(X, mask=check, other=0) + dx = tl.load(DX, mask=check, other=0) + x = x.to(tl.float32) + dx = dx.to(tl.float32) + y = x * (dx - tl.sum(x * dx, 0)) * scale + tl.store(DX, y, mask=check) class _softmax(torch.autograd.Function): diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index e69ad2038..87833e3c0 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -1,5 +1,6 @@ import os import triton +import triton.language as tl import torch @@ -27,25 +28,25 @@ def num_warps(N): @triton.jit def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta): BLOCK = meta['BLOCK'] - row = triton.program_id(0) - cols = triton.arange(0, BLOCK) - idx = triton.load(IDX + row) + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + idx = tl.load(IDX + row) # pointers to logit and probs LOGITS = LOGITS + row * N + cols WRIT_PROBS = PROBS + row * N + cols READ_PROBS = PROBS + row * N + idx # write-back negative log-probs - logits = triton.load(LOGITS, mask=cols < N, other=-float('inf')) - logits = logits.to(triton.float32) - logits = logits - triton.max(logits, 0) - probs = triton.log(triton.sum(triton.exp(logits), 0)) - logits - triton.store(WRIT_PROBS, probs, mask=cols < N) + logits = tl.load(LOGITS, mask=cols < N, other=-float('inf')) + logits = logits.to(tl.float32) + logits = logits - tl.max(logits, 0) + probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits + tl.store(WRIT_PROBS, probs, mask=cols < N) # There is a bug in the compiler, which fails to insert a barrier here. # We add it explicitly for now. Will be fixed soon. - triton.debug_barrier() + tl.debug_barrier() # write-back loss - probs = triton.load(READ_PROBS) - triton.store(LOSS + row, probs) + probs = tl.load(READ_PROBS) + tl.store(LOSS + row, probs) @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])}) @@ -53,20 +54,20 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta): @triton.jit def _backward(PROBS, IDX, DPROBS, N, **meta): BLOCK = meta['BLOCK'] - row = triton.program_id(0) - cols = triton.arange(0, BLOCK) - idx = triton.load(IDX + row) + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + idx = tl.load(IDX + row) # pointers to probs PROBS = PROBS + row * N + cols # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] # and we have -log(p[k]) stored in PROBS, so this is easy - probs = -triton.load(PROBS, mask=cols < N, other=float('inf')) - probs = triton.exp(probs.to(triton.float32)) + probs = -tl.load(PROBS, mask=cols < N, other=float('inf')) + probs = tl.exp(probs.to(tl.float32)) delta = cols == idx # write result in-place in PROBS - dout = triton.load(DPROBS + row) + dout = tl.load(DPROBS + row) din = (probs - delta) * dout - triton.store(PROBS, din.to(triton.float16), mask=cols < N) + tl.store(PROBS, din.to(tl.float16), mask=cols < N) class _cross_entropy(torch.autograd.Function): diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 9af671a88..f5bc6afa3 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -1,4 +1,5 @@ import torch +import triton.language as tl import triton @@ -27,8 +28,8 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride GROUP_M = META['GROUP_M'] SPLIT_K = META['SPLIT_K'] # matrix multiplication - pid = triton.program_id(0) - pid_z = triton.program_id(1) + pid = tl.program_id(0) + pid_z = tl.program_id(1) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N # re-order program ID for better L2 performance @@ -38,46 +39,46 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) # do matrix multiplication - rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) - rk = triton.arange(0, BLOCK_K) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) # pointers K = K // SPLIT_K A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn) - acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(K, 0, -BLOCK_K): if META['EVEN_K']: - a = triton.load(A) - b = triton.load(B) + a = tl.load(A) + b = tl.load(B) else: - a = triton.load(A, mask=rk[None, :] < k, other=0.) - b = triton.load(B, mask=rk[:, None] < k, other=0.) - acc += triton.dot(a, b) + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + acc += tl.dot(a, b) A += BLOCK_K * stride_ak B += BLOCK_K * stride_bk - acc = acc.to(triton.float16) + acc = acc.to(tl.float16) # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) mask = (rm < M)[:, None] & (rn < N)[None, :] # handles write-back with reduction-splitting if SPLIT_K == 1: - triton.store(C, acc, mask=mask) + tl.store(C, acc, mask=mask) else: - LOCKS = LOCKS + triton.program_id(0) - COUNT = LOCKS + triton.num_programs(0) - while triton.atomic_cas(LOCKS, 0, 1) == 1: + LOCKS = LOCKS + tl.program_id(0) + COUNT = LOCKS + tl.num_programs(0) + while tl.atomic_cas(LOCKS, 0, 1) == 1: pass - count = triton.load(COUNT) + count = tl.load(COUNT) if count == 0: - triton.store(C, acc, mask=mask) + tl.store(C, acc, mask=mask) else: - curr = triton.load(C, mask=mask, other=0.) - triton.store(C, acc + curr, mask=mask) - triton.atomic_xchg(COUNT, (count + 1) % SPLIT_K) - triton.atomic_xchg(LOCKS, 0) + curr = tl.load(C, mask=mask, other=0.) + tl.store(C, acc + curr, mask=mask) + tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K) + tl.atomic_xchg(LOCKS, 0) class _matmul(torch.autograd.Function): diff --git a/python/triton/testing.py b/python/triton/testing.py index 317730781..eb4d89956 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -81,6 +81,22 @@ def random(shape, dtype, device): def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param percentiles: Performance percentile to return in addition to the median. + :type percentiles: list[float] + """ + # Estimate the runtime of the function fn() torch.cuda.synchronize() @@ -125,13 +141,16 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]): class Benchmark: + """ + This class is used by the :code:`perf_report` function to generate line plots with a concise API. + """ def __init__( self, x_names, x_vals, - y_name, - y_vals, - y_lines, + line_arg, + line_vals, + line_names, plot_name, args, xlabel='', @@ -139,12 +158,38 @@ class Benchmark: x_log=False, y_log=False, ): + """ + Constructor + + :param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value. + :type x_names: List[str] + :param x_vals: List of values to use for the arguments in :code:`x_names`. + :type x_vals: List[Any] + :param line_arg: Argument name for which different values correspond to different lines in the plot. + :type line_arg: str + :param line_vals: List of values to use for the arguments in :code:`line_arg`. + :type line_vals: List[str] + :param line_names: Label names for the different lines. + :type line_names: List[str] + :param plot_name: Name of the plot. + :type plot_name: str + :param args: List of arguments to remain fixed throughout the benchmark. + :type args: List[str] + :param xlabel: Label for the x axis of the plot. + :type xlabel: str, optional + :param ylabel: Label for the y axis of the plot. + :type ylabel: str, optional + :param x_log: Whether the x axis should be log scale. + :type x_log: bool, optional + :param y_log: Whether the y axis should be log scale. + :type y_log: bool, optional + """ self.x_names = x_names self.x_vals = x_vals self.x_log = x_log - self.y_name = y_name - self.y_vals = y_vals - self.y_lines = y_lines + self.line_arg = line_arg + self.line_vals = line_vals + self.line_names = line_names self.y_log = y_log # plot info self.xlabel = xlabel @@ -162,15 +207,15 @@ class Mark: import matplotlib.pyplot as plt import pandas as pd import os - y_mean = bench.y_lines - y_min = [f'{x}-min' for x in bench.y_lines] - y_max = [f'{x}-max' for x in bench.y_lines] + y_mean = bench.line_names + y_min = [f'{x}-min' for x in bench.line_names] + y_max = [f'{x}-max' for x in bench.line_names] df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max) for x in bench.x_vals: x_args = {x_name: x for x_name in bench.x_names} row_mean, row_min, row_max = [], [], [] - for y in bench.y_vals: - ret = self.fn(**x_args, **{bench.y_name: y}, **bench.args) + for y in bench.line_vals: + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args) try: y_mean, y_min, y_max = ret except TypeError: @@ -183,7 +228,7 @@ class Mark: plt.figure() ax = plt.subplot() x = bench.x_names[0] - for y in bench.y_lines: + for y in bench.line_names: y_min, y_max = df[y + '-min'], df[y + '-max'] ax.plot(df[x], df[y], label=y) if y_min is not None and y_max is not None: @@ -199,7 +244,7 @@ class Mark: plt.show() if save_path: plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) - df = df[[bench.x_names[0]] + bench.y_lines] + df = df[[bench.x_names[0]] + bench.line_names] if print_data: print(df) if save_path: @@ -220,5 +265,11 @@ class Mark: def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. + + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ wrapper = lambda fn: Mark(fn, benchmarks) return wrapper diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 29797d579..2fa6a5fa9 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -13,6 +13,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn # -------------------------- import torch +import triton.language as tl import triton @@ -24,19 +25,19 @@ def _add( N, # Size of the vector **meta # Optional meta-parameters for the kernel ): - pid = triton.program_id(0) + pid = tl.program_id(0) # Create an offset for the blocks of pointers to be # processed by this program instance - offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK']) + offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK']) # Create a mask to guard memory operations against # out-of-bounds accesses mask = offsets < N # Load x - x = triton.load(X + offsets, mask=mask) - y = triton.load(Y + offsets, mask=mask) + x = tl.load(X + offsets, mask=mask) + y = tl.load(Y + offsets, mask=mask) # Write back x + y z = x + y - triton.store(Z + offsets, z) + tl.store(Z + offsets, z) # %% @@ -89,9 +90,9 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch. x_names=['size'], # argument names to use as an x-axis for the plot x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name` x_log=True, # x axis is logarithmic - y_name='provider', # argument name whose value corresponds to a different line in the plot - y_vals=['torch', 'triton'], # possible keys for `y_name` - y_lines=["Torch", "Triton"], # label name for the lines + line_arg='provider', # argument name whose value corresponds to a different line in the plot + line_vals=['torch', 'triton'], # possible values for `line_arg` + line_names=["Torch", "Triton"], # label name for the lines ylabel="GB/s", # label name for the y-axis plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot. args={} # values for function arguments not in `x_names` and `y_name` diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index f9b1b5103..dd93f8d83 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -46,28 +46,29 @@ def naive_softmax(x): # so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes: import triton +import triton.language as tl @triton.jit def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta): # row index - m = triton.program_id(0) + m = tl.program_id(0) # col indices - n = triton.arange(0, meta['BLOCK']) + n = tl.arange(0, meta['BLOCK']) # the memory address of all the elements # that we want to load can be computed as follows X = X + m * stride_xm + n - x = triton.load(X, mask=n < N, other=-float('inf')) + x = tl.load(X, mask=n < N, other=-float('inf')) # Substract maximum for numerical stability - z = x - triton.max(x, axis=0) + z = x - tl.max(x, axis=0) # Note that exponentials in Triton are fast # but approximate (i.e., think __expf in CUDA) - num = triton.exp(z) - denom = triton.sum(num, axis=0) + num = tl.exp(z) + denom = tl.sum(num, axis=0) y = num / denom # Write back to Y Y = Y + m * stride_ym + n - triton.store(Y, y, mask=n < N) + tl.store(Y, y, mask=n < N) # %% @@ -132,9 +133,9 @@ print(torch.allclose(y_tri, y_ref)) triton.testing.Benchmark( x_names=['N'], # argument names to use as an x-axis for the plot x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name` - y_name='provider', # argument name whose value corresponds to a different line in the plot - y_vals=['torch', 'triton', 'naive'], # possible keys for `y_name` - y_lines=["Torch", "Triton", 'Naive'], # label name for the lines + line_arg='provider', # argument name whose value corresponds to a different line in the plot + line_vals=['torch', 'triton', 'naive'], # possible values for `line_arg`` + line_names=["Torch", "Triton", 'Naive'], # label name for the lines ylabel="GB/s", # label name for the y-axis plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. args={'M': 4096} # values for function arguments not in `x_names` and `y_name` diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 0b7db9387..d80f3b3a1 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -115,6 +115,7 @@ You will specifically learn about: import torch import triton +import triton.language as tl # % # :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: @@ -124,9 +125,9 @@ import triton @triton.jit def sigmoid(x): - ret_true = 1 / (1 + triton.exp(-x)) - ret_false = triton.exp(x) / (1 + triton.exp(x)) - return triton.where(x >= 0, ret_true, ret_false) + ret_true = 1 / (1 + tl.exp(-x)) + ret_false = tl.exp(x) / (1 + tl.exp(x)) + return tl.where(x >= 0, ret_true, ret_false) @triton.jit @@ -151,7 +152,7 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride BLOCK_K = META['BLOCK_K'] GROUP_M = 8 # matrix multiplication - pid = triton.program_id(0) + pid = tl.program_id(0) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N # re-order program ID for better L2 performance @@ -161,16 +162,16 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) # do matrix multiplication - rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) - rk = triton.arange(0, BLOCK_K) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn) - acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(K, 0, -BLOCK_K): - a = triton.load(A) - b = triton.load(B) - acc += triton.dot(a, b) + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) A += BLOCK_K * stride_ak B += BLOCK_K * stride_bk # triton can accept arbitrary activation function @@ -178,11 +179,11 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride if META['ACTIVATION']: acc = META['ACTIVATION'](acc) # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) mask = (rm[:, None] < M) & (rn[None, :] < N) - triton.store(C, acc, mask=mask) + tl.store(C, acc, mask=mask) # %% @@ -238,9 +239,9 @@ print(triton.testing.allclose(c_0, c_1)) triton.testing.Benchmark( x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name` - y_name='provider', # argument name whose value corresponds to a different line in the plot - y_vals=['cublas', 'triton'], # possible keys for `y_name` - y_lines=["cuBLAS", "Triton"], # label name for the lines + line_arg='provider', # argument name whose value corresponds to a different line in the plot + line_vals=['cublas', 'triton'], # possible values for `line_arg`` + line_names=["cuBLAS", "Triton"], # label name for the lines ylabel="TFLOPS", # label name for the y-axis plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot. args={} @@ -258,4 +259,4 @@ def benchmark(M, N, K, provider): return perf(ms), perf(max_ms), perf(min_ms) -benchmark.run(print_data=True) \ No newline at end of file +benchmark.run(show_plots=True, print_data=True) \ No newline at end of file