[PYTHON] Renamed triton.core -> triton.language (#92)
This commit is contained in:
committed by
Philippe Tillet
parent
41410012e8
commit
bfc0a7587d
@@ -1,7 +1,7 @@
|
|||||||
Welcome to Triton's documentation!
|
Welcome to Triton's documentation!
|
||||||
==================================
|
==================================
|
||||||
|
|
||||||
Triton is an imperative language and compiler for parallel programming. It aims to provide a programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
|
Triton is an language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
|
||||||
|
|
||||||
Getting Started
|
Getting Started
|
||||||
---------------
|
---------------
|
||||||
@@ -17,18 +17,22 @@ Getting Started
|
|||||||
getting-started/installation
|
getting-started/installation
|
||||||
getting-started/tutorials/index
|
getting-started/tutorials/index
|
||||||
|
|
||||||
Language Reference
|
Python API
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
- Checkout the :doc:`Python API Documentation <language-reference/python-api/index>`
|
- :doc:`triton <python-api/triton>`
|
||||||
|
- :doc:`triton.language <python-api/triton.language>`
|
||||||
|
- :doc:`triton.testing <python-api/triton.testing>`
|
||||||
|
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
:caption: Language Reference
|
:caption: Python API
|
||||||
:hidden:
|
:hidden:
|
||||||
|
|
||||||
language-reference/python-api/index
|
python-api/triton
|
||||||
|
python-api/triton.language
|
||||||
|
python-api/triton.testing
|
||||||
|
|
||||||
|
|
||||||
Going Further
|
Going Further
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
Python API
|
triton.language
|
||||||
===========
|
================
|
||||||
|
|
||||||
.. currentmodule:: triton
|
.. currentmodule:: triton.language
|
||||||
|
|
||||||
|
|
||||||
Programming Model
|
Programming Model
|
10
docs/python-api/triton.rst
Normal file
10
docs/python-api/triton.rst
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
triton
|
||||||
|
========
|
||||||
|
|
||||||
|
.. currentmodule:: triton
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
jit
|
12
docs/python-api/triton.testing.rst
Normal file
12
docs/python-api/triton.testing.rst
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
triton.testing
|
||||||
|
================
|
||||||
|
|
||||||
|
.. currentmodule:: triton.testing
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
do_bench
|
||||||
|
Benchmark
|
||||||
|
perf_report
|
@@ -10,9 +10,9 @@ square_confs = [
|
|||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names = ['M', 'N', 'K'],
|
x_names = ['M', 'N', 'K'],
|
||||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||||
y_name = 'block',
|
line_arg = 'block',
|
||||||
y_vals = [16, 32, 64],
|
line_vals = [16, 32, 64],
|
||||||
y_lines = ['Block16', 'Block32', 'Block64'],
|
line_names = ['Block16', 'Block32', 'Block64'],
|
||||||
ylabel = 'TFLOPS',
|
ylabel = 'TFLOPS',
|
||||||
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
||||||
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
|
args = {'layout_mode': layout_mode, 'op_mode': op_mode,
|
||||||
@@ -60,9 +60,9 @@ square_confs = [
|
|||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names = ['M', 'N'],
|
x_names = ['M', 'N'],
|
||||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||||
y_name = 'block',
|
line_arg = 'block',
|
||||||
y_vals = [16, 32, 64],
|
line_vals = [16, 32, 64],
|
||||||
y_lines = ['Block16', 'Block32', 'Block64'],
|
line_names = ['Block16', 'Block32', 'Block64'],
|
||||||
ylabel = 'GBPS',
|
ylabel = 'GBPS',
|
||||||
plot_name = f'{layout_mode}-square',
|
plot_name = f'{layout_mode}-square',
|
||||||
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
||||||
|
@@ -5,9 +5,9 @@ confs = [
|
|||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names = ['N'],
|
x_names = ['N'],
|
||||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
|
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
|
||||||
y_name = 'provider',
|
line_arg = 'provider',
|
||||||
y_vals = ['triton', 'torch'],
|
line_vals = ['triton', 'torch'],
|
||||||
y_lines = ['Triton', 'Torch'],
|
line_names = ['Triton', 'Torch'],
|
||||||
ylabel = 'GBPS',
|
ylabel = 'GBPS',
|
||||||
plot_name = f'{mode}-2048',
|
plot_name = f'{mode}-2048',
|
||||||
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
||||||
|
@@ -16,9 +16,9 @@ square_confs = [
|
|||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=["M", "N", "K"],
|
x_names=["M", "N", "K"],
|
||||||
x_vals=rounded_linspace(512, 8192, 32, 128),
|
x_vals=rounded_linspace(512, 8192, 32, 128),
|
||||||
y_name="provider",
|
line_arg="provider",
|
||||||
y_vals=["cublas", "triton", "cutlass"],
|
line_vals=["cublas", "triton", "cutlass"],
|
||||||
y_lines=["cuBLAS", "Triton", "CUTLASS"],
|
line_names=["cuBLAS", "Triton", "CUTLASS"],
|
||||||
ylabel="TFLOPS",
|
ylabel="TFLOPS",
|
||||||
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
|
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
|
||||||
args={"AT": AT, "BT": BT, "dtype": torch.float16},
|
args={"AT": AT, "BT": BT, "dtype": torch.float16},
|
||||||
@@ -30,9 +30,9 @@ transformer_confs = [
|
|||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=[x],
|
x_names=[x],
|
||||||
x_vals = rounded_linspace(NK//16, NK, 32, 128),
|
x_vals = rounded_linspace(NK//16, NK, 32, 128),
|
||||||
y_name="provider",
|
line_arg="provider",
|
||||||
y_vals=["cublas", "triton", "cutlass"],
|
line_vals=["cublas", "triton", "cutlass"],
|
||||||
y_lines=["cuBLAS", "Triton", "CUTLASS"],
|
line_names=["cuBLAS", "Triton", "CUTLASS"],
|
||||||
ylabel="TFLOPS",
|
ylabel="TFLOPS",
|
||||||
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
||||||
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
import copy
|
import copy
|
||||||
import pytest
|
import pytest
|
||||||
import ast
|
import ast
|
||||||
@@ -37,10 +38,10 @@ def _test_unary(dtype_x, expr, device='cuda'):
|
|||||||
# define the kernel / launch-grid
|
# define the kernel / launch-grid
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(Z, X, **meta):
|
def kernel(Z, X, **meta):
|
||||||
off = triton.arange(0, meta['SIZE'])
|
off = tl.arange(0, meta['SIZE'])
|
||||||
x = triton.load(X + off)
|
x = tl.load(X + off)
|
||||||
z = GENERATE_TEST_HERE
|
z = GENERATE_TEST_HERE
|
||||||
triton.store(Z + off, z)
|
tl.store(Z + off, z)
|
||||||
|
|
||||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||||
# inputs
|
# inputs
|
||||||
@@ -59,11 +60,11 @@ def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
|
|||||||
# define the kernel / launch-grid
|
# define the kernel / launch-grid
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(Z, X, Y, **meta):
|
def kernel(Z, X, Y, **meta):
|
||||||
off = triton.arange(0, meta['SIZE'])
|
off = tl.arange(0, meta['SIZE'])
|
||||||
x = triton.load(X + off)
|
x = tl.load(X + off)
|
||||||
y = triton.load(Y + off)
|
y = tl.load(Y + off)
|
||||||
z = GENERATE_TEST_HERE
|
z = GENERATE_TEST_HERE
|
||||||
triton.store(Z + off, z)
|
tl.store(Z + off, z)
|
||||||
|
|
||||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||||
# inputs
|
# inputs
|
||||||
@@ -144,7 +145,7 @@ def make_ptr_str(name, shape):
|
|||||||
stride = 1
|
stride = 1
|
||||||
for i in reversed(range(rank)):
|
for i in reversed(range(rank)):
|
||||||
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
|
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
|
||||||
offsets += [f'triton.arange(0, {shape[i]})[{idx}]*{stride}']
|
offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
|
||||||
stride *= shape[i]
|
stride *= shape[i]
|
||||||
return f"{name} + {' + '.join(offsets)}"
|
return f"{name} + {' + '.join(offsets)}"
|
||||||
|
|
||||||
@@ -164,11 +165,11 @@ def test_index1d(expr, device='cuda'):
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(Z, X, **meta):
|
def kernel(Z, X, **meta):
|
||||||
SIZE = meta['SIZE']
|
SIZE = meta['SIZE']
|
||||||
m = triton.arange(0, SIZE)
|
m = tl.arange(0, SIZE)
|
||||||
n = triton.arange(0, SIZE)
|
n = tl.arange(0, SIZE)
|
||||||
x = triton.load(X_PTR_EXPR)
|
x = tl.load(X_PTR_EXPR)
|
||||||
z = GENERATE_TEST_HERE
|
z = GENERATE_TEST_HERE
|
||||||
triton.store(Z_PTR_EXPR, z)
|
tl.store(Z_PTR_EXPR, z)
|
||||||
|
|
||||||
to_replace = {
|
to_replace = {
|
||||||
'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
@@ -2,9 +2,9 @@
|
|||||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||||
import torch
|
import torch
|
||||||
# submodules
|
# submodules
|
||||||
from .code_gen import jit, autotune, heuristics, Config, Autotuner
|
from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner
|
||||||
from .core import *
|
|
||||||
|
|
||||||
|
from . import language
|
||||||
from . import code_gen
|
from . import code_gen
|
||||||
from . import testing
|
from . import testing
|
||||||
from . import ops
|
from . import ops
|
||||||
|
@@ -26,21 +26,21 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ret = self.builtins[name]
|
ret = self.builtins[name]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'{name} is not defined')
|
raise ValueError(f'{name} is not defined')
|
||||||
if isinstance(ret, triton.block):
|
if isinstance(ret, triton.language.block):
|
||||||
handle = self.module.get_value(name)
|
handle = self.module.get_value(name)
|
||||||
return triton.block(handle)
|
return triton.language.block(handle)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def set_value(self, name, value):
|
def set_value(self, name, value):
|
||||||
if isinstance(value, _triton.ir.value):
|
if isinstance(value, _triton.ir.value):
|
||||||
value = triton.block(value)
|
value = triton.language.block(value)
|
||||||
if isinstance(value, triton.block):
|
if isinstance(value, triton.language.block):
|
||||||
self.module.set_value(name, value.handle)
|
self.module.set_value(name, value.handle)
|
||||||
self.module.scope.set_type(name, value.handle.type)
|
self.module.scope.set_type(name, value.handle.type)
|
||||||
self.lscope[name] = value
|
self.lscope[name] = value
|
||||||
|
|
||||||
def is_triton_object(self, value):
|
def is_triton_object(self, value):
|
||||||
return isinstance(value, triton.block)
|
return isinstance(value, triton.language.block)
|
||||||
|
|
||||||
def visit_compound_statement(self, stmts, add_scope=False):
|
def visit_compound_statement(self, stmts, add_scope=False):
|
||||||
if add_scope:
|
if add_scope:
|
||||||
@@ -63,7 +63,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
self.constants = constants
|
self.constants = constants
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.last_node = None
|
self.last_node = None
|
||||||
self.builtins = {'range': range, 'min': triton.minimum, 'float': float, 'int': int, 'print': print, 'getattr': getattr}
|
self.builtins = {
|
||||||
|
'range': range,
|
||||||
|
'min': triton.language.minimum,
|
||||||
|
'float': float,
|
||||||
|
'int': int,
|
||||||
|
'print': print,
|
||||||
|
'getattr': getattr,
|
||||||
|
}
|
||||||
|
|
||||||
def visit_Module(self, node):
|
def visit_Module(self, node):
|
||||||
self.module.add_new_scope()
|
self.module.add_new_scope()
|
||||||
@@ -303,7 +310,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]])
|
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]])
|
||||||
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]])
|
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]])
|
||||||
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)])
|
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)])
|
||||||
build_cond = lambda: triton.where(self.visit(pos_step_node),\
|
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
|
||||||
self.visit(pos_cond_node),\
|
self.visit(pos_cond_node),\
|
||||||
self.visit(neg_cond_node),\
|
self.visit(neg_cond_node),\
|
||||||
builder=self.builder)
|
builder=self.builder)
|
||||||
@@ -359,7 +366,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
if isinstance(fn, JITFunction):
|
if isinstance(fn, JITFunction):
|
||||||
return fn(*args, generator=self, **kws)
|
return fn(*args, generator=self, **kws)
|
||||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||||
sys.modules[fn.__module__] is triton.core:
|
sys.modules[fn.__module__] is triton.language:
|
||||||
return fn(*args, builder=self.builder, **kws)
|
return fn(*args, builder=self.builder, **kws)
|
||||||
return fn(*args, **kws)
|
return fn(*args, **kws)
|
||||||
|
|
||||||
@@ -613,6 +620,11 @@ class JITFunction:
|
|||||||
raise e
|
raise e
|
||||||
raise CompilationError(self.src, node, e)
|
raise CompilationError(self.src, node, e)
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
if name == 'kernel_decorators':
|
||||||
|
self.kernel = None
|
||||||
|
super(JITFunction, self).__setattr__(name, value)
|
||||||
|
|
||||||
def _init_kernel(self):
|
def _init_kernel(self):
|
||||||
if self.kernel is None:
|
if self.kernel is None:
|
||||||
self.kernel = Kernel(self)
|
self.kernel = Kernel(self)
|
||||||
@@ -659,4 +671,23 @@ def heuristics(values):
|
|||||||
|
|
||||||
|
|
||||||
def jit(fn):
|
def jit(fn):
|
||||||
|
"""
|
||||||
|
Decorator for JIT-compiling a function using the Triton compiler.
|
||||||
|
|
||||||
|
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||||
|
|
||||||
|
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||||
|
|
||||||
|
* python primitives,
|
||||||
|
* objects within the triton.language package,
|
||||||
|
* arguments to this function,
|
||||||
|
* other jit'd functions
|
||||||
|
|
||||||
|
:param fn: the function to be jit-compiled
|
||||||
|
:type fn: Callable
|
||||||
|
"""
|
||||||
return JITFunction(fn)
|
return JITFunction(fn)
|
||||||
|
|
||||||
|
|
||||||
|
def cdiv(x, y):
|
||||||
|
return (x + y - 1) // y
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
|
import triton
|
||||||
from triton._C.libtriton.triton import ir
|
from triton._C.libtriton.triton import ir
|
||||||
from triton._C.libtriton.triton import frontend
|
from triton._C.libtriton.triton import frontend
|
||||||
import triton
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ def _patch(fn):
|
|||||||
if x.type.is_void():
|
if x.type.is_void():
|
||||||
return None
|
return None
|
||||||
return block(x)
|
return block(x)
|
||||||
return x
|
return tl
|
||||||
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
builder = args[-1]
|
builder = args[-1]
|
||||||
@@ -547,7 +547,7 @@ def minimum(x, y):
|
|||||||
:param other: the second input block
|
:param other: the second input block
|
||||||
:type other: Block
|
:type other: Block
|
||||||
"""
|
"""
|
||||||
return triton.where(x < y, x, y)
|
return triton.language.where(x < y, x, y)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -560,7 +560,7 @@ def maximum(x, y):
|
|||||||
:param other: the second input block
|
:param other: the second input block
|
||||||
:type other: Block
|
:type other: Block
|
||||||
"""
|
"""
|
||||||
return triton.where(x > y, x, y)
|
return triton.language.where(x > y, x, y)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -571,7 +571,7 @@ def sigmoid(x):
|
|||||||
:param x: the input block
|
:param x: the input block
|
||||||
:type x: Block
|
:type x: Block
|
||||||
"""
|
"""
|
||||||
return 1 / (1 + np.exp(-x))
|
return 1 / (1 + triton.language.exp(-x))
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -582,9 +582,9 @@ def softmax(x):
|
|||||||
:param x: the input block
|
:param x: the input block
|
||||||
:type x: Block
|
:type x: Block
|
||||||
"""
|
"""
|
||||||
z = x - triton.max(x, 0)
|
z = x - triton.language.max(x, 0)
|
||||||
num = triton.exp(z)
|
num = triton.language.exp(z)
|
||||||
den = triton.sum(num, 0)
|
den = triton.language.sum(num, 0)
|
||||||
return num / den
|
return num / den
|
||||||
|
|
||||||
|
|
||||||
@@ -596,8 +596,4 @@ def ravel(x):
|
|||||||
:param x: the input block
|
:param x: the input block
|
||||||
:type x: Block
|
:type x: Block
|
||||||
"""
|
"""
|
||||||
return triton.reshape(x, [x.type.numel])
|
return triton.language.reshape(x, [x.type.numel])
|
||||||
|
|
||||||
|
|
||||||
def cdiv(x, y):
|
|
||||||
return (x + y - 1) // y
|
|
@@ -1,4 +1,5 @@
|
|||||||
import triton
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
import triton._C.libtriton as libtriton
|
import triton._C.libtriton as libtriton
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -16,21 +17,21 @@ def _kernel(
|
|||||||
#------------#
|
#------------#
|
||||||
#- Prologue -#
|
#- Prologue -#
|
||||||
#------------#
|
#------------#
|
||||||
pid0 = triton.program_id(0)
|
pid0 = tl.program_id(0)
|
||||||
pid1 = triton.program_id(1)
|
pid1 = tl.program_id(1)
|
||||||
pidz = triton.program_id(2)
|
pidz = tl.program_id(2)
|
||||||
if meta['SDD']:
|
if meta['SDD']:
|
||||||
pid1 = pid1 + SDD_off_width
|
pid1 = pid1 + SDD_off_width
|
||||||
blockidm = triton.arange(0, TM) // BLOCK
|
blockidm = tl.arange(0, TM) // BLOCK
|
||||||
blockidn = triton.arange(0, TN) // BLOCK
|
blockidn = tl.arange(0, TN) // BLOCK
|
||||||
offlutm = blockidm * (TN // BLOCK) * 4
|
offlutm = blockidm * (TN // BLOCK) * 4
|
||||||
offlutn = blockidn * 4
|
offlutn = blockidn * 4
|
||||||
header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4
|
header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4
|
||||||
z = triton.load(header + 0)
|
z = tl.load(header + 0)
|
||||||
i = triton.load(header + 1 + offlutm)
|
i = tl.load(header + 1 + offlutm)
|
||||||
j = triton.load(header + 2 + offlutn)
|
j = tl.load(header + 2 + offlutn)
|
||||||
AS1 = SDD_K // TZ
|
AS1 = SDD_K // TZ
|
||||||
lockid = triton.where(TZ > 1, 1, 0)
|
lockid = tl.where(TZ > 1, 1, 0)
|
||||||
offka = pid0 * AS1
|
offka = pid0 * AS1
|
||||||
offkb = pid0 * AS1
|
offkb = pid0 * AS1
|
||||||
offmc = 0
|
offmc = 0
|
||||||
@@ -41,16 +42,16 @@ def _kernel(
|
|||||||
offhc = 0
|
offhc = 0
|
||||||
offha = z
|
offha = z
|
||||||
offhb = z
|
offhb = z
|
||||||
ram = i * BLOCK + (triton.arange(0, TM) % BLOCK)
|
ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)
|
||||||
rbn = j * BLOCK + (triton.arange(0, TN) % BLOCK)
|
rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)
|
||||||
else:
|
else:
|
||||||
header = lut + pid0 * 6
|
header = lut + pid0 * 6
|
||||||
offset = triton.load(header + 0)
|
offset = tl.load(header + 0)
|
||||||
AS1 = triton.load(header + 1)
|
AS1 = tl.load(header + 1)
|
||||||
column = triton.load(header + 2)
|
column = tl.load(header + 2)
|
||||||
depth = triton.load(header + 3)
|
depth = tl.load(header + 3)
|
||||||
lockid = triton.load(header + 4)
|
lockid = tl.load(header + 4)
|
||||||
maxid = triton.load(header + 5)
|
maxid = tl.load(header + 5)
|
||||||
pinc = lut + offset
|
pinc = lut + offset
|
||||||
offhc = depth
|
offhc = depth
|
||||||
if meta['DSD']:
|
if meta['DSD']:
|
||||||
@@ -60,14 +61,14 @@ def _kernel(
|
|||||||
offpc = 0
|
offpc = 0
|
||||||
# dense input offset
|
# dense input offset
|
||||||
offnb = pid1 * TN
|
offnb = pid1 * TN
|
||||||
offkb = triton.load(pinc)
|
offkb = tl.load(pinc)
|
||||||
offkb = triton.multiple_of(offkb, 8) # compiler hint
|
offkb = tl.multiple_of(offkb, 8) # compiler hint
|
||||||
offpb = 0
|
offpb = 0
|
||||||
# sparse input offset
|
# sparse input offset
|
||||||
offma = 0
|
offma = 0
|
||||||
offka = 0
|
offka = 0
|
||||||
offpa = triton.load(pinc + 1)
|
offpa = tl.load(pinc + 1)
|
||||||
offpa = triton.multiple_of(offpa, 8) # compiler hint
|
offpa = tl.multiple_of(offpa, 8) # compiler hint
|
||||||
offpa = offpa * BLOCK * BLOCK
|
offpa = offpa * BLOCK * BLOCK
|
||||||
offha = 0
|
offha = 0
|
||||||
offhb = depth
|
offhb = depth
|
||||||
@@ -78,23 +79,23 @@ def _kernel(
|
|||||||
offpc = 0
|
offpc = 0
|
||||||
# dense input offset
|
# dense input offset
|
||||||
offma = pid1 * TM
|
offma = pid1 * TM
|
||||||
offka = triton.load(pinc)
|
offka = tl.load(pinc)
|
||||||
offka = triton.multiple_of(offka, 8) # compiler hint
|
offka = tl.multiple_of(offka, 8) # compiler hint
|
||||||
offpa = 0
|
offpa = 0
|
||||||
# sparse input offset
|
# sparse input offset
|
||||||
offnb = 0
|
offnb = 0
|
||||||
offkb = 0
|
offkb = 0
|
||||||
offpb = triton.load(pinc + 1)
|
offpb = tl.load(pinc + 1)
|
||||||
offpb = triton.multiple_of(offpb, 8) # compiler hint
|
offpb = tl.multiple_of(offpb, 8) # compiler hint
|
||||||
offpb = offpb * BLOCK * BLOCK
|
offpb = offpb * BLOCK * BLOCK
|
||||||
offha = depth
|
offha = depth
|
||||||
offhb = 0
|
offhb = 0
|
||||||
ram = offma + triton.arange(0, TM)
|
ram = offma + tl.arange(0, TM)
|
||||||
rbn = offnb + triton.arange(0, TN)
|
rbn = offnb + tl.arange(0, TN)
|
||||||
|
|
||||||
# initialize a, b pointers
|
# initialize a, b pointers
|
||||||
rka = offka + triton.arange(0, TK)
|
rka = offka + tl.arange(0, TK)
|
||||||
rkb = offkb + triton.arange(0, TK)
|
rkb = offkb + tl.arange(0, TK)
|
||||||
pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
|
pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[None, :] * stride_ka
|
||||||
pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
|
pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[None, :] * stride_nb + rkb[:, None] * stride_kb
|
||||||
if meta['DDS']:
|
if meta['DDS']:
|
||||||
@@ -105,31 +106,31 @@ def _kernel(
|
|||||||
checkbn = rbn[None, :] < DS0
|
checkbn = rbn[None, :] < DS0
|
||||||
else:
|
else:
|
||||||
checkbn = AS1 > 0
|
checkbn = AS1 > 0
|
||||||
a = triton.load(pa, mask=checkam, other=0.)
|
a = tl.load(pa, mask=checkam, other=0.)
|
||||||
b = triton.load(pb, mask=checkbn, other=0.)
|
b = tl.load(pb, mask=checkbn, other=0.)
|
||||||
|
|
||||||
## ---------------- ##
|
## ---------------- ##
|
||||||
## Inner Loop ##
|
## Inner Loop ##
|
||||||
## ---------------- ##
|
## ---------------- ##
|
||||||
acc = triton.zeros((TM, TN), dtype=triton.float32)
|
acc = tl.zeros((TM, TN), dtype=tl.float32)
|
||||||
for k in range(AS1, 0, -TK):
|
for k in range(AS1, 0, -TK):
|
||||||
acc += triton.dot(a, b)
|
acc += tl.dot(a, b)
|
||||||
if meta['SDD']:
|
if meta['SDD']:
|
||||||
inc_a = TK * stride_ka
|
inc_a = TK * stride_ka
|
||||||
inc_b = TK * stride_kb
|
inc_b = TK * stride_kb
|
||||||
else:
|
else:
|
||||||
pinc += 2
|
pinc += 2
|
||||||
if meta['DSD']:
|
if meta['DSD']:
|
||||||
inc_b = triton.load(pinc)
|
inc_b = tl.load(pinc)
|
||||||
inc_a = triton.load(pinc + 1)
|
inc_a = tl.load(pinc + 1)
|
||||||
inc_b = triton.multiple_of(inc_b, 8)
|
inc_b = tl.multiple_of(inc_b, 8)
|
||||||
inc_a = triton.multiple_of(inc_a, 8)
|
inc_a = tl.multiple_of(inc_a, 8)
|
||||||
inc_b = inc_b * stride_kb
|
inc_b = inc_b * stride_kb
|
||||||
if meta['DDS']:
|
if meta['DDS']:
|
||||||
inc_a = triton.load(pinc)
|
inc_a = tl.load(pinc)
|
||||||
inc_b = triton.load(pinc + 1)
|
inc_b = tl.load(pinc + 1)
|
||||||
inc_a = triton.multiple_of(inc_a, 8)
|
inc_a = tl.multiple_of(inc_a, 8)
|
||||||
inc_b = triton.multiple_of(inc_b, 8)
|
inc_b = tl.multiple_of(inc_b, 8)
|
||||||
inc_a = inc_a * stride_ka
|
inc_a = inc_a * stride_ka
|
||||||
pa += inc_a
|
pa += inc_a
|
||||||
pb += inc_b
|
pb += inc_b
|
||||||
@@ -138,24 +139,24 @@ def _kernel(
|
|||||||
checkbk = k > TK
|
checkbk = k > TK
|
||||||
checka = checkam & checkak
|
checka = checkam & checkak
|
||||||
checkb = checkbn & checkbk
|
checkb = checkbn & checkbk
|
||||||
a = triton.load(pa, mask=checka)
|
a = tl.load(pa, mask=checka)
|
||||||
b = triton.load(pb, mask=checkb)
|
b = tl.load(pb, mask=checkb)
|
||||||
c = acc.to(C.dtype.element_ty)
|
c = acc.to(C.dtype.element_ty)
|
||||||
|
|
||||||
if meta['SDD']:
|
if meta['SDD']:
|
||||||
checkc = True
|
checkc = True
|
||||||
rr_blockidm = triton.arange(0, TM) // BLOCK
|
rr_blockidm = tl.arange(0, TM) // BLOCK
|
||||||
rr_blockidn = triton.arange(0, TN) // BLOCK
|
rr_blockidn = tl.arange(0, TN) // BLOCK
|
||||||
rr_offlutm = rr_blockidm * (TN // BLOCK) * 4
|
rr_offlutm = rr_blockidm * (TN // BLOCK) * 4
|
||||||
rr_offlutn = rr_blockidn * 4
|
rr_offlutn = rr_blockidn * 4
|
||||||
off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]
|
off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]
|
||||||
bkid = triton.load(header + off_bkid)
|
bkid = tl.load(header + off_bkid)
|
||||||
offpc = bkid * BLOCK * BLOCK
|
offpc = bkid * BLOCK * BLOCK
|
||||||
rcm = triton.arange(0, TM) % BLOCK
|
rcm = tl.arange(0, TM) % BLOCK
|
||||||
rcn = triton.arange(0, TN) % BLOCK
|
rcn = tl.arange(0, TN) % BLOCK
|
||||||
else:
|
else:
|
||||||
rcm = offmc + triton.arange(0, TM)
|
rcm = offmc + tl.arange(0, TM)
|
||||||
rcn = offnc + triton.arange(0, TN)
|
rcn = offnc + tl.arange(0, TN)
|
||||||
if meta['DSD']:
|
if meta['DSD']:
|
||||||
checkc = rcn[None, :] < DS0
|
checkc = rcn[None, :] < DS0
|
||||||
if meta['DDS']:
|
if meta['DDS']:
|
||||||
@@ -164,21 +165,21 @@ def _kernel(
|
|||||||
pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc
|
pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[None, :] * stride_nc
|
||||||
# write-back directly
|
# write-back directly
|
||||||
if lockid == 0:
|
if lockid == 0:
|
||||||
triton.store(pc, c, mask=checkc)
|
tl.store(pc, c, mask=checkc)
|
||||||
# accumulate partial results using spin-locks
|
# accumulate partial results using spin-locks
|
||||||
else:
|
else:
|
||||||
plock = locks + triton.program_id(2) * nlocks * triton.num_programs(1) + triton.program_id(1) * nlocks + lockid - 1
|
plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1
|
||||||
pcount = plock + triton.num_programs(2) * triton.num_programs(1) * nlocks
|
pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks
|
||||||
while triton.atomic_cas(plock, 0, 1) == 1:
|
while tl.atomic_cas(plock, 0, 1) == 1:
|
||||||
pass
|
pass
|
||||||
count = triton.load(pcount)
|
count = tl.load(pcount)
|
||||||
if count == 0:
|
if count == 0:
|
||||||
triton.store(pc, c, mask=checkc)
|
tl.store(pc, c, mask=checkc)
|
||||||
else:
|
else:
|
||||||
d = triton.load(pc, mask=checkc)
|
d = tl.load(pc, mask=checkc)
|
||||||
triton.store(pc, d + c, mask=checkc)
|
tl.store(pc, d + c, mask=checkc)
|
||||||
triton.atomic_xchg(pcount, (count + 1) % maxid)
|
tl.atomic_xchg(pcount, (count + 1) % maxid)
|
||||||
triton.atomic_xchg(plock, 0)
|
tl.atomic_xchg(plock, 0)
|
||||||
|
|
||||||
|
|
||||||
##############
|
##############
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import triton.language as tl
|
||||||
import triton
|
import triton
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
@@ -31,86 +32,86 @@ def _forward(
|
|||||||
):
|
):
|
||||||
TN = meta['TN']
|
TN = meta['TN']
|
||||||
BLOCK = meta['BLOCK']
|
BLOCK = meta['BLOCK']
|
||||||
pidhm = triton.program_id(0)
|
pidhm = tl.program_id(0)
|
||||||
pidz = triton.program_id(1)
|
pidz = tl.program_id(1)
|
||||||
# create index ranges
|
# create index ranges
|
||||||
rxm = pidhm % BLOCK
|
rxm = pidhm % BLOCK
|
||||||
rbm = pidhm // BLOCK
|
rbm = pidhm // BLOCK
|
||||||
rxn = triton.arange(0, TN) % BLOCK
|
rxn = tl.arange(0, TN) % BLOCK
|
||||||
rbn = triton.arange(0, TN) // BLOCK
|
rbn = tl.arange(0, TN) // BLOCK
|
||||||
# extract information from LUT
|
# extract information from LUT
|
||||||
header = LUT + rbm * 2
|
header = LUT + rbm * 2
|
||||||
size = triton.load(header + 0)
|
size = tl.load(header + 0)
|
||||||
offset = triton.load(header + 1)
|
offset = tl.load(header + 1)
|
||||||
check = rbn < size
|
check = rbn < size
|
||||||
rbmn = triton.where(check, rbn, size - 1)
|
rbmn = tl.where(check, rbn, size - 1)
|
||||||
# block id and column id
|
# block id and column id
|
||||||
blockid = triton.load(LUT + offset + rbmn * 4 + 0)
|
blockid = tl.load(LUT + offset + rbmn * 4 + 0)
|
||||||
columnid = triton.load(LUT + offset + rbmn * 4 + 1)
|
columnid = tl.load(LUT + offset + rbmn * 4 + 1)
|
||||||
rowid = triton.load(LUT + offset + rbmn * 4 + 2)
|
rowid = tl.load(LUT + offset + rbmn * 4 + 2)
|
||||||
headid = triton.load(LUT + offset + rbmn * 4 + 3)
|
headid = tl.load(LUT + offset + rbmn * 4 + 3)
|
||||||
# pointers to X
|
# pointers to X
|
||||||
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
||||||
x = triton.load(px, mask=check, other=-float('inf'))
|
x = tl.load(px, mask=check, other=-float('inf'))
|
||||||
x = x.to(triton.float32)
|
x = x.to(tl.float32)
|
||||||
# apply scale
|
# apply scale
|
||||||
if meta['APPLY_SCALE']:
|
if meta['APPLY_SCALE']:
|
||||||
x = x * scale
|
x = x * scale
|
||||||
# apply RPE
|
# apply RPE
|
||||||
if meta['APPLY_RPE']:
|
if meta['APPLY_RPE']:
|
||||||
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
|
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
|
||||||
rpe = triton.load(prpe, mask=check, other=0)
|
rpe = tl.load(prpe, mask=check, other=0)
|
||||||
x = x + rpe
|
x = x + rpe
|
||||||
# apply key-padding mask
|
# apply key-padding mask
|
||||||
if meta['APPLY_KP_MASK']:
|
if meta['APPLY_KP_MASK']:
|
||||||
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
|
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
|
||||||
kp_m = triton.load(pkp_m, mask=check, other=-float('inf'))
|
kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
|
||||||
if meta['KP_MASK_MUL']:
|
if meta['KP_MASK_MUL']:
|
||||||
kp_m = triton.where(kp_m == 0, -float('inf'), 0.)
|
kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
|
||||||
x = x + kp_m
|
x = x + kp_m
|
||||||
# apply attention mask
|
# apply attention mask
|
||||||
if meta['APPLY_ATTN_MASK']:
|
if meta['APPLY_ATTN_MASK']:
|
||||||
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
|
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
|
||||||
attn_m = triton.load(pattn_m, mask=check, other=-float('inf'))
|
attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
|
||||||
if meta['ATTN_MASK_MUL']:
|
if meta['ATTN_MASK_MUL']:
|
||||||
attn_m = triton.where(attn_m == 0, -float('inf'), 0.)
|
attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
|
||||||
x = x + attn_m
|
x = x + attn_m
|
||||||
# computation
|
# computation
|
||||||
x = triton.softmax(x)
|
x = tl.softmax(x)
|
||||||
triton.store(px, x, mask=check)
|
tl.store(px, x, mask=check)
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
|
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
|
||||||
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']})
|
@triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
|
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
|
||||||
pidhm = triton.program_id(0)
|
pidhm = tl.program_id(0)
|
||||||
pidz = triton.program_id(1)
|
pidz = tl.program_id(1)
|
||||||
TN = meta['TN']
|
TN = meta['TN']
|
||||||
BLOCK = meta['BLOCK']
|
BLOCK = meta['BLOCK']
|
||||||
# create index ranges
|
# create index ranges
|
||||||
rxm = pidhm % BLOCK
|
rxm = pidhm % BLOCK
|
||||||
rbm = pidhm // BLOCK
|
rbm = pidhm // BLOCK
|
||||||
rxn = triton.arange(0, TN) % BLOCK
|
rxn = tl.arange(0, TN) % BLOCK
|
||||||
rbn = triton.arange(0, TN) // BLOCK
|
rbn = tl.arange(0, TN) // BLOCK
|
||||||
# extract information from look-up table
|
# extract information from look-up table
|
||||||
header = LUT + rbm * 2
|
header = LUT + rbm * 2
|
||||||
size = triton.load(header + 0)
|
size = tl.load(header + 0)
|
||||||
offset = triton.load(header + 1)
|
offset = tl.load(header + 1)
|
||||||
# bounds checking on lut
|
# bounds checking on lut
|
||||||
check = rbn < size
|
check = rbn < size
|
||||||
rbmn = triton.where(check, rbn, size - 1)
|
rbmn = tl.where(check, rbn, size - 1)
|
||||||
# initialize pointers to block-sparse input
|
# initialize pointers to block-sparse input
|
||||||
blockid = triton.load(LUT + offset + rbmn * 4)
|
blockid = tl.load(LUT + offset + rbmn * 4)
|
||||||
X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
||||||
DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
||||||
# compute fused softmax backward
|
# compute fused softmax backward
|
||||||
x = triton.load(X, mask=check, other=0)
|
x = tl.load(X, mask=check, other=0)
|
||||||
dx = triton.load(DX, mask=check, other=0)
|
dx = tl.load(DX, mask=check, other=0)
|
||||||
x = x.to(triton.float32)
|
x = x.to(tl.float32)
|
||||||
dx = dx.to(triton.float32)
|
dx = dx.to(tl.float32)
|
||||||
y = x * (dx - triton.sum(x * dx, 0)) * scale
|
y = x * (dx - tl.sum(x * dx, 0)) * scale
|
||||||
triton.store(DX, y, mask=check)
|
tl.store(DX, y, mask=check)
|
||||||
|
|
||||||
|
|
||||||
class _softmax(torch.autograd.Function):
|
class _softmax(torch.autograd.Function):
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import triton
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -27,25 +28,25 @@ def num_warps(N):
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
|
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
|
||||||
BLOCK = meta['BLOCK']
|
BLOCK = meta['BLOCK']
|
||||||
row = triton.program_id(0)
|
row = tl.program_id(0)
|
||||||
cols = triton.arange(0, BLOCK)
|
cols = tl.arange(0, BLOCK)
|
||||||
idx = triton.load(IDX + row)
|
idx = tl.load(IDX + row)
|
||||||
# pointers to logit and probs
|
# pointers to logit and probs
|
||||||
LOGITS = LOGITS + row * N + cols
|
LOGITS = LOGITS + row * N + cols
|
||||||
WRIT_PROBS = PROBS + row * N + cols
|
WRIT_PROBS = PROBS + row * N + cols
|
||||||
READ_PROBS = PROBS + row * N + idx
|
READ_PROBS = PROBS + row * N + idx
|
||||||
# write-back negative log-probs
|
# write-back negative log-probs
|
||||||
logits = triton.load(LOGITS, mask=cols < N, other=-float('inf'))
|
logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
|
||||||
logits = logits.to(triton.float32)
|
logits = logits.to(tl.float32)
|
||||||
logits = logits - triton.max(logits, 0)
|
logits = logits - tl.max(logits, 0)
|
||||||
probs = triton.log(triton.sum(triton.exp(logits), 0)) - logits
|
probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits
|
||||||
triton.store(WRIT_PROBS, probs, mask=cols < N)
|
tl.store(WRIT_PROBS, probs, mask=cols < N)
|
||||||
# There is a bug in the compiler, which fails to insert a barrier here.
|
# There is a bug in the compiler, which fails to insert a barrier here.
|
||||||
# We add it explicitly for now. Will be fixed soon.
|
# We add it explicitly for now. Will be fixed soon.
|
||||||
triton.debug_barrier()
|
tl.debug_barrier()
|
||||||
# write-back loss
|
# write-back loss
|
||||||
probs = triton.load(READ_PROBS)
|
probs = tl.load(READ_PROBS)
|
||||||
triton.store(LOSS + row, probs)
|
tl.store(LOSS + row, probs)
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
|
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
|
||||||
@@ -53,20 +54,20 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def _backward(PROBS, IDX, DPROBS, N, **meta):
|
def _backward(PROBS, IDX, DPROBS, N, **meta):
|
||||||
BLOCK = meta['BLOCK']
|
BLOCK = meta['BLOCK']
|
||||||
row = triton.program_id(0)
|
row = tl.program_id(0)
|
||||||
cols = triton.arange(0, BLOCK)
|
cols = tl.arange(0, BLOCK)
|
||||||
idx = triton.load(IDX + row)
|
idx = tl.load(IDX + row)
|
||||||
# pointers to probs
|
# pointers to probs
|
||||||
PROBS = PROBS + row * N + cols
|
PROBS = PROBS + row * N + cols
|
||||||
# We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
# We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
||||||
# and we have -log(p[k]) stored in PROBS, so this is easy
|
# and we have -log(p[k]) stored in PROBS, so this is easy
|
||||||
probs = -triton.load(PROBS, mask=cols < N, other=float('inf'))
|
probs = -tl.load(PROBS, mask=cols < N, other=float('inf'))
|
||||||
probs = triton.exp(probs.to(triton.float32))
|
probs = tl.exp(probs.to(tl.float32))
|
||||||
delta = cols == idx
|
delta = cols == idx
|
||||||
# write result in-place in PROBS
|
# write result in-place in PROBS
|
||||||
dout = triton.load(DPROBS + row)
|
dout = tl.load(DPROBS + row)
|
||||||
din = (probs - delta) * dout
|
din = (probs - delta) * dout
|
||||||
triton.store(PROBS, din.to(triton.float16), mask=cols < N)
|
tl.store(PROBS, din.to(tl.float16), mask=cols < N)
|
||||||
|
|
||||||
|
|
||||||
class _cross_entropy(torch.autograd.Function):
|
class _cross_entropy(torch.autograd.Function):
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import triton.language as tl
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
|
|
||||||
@@ -27,8 +28,8 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
|
|||||||
GROUP_M = META['GROUP_M']
|
GROUP_M = META['GROUP_M']
|
||||||
SPLIT_K = META['SPLIT_K']
|
SPLIT_K = META['SPLIT_K']
|
||||||
# matrix multiplication
|
# matrix multiplication
|
||||||
pid = triton.program_id(0)
|
pid = tl.program_id(0)
|
||||||
pid_z = triton.program_id(1)
|
pid_z = tl.program_id(1)
|
||||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||||
# re-order program ID for better L2 performance
|
# re-order program ID for better L2 performance
|
||||||
@@ -38,46 +39,46 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
|
|||||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||||
pid_n = (pid % width) // (group_size)
|
pid_n = (pid % width) // (group_size)
|
||||||
# do matrix multiplication
|
# do matrix multiplication
|
||||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
rk = triton.arange(0, BLOCK_K)
|
rk = tl.arange(0, BLOCK_K)
|
||||||
# pointers
|
# pointers
|
||||||
K = K // SPLIT_K
|
K = K // SPLIT_K
|
||||||
A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak)
|
A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||||
B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
||||||
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||||
for k in range(K, 0, -BLOCK_K):
|
for k in range(K, 0, -BLOCK_K):
|
||||||
if META['EVEN_K']:
|
if META['EVEN_K']:
|
||||||
a = triton.load(A)
|
a = tl.load(A)
|
||||||
b = triton.load(B)
|
b = tl.load(B)
|
||||||
else:
|
else:
|
||||||
a = triton.load(A, mask=rk[None, :] < k, other=0.)
|
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
||||||
b = triton.load(B, mask=rk[:, None] < k, other=0.)
|
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
||||||
acc += triton.dot(a, b)
|
acc += tl.dot(a, b)
|
||||||
A += BLOCK_K * stride_ak
|
A += BLOCK_K * stride_ak
|
||||||
B += BLOCK_K * stride_bk
|
B += BLOCK_K * stride_bk
|
||||||
acc = acc.to(triton.float16)
|
acc = acc.to(tl.float16)
|
||||||
# rematerialize rm and rn to save registers
|
# rematerialize rm and rn to save registers
|
||||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||||
# handles write-back with reduction-splitting
|
# handles write-back with reduction-splitting
|
||||||
if SPLIT_K == 1:
|
if SPLIT_K == 1:
|
||||||
triton.store(C, acc, mask=mask)
|
tl.store(C, acc, mask=mask)
|
||||||
else:
|
else:
|
||||||
LOCKS = LOCKS + triton.program_id(0)
|
LOCKS = LOCKS + tl.program_id(0)
|
||||||
COUNT = LOCKS + triton.num_programs(0)
|
COUNT = LOCKS + tl.num_programs(0)
|
||||||
while triton.atomic_cas(LOCKS, 0, 1) == 1:
|
while tl.atomic_cas(LOCKS, 0, 1) == 1:
|
||||||
pass
|
pass
|
||||||
count = triton.load(COUNT)
|
count = tl.load(COUNT)
|
||||||
if count == 0:
|
if count == 0:
|
||||||
triton.store(C, acc, mask=mask)
|
tl.store(C, acc, mask=mask)
|
||||||
else:
|
else:
|
||||||
curr = triton.load(C, mask=mask, other=0.)
|
curr = tl.load(C, mask=mask, other=0.)
|
||||||
triton.store(C, acc + curr, mask=mask)
|
tl.store(C, acc + curr, mask=mask)
|
||||||
triton.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
|
tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
|
||||||
triton.atomic_xchg(LOCKS, 0)
|
tl.atomic_xchg(LOCKS, 0)
|
||||||
|
|
||||||
|
|
||||||
class _matmul(torch.autograd.Function):
|
class _matmul(torch.autograd.Function):
|
||||||
|
@@ -81,6 +81,22 @@ def random(shape, dtype, device):
|
|||||||
|
|
||||||
|
|
||||||
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
|
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
|
||||||
|
"""
|
||||||
|
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
|
||||||
|
the 20-th and 80-th performance percentile.
|
||||||
|
|
||||||
|
:param fn: Function to benchmark
|
||||||
|
:type fn: Callable
|
||||||
|
:param warmup: Warmup time (in ms)
|
||||||
|
:type warmup: int
|
||||||
|
:param rep: Repetition time (in ms)
|
||||||
|
:type rep: int
|
||||||
|
:param grad_to_none: Reset the gradient of the provided tensor to None
|
||||||
|
:type grad_to_none: torch.tensor, optional
|
||||||
|
:param percentiles: Performance percentile to return in addition to the median.
|
||||||
|
:type percentiles: list[float]
|
||||||
|
"""
|
||||||
|
|
||||||
# Estimate the runtime of the function
|
# Estimate the runtime of the function
|
||||||
fn()
|
fn()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@@ -125,13 +141,16 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
|
|||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
class Benchmark:
|
||||||
|
"""
|
||||||
|
This class is used by the :code:`perf_report` function to generate line plots with a concise API.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
x_names,
|
x_names,
|
||||||
x_vals,
|
x_vals,
|
||||||
y_name,
|
line_arg,
|
||||||
y_vals,
|
line_vals,
|
||||||
y_lines,
|
line_names,
|
||||||
plot_name,
|
plot_name,
|
||||||
args,
|
args,
|
||||||
xlabel='',
|
xlabel='',
|
||||||
@@ -139,12 +158,38 @@ class Benchmark:
|
|||||||
x_log=False,
|
x_log=False,
|
||||||
y_log=False,
|
y_log=False,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Constructor
|
||||||
|
|
||||||
|
:param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value.
|
||||||
|
:type x_names: List[str]
|
||||||
|
:param x_vals: List of values to use for the arguments in :code:`x_names`.
|
||||||
|
:type x_vals: List[Any]
|
||||||
|
:param line_arg: Argument name for which different values correspond to different lines in the plot.
|
||||||
|
:type line_arg: str
|
||||||
|
:param line_vals: List of values to use for the arguments in :code:`line_arg`.
|
||||||
|
:type line_vals: List[str]
|
||||||
|
:param line_names: Label names for the different lines.
|
||||||
|
:type line_names: List[str]
|
||||||
|
:param plot_name: Name of the plot.
|
||||||
|
:type plot_name: str
|
||||||
|
:param args: List of arguments to remain fixed throughout the benchmark.
|
||||||
|
:type args: List[str]
|
||||||
|
:param xlabel: Label for the x axis of the plot.
|
||||||
|
:type xlabel: str, optional
|
||||||
|
:param ylabel: Label for the y axis of the plot.
|
||||||
|
:type ylabel: str, optional
|
||||||
|
:param x_log: Whether the x axis should be log scale.
|
||||||
|
:type x_log: bool, optional
|
||||||
|
:param y_log: Whether the y axis should be log scale.
|
||||||
|
:type y_log: bool, optional
|
||||||
|
"""
|
||||||
self.x_names = x_names
|
self.x_names = x_names
|
||||||
self.x_vals = x_vals
|
self.x_vals = x_vals
|
||||||
self.x_log = x_log
|
self.x_log = x_log
|
||||||
self.y_name = y_name
|
self.line_arg = line_arg
|
||||||
self.y_vals = y_vals
|
self.line_vals = line_vals
|
||||||
self.y_lines = y_lines
|
self.line_names = line_names
|
||||||
self.y_log = y_log
|
self.y_log = y_log
|
||||||
# plot info
|
# plot info
|
||||||
self.xlabel = xlabel
|
self.xlabel = xlabel
|
||||||
@@ -162,15 +207,15 @@ class Mark:
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import os
|
import os
|
||||||
y_mean = bench.y_lines
|
y_mean = bench.line_names
|
||||||
y_min = [f'{x}-min' for x in bench.y_lines]
|
y_min = [f'{x}-min' for x in bench.line_names]
|
||||||
y_max = [f'{x}-max' for x in bench.y_lines]
|
y_max = [f'{x}-max' for x in bench.line_names]
|
||||||
df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max)
|
df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max)
|
||||||
for x in bench.x_vals:
|
for x in bench.x_vals:
|
||||||
x_args = {x_name: x for x_name in bench.x_names}
|
x_args = {x_name: x for x_name in bench.x_names}
|
||||||
row_mean, row_min, row_max = [], [], []
|
row_mean, row_min, row_max = [], [], []
|
||||||
for y in bench.y_vals:
|
for y in bench.line_vals:
|
||||||
ret = self.fn(**x_args, **{bench.y_name: y}, **bench.args)
|
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
|
||||||
try:
|
try:
|
||||||
y_mean, y_min, y_max = ret
|
y_mean, y_min, y_max = ret
|
||||||
except TypeError:
|
except TypeError:
|
||||||
@@ -183,7 +228,7 @@ class Mark:
|
|||||||
plt.figure()
|
plt.figure()
|
||||||
ax = plt.subplot()
|
ax = plt.subplot()
|
||||||
x = bench.x_names[0]
|
x = bench.x_names[0]
|
||||||
for y in bench.y_lines:
|
for y in bench.line_names:
|
||||||
y_min, y_max = df[y + '-min'], df[y + '-max']
|
y_min, y_max = df[y + '-min'], df[y + '-max']
|
||||||
ax.plot(df[x], df[y], label=y)
|
ax.plot(df[x], df[y], label=y)
|
||||||
if y_min is not None and y_max is not None:
|
if y_min is not None and y_max is not None:
|
||||||
@@ -199,7 +244,7 @@ class Mark:
|
|||||||
plt.show()
|
plt.show()
|
||||||
if save_path:
|
if save_path:
|
||||||
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
|
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
|
||||||
df = df[[bench.x_names[0]] + bench.y_lines]
|
df = df[[bench.x_names[0]] + bench.line_names]
|
||||||
if print_data:
|
if print_data:
|
||||||
print(df)
|
print(df)
|
||||||
if save_path:
|
if save_path:
|
||||||
@@ -220,5 +265,11 @@ class Mark:
|
|||||||
|
|
||||||
|
|
||||||
def perf_report(benchmarks):
|
def perf_report(benchmarks):
|
||||||
|
"""
|
||||||
|
Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
|
||||||
|
|
||||||
|
:param benchmarks: Benchmarking configurations.
|
||||||
|
:type benchmarks: List of :class:`Benchmark`
|
||||||
|
"""
|
||||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@@ -13,6 +13,7 @@ In this tutorial, you will write a simple vector addition using Triton and learn
|
|||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import triton.language as tl
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
|
|
||||||
@@ -24,19 +25,19 @@ def _add(
|
|||||||
N, # Size of the vector
|
N, # Size of the vector
|
||||||
**meta # Optional meta-parameters for the kernel
|
**meta # Optional meta-parameters for the kernel
|
||||||
):
|
):
|
||||||
pid = triton.program_id(0)
|
pid = tl.program_id(0)
|
||||||
# Create an offset for the blocks of pointers to be
|
# Create an offset for the blocks of pointers to be
|
||||||
# processed by this program instance
|
# processed by this program instance
|
||||||
offsets = pid * meta['BLOCK'] + triton.arange(0, meta['BLOCK'])
|
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
|
||||||
# Create a mask to guard memory operations against
|
# Create a mask to guard memory operations against
|
||||||
# out-of-bounds accesses
|
# out-of-bounds accesses
|
||||||
mask = offsets < N
|
mask = offsets < N
|
||||||
# Load x
|
# Load x
|
||||||
x = triton.load(X + offsets, mask=mask)
|
x = tl.load(X + offsets, mask=mask)
|
||||||
y = triton.load(Y + offsets, mask=mask)
|
y = tl.load(Y + offsets, mask=mask)
|
||||||
# Write back x + y
|
# Write back x + y
|
||||||
z = x + y
|
z = x + y
|
||||||
triton.store(Z + offsets, z)
|
tl.store(Z + offsets, z)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
@@ -89,9 +90,9 @@ print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.
|
|||||||
x_names=['size'], # argument names to use as an x-axis for the plot
|
x_names=['size'], # argument names to use as an x-axis for the plot
|
||||||
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
|
x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name`
|
||||||
x_log=True, # x axis is logarithmic
|
x_log=True, # x axis is logarithmic
|
||||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||||
y_vals=['torch', 'triton'], # possible keys for `y_name`
|
line_vals=['torch', 'triton'], # possible values for `line_arg`
|
||||||
y_lines=["Torch", "Triton"], # label name for the lines
|
line_names=["Torch", "Triton"], # label name for the lines
|
||||||
ylabel="GB/s", # label name for the y-axis
|
ylabel="GB/s", # label name for the y-axis
|
||||||
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
|
plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||||
args={} # values for function arguments not in `x_names` and `y_name`
|
args={} # values for function arguments not in `x_names` and `y_name`
|
||||||
|
@@ -46,28 +46,29 @@ def naive_softmax(x):
|
|||||||
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
|
# so we need to internally "pad" tiles and guard the memory operations properly if we want to handle any possible input shapes:
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
|
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
|
||||||
# row index
|
# row index
|
||||||
m = triton.program_id(0)
|
m = tl.program_id(0)
|
||||||
# col indices
|
# col indices
|
||||||
n = triton.arange(0, meta['BLOCK'])
|
n = tl.arange(0, meta['BLOCK'])
|
||||||
# the memory address of all the elements
|
# the memory address of all the elements
|
||||||
# that we want to load can be computed as follows
|
# that we want to load can be computed as follows
|
||||||
X = X + m * stride_xm + n
|
X = X + m * stride_xm + n
|
||||||
x = triton.load(X, mask=n < N, other=-float('inf'))
|
x = tl.load(X, mask=n < N, other=-float('inf'))
|
||||||
# Substract maximum for numerical stability
|
# Substract maximum for numerical stability
|
||||||
z = x - triton.max(x, axis=0)
|
z = x - tl.max(x, axis=0)
|
||||||
# Note that exponentials in Triton are fast
|
# Note that exponentials in Triton are fast
|
||||||
# but approximate (i.e., think __expf in CUDA)
|
# but approximate (i.e., think __expf in CUDA)
|
||||||
num = triton.exp(z)
|
num = tl.exp(z)
|
||||||
denom = triton.sum(num, axis=0)
|
denom = tl.sum(num, axis=0)
|
||||||
y = num / denom
|
y = num / denom
|
||||||
# Write back to Y
|
# Write back to Y
|
||||||
Y = Y + m * stride_ym + n
|
Y = Y + m * stride_ym + n
|
||||||
triton.store(Y, y, mask=n < N)
|
tl.store(Y, y, mask=n < N)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
@@ -132,9 +133,9 @@ print(torch.allclose(y_tri, y_ref))
|
|||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=['N'], # argument names to use as an x-axis for the plot
|
x_names=['N'], # argument names to use as an x-axis for the plot
|
||||||
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
|
x_vals=[256 * i for i in range(2, 50)], # different possible values for `x_name`
|
||||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||||
y_vals=['torch', 'triton', 'naive'], # possible keys for `y_name`
|
line_vals=['torch', 'triton', 'naive'], # possible values for `line_arg``
|
||||||
y_lines=["Torch", "Triton", 'Naive'], # label name for the lines
|
line_names=["Torch", "Triton", 'Naive'], # label name for the lines
|
||||||
ylabel="GB/s", # label name for the y-axis
|
ylabel="GB/s", # label name for the y-axis
|
||||||
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
|
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||||
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
|
args={'M': 4096} # values for function arguments not in `x_names` and `y_name`
|
||||||
|
@@ -115,6 +115,7 @@ You will specifically learn about:
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
# %
|
# %
|
||||||
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
||||||
@@ -124,9 +125,9 @@ import triton
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def sigmoid(x):
|
def sigmoid(x):
|
||||||
ret_true = 1 / (1 + triton.exp(-x))
|
ret_true = 1 / (1 + tl.exp(-x))
|
||||||
ret_false = triton.exp(x) / (1 + triton.exp(x))
|
ret_false = tl.exp(x) / (1 + tl.exp(x))
|
||||||
return triton.where(x >= 0, ret_true, ret_false)
|
return tl.where(x >= 0, ret_true, ret_false)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -151,7 +152,7 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
|
|||||||
BLOCK_K = META['BLOCK_K']
|
BLOCK_K = META['BLOCK_K']
|
||||||
GROUP_M = 8
|
GROUP_M = 8
|
||||||
# matrix multiplication
|
# matrix multiplication
|
||||||
pid = triton.program_id(0)
|
pid = tl.program_id(0)
|
||||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||||
# re-order program ID for better L2 performance
|
# re-order program ID for better L2 performance
|
||||||
@@ -161,16 +162,16 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
|
|||||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||||
pid_n = (pid % width) // (group_size)
|
pid_n = (pid % width) // (group_size)
|
||||||
# do matrix multiplication
|
# do matrix multiplication
|
||||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
rk = triton.arange(0, BLOCK_K)
|
rk = tl.arange(0, BLOCK_K)
|
||||||
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
|
A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||||
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
||||||
acc = triton.zeros((BLOCK_M, BLOCK_N), dtype=triton.float32)
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||||
for k in range(K, 0, -BLOCK_K):
|
for k in range(K, 0, -BLOCK_K):
|
||||||
a = triton.load(A)
|
a = tl.load(A)
|
||||||
b = triton.load(B)
|
b = tl.load(B)
|
||||||
acc += triton.dot(a, b)
|
acc += tl.dot(a, b)
|
||||||
A += BLOCK_K * stride_ak
|
A += BLOCK_K * stride_ak
|
||||||
B += BLOCK_K * stride_bk
|
B += BLOCK_K * stride_bk
|
||||||
# triton can accept arbitrary activation function
|
# triton can accept arbitrary activation function
|
||||||
@@ -178,11 +179,11 @@ def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
|
|||||||
if META['ACTIVATION']:
|
if META['ACTIVATION']:
|
||||||
acc = META['ACTIVATION'](acc)
|
acc = META['ACTIVATION'](acc)
|
||||||
# rematerialize rm and rn to save registers
|
# rematerialize rm and rn to save registers
|
||||||
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||||
mask = (rm[:, None] < M) & (rn[None, :] < N)
|
mask = (rm[:, None] < M) & (rn[None, :] < N)
|
||||||
triton.store(C, acc, mask=mask)
|
tl.store(C, acc, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
@@ -238,9 +239,9 @@ print(triton.testing.allclose(c_0, c_1))
|
|||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||||
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
||||||
y_name='provider', # argument name whose value corresponds to a different line in the plot
|
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||||
y_vals=['cublas', 'triton'], # possible keys for `y_name`
|
line_vals=['cublas', 'triton'], # possible values for `line_arg``
|
||||||
y_lines=["cuBLAS", "Triton"], # label name for the lines
|
line_names=["cuBLAS", "Triton"], # label name for the lines
|
||||||
ylabel="TFLOPS", # label name for the y-axis
|
ylabel="TFLOPS", # label name for the y-axis
|
||||||
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot.
|
||||||
args={}
|
args={}
|
||||||
@@ -258,4 +259,4 @@ def benchmark(M, N, K, provider):
|
|||||||
return perf(ms), perf(max_ms), perf(min_ms)
|
return perf(ms), perf(max_ms), perf(min_ms)
|
||||||
|
|
||||||
|
|
||||||
benchmark.run(print_data=True)
|
benchmark.run(show_plots=True, print_data=True)
|
Reference in New Issue
Block a user