[CODEGEN] Major performance improvements on A100 (#70)
Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
committed by
Philippe Tillet
parent
045ab5d62a
commit
5b83259592
@@ -2,58 +2,74 @@ import triton
|
||||
import torch
|
||||
|
||||
# square benchmarks
|
||||
nt = {False: 'n', True: 't'}
|
||||
nt = {False: "n", True: "t"}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names = ['M', 'N', 'K'],
|
||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
y_name = 'provider',
|
||||
y_vals = ['torch', 'triton', 'cutlass'],
|
||||
y_lines = ['Torch', 'Triton', 'CUTLASS'],
|
||||
ylabel = 'TFLOPS',
|
||||
loglog = False,
|
||||
plot_name = f'matmul-square-{nt[AT]}{nt[BT]}',
|
||||
args = {'AT': False, 'BT': False, 'dtype': torch.float16}
|
||||
)\
|
||||
for AT in [False, True] for BT in [False, True]
|
||||
x_names=["M", "N", "K"],
|
||||
x_vals=[512 * i for i in range(1, 16)],
|
||||
y_name="provider",
|
||||
y_vals=["torch", "triton", "cutlass"],
|
||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
loglog=False,
|
||||
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
|
||||
args={"AT": AT, "BT": BT, "dtype": torch.float16},
|
||||
) for AT in [False, True] for BT in [False, True]
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=5):
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
||||
import os
|
||||
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
|
||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||
if AT:
|
||||
a = a.t()
|
||||
if BT:
|
||||
b = b.t()
|
||||
num_flops = 2 * M * N * K
|
||||
if provider == 'torch':
|
||||
if provider == "torch":
|
||||
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
||||
torch_tflops = num_flops / torch_ms * 1e-9
|
||||
return torch_tflops
|
||||
if provider == 'triton':
|
||||
if provider == "triton":
|
||||
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
||||
triton_tflops = num_flops / triton_ms * 1e-9
|
||||
return triton_tflops
|
||||
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ:
|
||||
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
|
||||
import subprocess
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
|
||||
# run program specified by CUTLASS_PROFILER env variable
|
||||
layout_a = 'column' if AT else 'row'
|
||||
layout_b = 'column' if BT else 'row'
|
||||
layout_a = "column" if AT else "row"
|
||||
layout_b = "column" if BT else "row"
|
||||
# create temporary file name
|
||||
fd, fname = tempfile.mkstemp()
|
||||
# run program and gets its output
|
||||
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
|
||||
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', f'--warmup-iterations={warmup}', \
|
||||
f'--profiling-iterations={rep}', f'--output={fname}', '--verbose=false']
|
||||
cmd = [
|
||||
os.environ["CUTLASS_PROFILER"],
|
||||
f"--m={M}",
|
||||
f"--n={N}",
|
||||
f"--k={K}",
|
||||
f"--A=f16:{layout_a}",
|
||||
f"--B=f16:{layout_b}",
|
||||
"--C=f16:column",
|
||||
"--accum=f32",
|
||||
"--operation=gemm",
|
||||
"--verification-enabled=false",
|
||||
f"--warmup-iterations={warmup}",
|
||||
f"--profiling-iterations={rep}",
|
||||
f"--output={fname}",
|
||||
"--verbose=false",
|
||||
]
|
||||
# run cmd
|
||||
subprocess.run(cmd, stdout=subprocess.PIPE)
|
||||
# read CSV output
|
||||
df_c = pd.read_csv(f'{fname}.gemm.csv')
|
||||
cutlass_tflops = max(df_c['GFLOPs']) / 1e3
|
||||
df_c = pd.read_csv(f"{fname}.gemm.csv")
|
||||
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
|
||||
return cutlass_tflops
|
||||
return None
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
bench_op.run()
|
||||
|
102
python/setup.py
102
python/setup.py
@@ -15,21 +15,21 @@ import distutils.spawn
|
||||
import torch
|
||||
|
||||
def find_llvm():
|
||||
versions = ['-10', '-10.0', '']
|
||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
||||
versions = ["-10", "-10.0", ""]
|
||||
supported = ["llvm-config{v}".format(v=v) for v in versions]
|
||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||
paths = [p for p in paths if p is not None]
|
||||
if paths:
|
||||
return paths[0]
|
||||
config = distutils.spawn.find_executable('llvm-config')
|
||||
instructions = 'Please install llvm-10-dev'
|
||||
config = distutils.spawn.find_executable("llvm-config")
|
||||
instructions = "Please install llvm-10-dev"
|
||||
if config is None:
|
||||
raise RuntimeError('Could not find llvm-config. ' + instructions)
|
||||
version = os.popen('{config} --version'.format(config=config)).read()
|
||||
raise RuntimeError('Version {v} not supported. '.format(v=version) + instructions)
|
||||
raise RuntimeError("Could not find llvm-config. " + instructions)
|
||||
version = os.popen("{config} --version".format(config=config)).read()
|
||||
raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions)
|
||||
|
||||
class CMakeExtension(Extension):
|
||||
def __init__(self, name, path, sourcedir=''):
|
||||
def __init__(self, name, path, sourcedir=""):
|
||||
Extension.__init__(self, name, sources=[])
|
||||
self.sourcedir = os.path.abspath(sourcedir)
|
||||
self.path = path
|
||||
@@ -37,84 +37,84 @@ class CMakeExtension(Extension):
|
||||
class CMakeBuild(build_ext):
|
||||
def run(self):
|
||||
try:
|
||||
out = subprocess.check_output(['cmake', '--version'])
|
||||
out = subprocess.check_output(["cmake", "--version"])
|
||||
except OSError:
|
||||
raise RuntimeError("CMake must be installed to build the following extensions: " +
|
||||
", ".join(e.name for e in self.extensions))
|
||||
|
||||
if platform.system() == "Windows":
|
||||
cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1))
|
||||
if cmake_version < '3.1.0':
|
||||
cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1))
|
||||
if cmake_version < "3.1.0":
|
||||
raise RuntimeError("CMake >= 3.1.0 is required on Windows")
|
||||
|
||||
for ext in self.extensions:
|
||||
self.build_extension(ext)
|
||||
|
||||
def build_extension(self, ext):
|
||||
#self.debug = True
|
||||
# self.debug = True
|
||||
self.debug = False
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||
# python directories
|
||||
python_include_dirs = distutils.sysconfig.get_python_inc()
|
||||
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
|
||||
python_lib_dirs = distutils.sysconfig.get_config_var("LIBDIR")
|
||||
torch_include_dirs = include_paths(True)
|
||||
torch_library_dirs = library_paths(True)
|
||||
cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
|
||||
cmake_args = [
|
||||
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
|
||||
'-DBUILD_TUTORIALS=OFF',
|
||||
'-DBUILD_PYTHON_MODULE=ON',
|
||||
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||
"-DBUILD_TUTORIALS=OFF",
|
||||
"-DBUILD_PYTHON_MODULE=ON",
|
||||
#'-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
|
||||
'-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)),
|
||||
'-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)),
|
||||
'-DTORCH_CXX11_ABI=' + cxx11abi,
|
||||
'-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton',
|
||||
'-DLLVM_CONFIG=' + find_llvm()
|
||||
"-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs] + include_paths(True)),
|
||||
"-DPYTHON_LINK_DIRS=" + ";".join(library_paths(True)),
|
||||
"-DTORCH_CXX11_ABI=" + cxx11abi,
|
||||
"-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton",
|
||||
"-DLLVM_CONFIG=" + find_llvm(),
|
||||
]
|
||||
# configuration
|
||||
cfg = 'Debug' if self.debug else 'Release'
|
||||
cfg = 'Release'
|
||||
build_args = ['--config', cfg]
|
||||
cfg = "Debug" if self.debug else "Release"
|
||||
build_args = ["--config", cfg]
|
||||
|
||||
if platform.system() == "Windows":
|
||||
cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)]
|
||||
cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
|
||||
if sys.maxsize > 2**32:
|
||||
cmake_args += ['-A', 'x64']
|
||||
build_args += ['--', '/m']
|
||||
cmake_args += ["-A", "x64"]
|
||||
build_args += ["--", "/m"]
|
||||
else:
|
||||
cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg]
|
||||
build_args += ['--', '-j4']
|
||||
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
|
||||
build_args += ["--", "-j4"]
|
||||
|
||||
env = os.environ.copy()
|
||||
if not os.path.exists(self.build_temp):
|
||||
os.makedirs(self.build_temp)
|
||||
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
|
||||
subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env)
|
||||
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
|
||||
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))
|
||||
subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env)
|
||||
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
|
||||
|
||||
setup(
|
||||
name='triton',
|
||||
version='1.0.0',
|
||||
author='Philippe Tillet',
|
||||
author_email='phil@openai.com',
|
||||
description='A language and compiler for custom Deep Learning operations',
|
||||
long_description='',
|
||||
packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'],
|
||||
install_requires=['numpy', 'torch'],
|
||||
package_data={'triton/ops': ['*.c'], 'triton/ops/blocksparse': ['*.c']},
|
||||
name="triton",
|
||||
version="1.0.0",
|
||||
author="Philippe Tillet",
|
||||
author_email="phil@openai.com",
|
||||
description="A language and compiler for custom Deep Learning operations",
|
||||
long_description="",
|
||||
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
|
||||
install_requires=["numpy", "torch"],
|
||||
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
||||
include_package_data=True,
|
||||
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
|
||||
cmdclass={'build_ext': CMakeBuild},
|
||||
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||
cmdclass={"build_ext": CMakeBuild},
|
||||
zip_safe=False,
|
||||
# for PyPI
|
||||
keywords=['Compiler', 'Deep Learning'],
|
||||
url='https://github.com/ptillet/triton/',
|
||||
download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz',
|
||||
keywords=["Compiler", "Deep Learning"],
|
||||
url="https://github.com/ptillet/triton/",
|
||||
download_url="https://github.com/ptillet/triton/archive/v0.1.tar.gz",
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
|
||||
'Intended Audience :: Developers', # Define that your audience are developers
|
||||
'Topic :: Software Development :: Build Tools',
|
||||
'License :: OSI Approved :: MIT License', # Again, pick a license
|
||||
'Programming Language :: Python :: 3.6',
|
||||
"Development Status :: 3 - Alpha", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
|
||||
"Intended Audience :: Developers", # Define that your audience are developers
|
||||
"Topic :: Software Development :: Build Tools",
|
||||
"License :: OSI Approved :: MIT License", # Again, pick a license
|
||||
"Programming Language :: Python :: 3.6",
|
||||
],
|
||||
)
|
||||
|
@@ -2,29 +2,17 @@ import torch
|
||||
import triton
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"MODE, TRANS_A, TRANS_B, BLOCK",
|
||||
[
|
||||
(mode, at, bt, block)
|
||||
for mode in ["sdd", "dsd", "dds"]
|
||||
for at in [False, True]
|
||||
for bt in [False, True]
|
||||
for block in [16, 32, 64]
|
||||
],
|
||||
[(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True]
|
||||
for block in [16, 32, 64]],
|
||||
)
|
||||
def test_matmul(
|
||||
MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384
|
||||
):
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
# create inputs
|
||||
a = torch.randn(
|
||||
(Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda"
|
||||
)
|
||||
b = torch.randn(
|
||||
(Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda"
|
||||
)
|
||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
|
||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
|
||||
shape = {
|
||||
"sdd": (M, N),
|
||||
"dsd": (a.shape[2], a.shape[3]),
|
||||
@@ -32,9 +20,7 @@ def test_matmul(
|
||||
}[MODE]
|
||||
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
||||
# triton result
|
||||
op = triton.ops.blocksparse.matmul(
|
||||
layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B
|
||||
)
|
||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||
rc = op(ra, rb)
|
||||
@@ -49,7 +35,6 @@ def test_matmul(
|
||||
# compare
|
||||
assert triton.testing.allclose(rc, tc)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK, WIDTH",
|
||||
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
|
||||
@@ -62,12 +47,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
|
||||
# create inputs
|
||||
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
||||
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
|
||||
at_mask = torch.randint(
|
||||
low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda"
|
||||
)
|
||||
kp_mask = torch.randint(
|
||||
low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda"
|
||||
)
|
||||
at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
|
||||
kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
|
||||
kp_mask[kp_mask == 1.0] = float("-inf")
|
||||
# triton result
|
||||
op = triton.ops.blocksparse.softmax(layout, BLOCK)
|
||||
@@ -94,7 +75,6 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
|
||||
# compare
|
||||
assert triton.testing.allclose(ry, ty)
|
||||
|
||||
|
||||
def test_attention_fwd_bwd(
|
||||
input_scale=1.0,
|
||||
tol=2e-2,
|
||||
@@ -108,10 +88,7 @@ def test_attention_fwd_bwd(
|
||||
# inputs
|
||||
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
||||
qkvs = [
|
||||
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True)
|
||||
.to(dtype)
|
||||
.cuda()
|
||||
for _ in range(3)
|
||||
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
|
||||
]
|
||||
attn_mask = torch.tril(
|
||||
torch.ones(
|
||||
@@ -129,11 +106,9 @@ def test_attention_fwd_bwd(
|
||||
query.retain_grad()
|
||||
key.retain_grad()
|
||||
value.retain_grad()
|
||||
attn_out = triton_attention(
|
||||
layout, block, attn_mask, query=query, key=key, value=value, scale=scale
|
||||
)
|
||||
attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
|
||||
# ad hoc loss
|
||||
loss = (attn_out ** 2).mean()
|
||||
loss = (attn_out**2).mean()
|
||||
loss.backward()
|
||||
grads = [query.grad, key.grad, value.grad]
|
||||
|
||||
@@ -148,17 +123,16 @@ def test_attention_fwd_bwd(
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
|
||||
# ad hoc loss
|
||||
torch_loss = (torch_attn_out ** 2).mean()
|
||||
torch_loss = (torch_attn_out**2).mean()
|
||||
torch_loss.backward()
|
||||
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
|
||||
|
||||
# comparison
|
||||
print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
|
||||
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
|
||||
torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol)
|
||||
for g1, g2 in zip(grads, torch_grads):
|
||||
torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol)
|
||||
|
||||
|
||||
def triton_attention(
|
||||
layout,
|
||||
block: int,
|
||||
@@ -168,12 +142,8 @@ def triton_attention(
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(
|
||||
layout, block, "sdd", trans_a=False, trans_b=True
|
||||
)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(
|
||||
layout, block, "dsd", trans_a=False, trans_b=False
|
||||
)
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
|
||||
sparse_softmax = triton.ops.blocksparse.softmax(
|
||||
layout,
|
||||
block,
|
||||
|
@@ -4,7 +4,7 @@ import triton
|
||||
import torch
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE",
|
||||
"TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE",
|
||||
itertools.chain(*[
|
||||
[
|
||||
# 1 warp
|
||||
@@ -17,14 +17,14 @@ import torch
|
||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
# # 2 warp
|
||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
# # 4 warp
|
||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||
@@ -40,24 +40,28 @@ import torch
|
||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE)
|
||||
] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True]
|
||||
]))
|
||||
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE]
|
||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True]
|
||||
]),
|
||||
)
|
||||
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
torch.manual_seed(0)
|
||||
triton.ops._matmul._kernels = dict()
|
||||
triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
|
||||
if M is None: M = TM
|
||||
if N is None: N = TN
|
||||
if K is None: K = TK * TZ
|
||||
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
||||
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
||||
triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)]
|
||||
if M is None:
|
||||
M = TM
|
||||
if N is None:
|
||||
N = TN
|
||||
if K is None:
|
||||
K = TK * SPLITK
|
||||
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.ops.matmul(a, b)
|
||||
assert triton.testing.allclose(th_c, tt_c)
|
||||
assert triton.testing.allclose(th_c, tt_c)
|
||||
|
@@ -1,198 +1,199 @@
|
||||
__global__ void NAME (TYPE* A __readonly __noalias __aligned(16),
|
||||
TYPE* B __readonly __noalias __aligned(16),
|
||||
TYPE* C __noalias __aligned(16),
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8),
|
||||
long stride_za __multipleof(8),
|
||||
long stride_zb __multipleof(8),
|
||||
long stride_zc __multipleof(8),
|
||||
long stride_ha __multipleof(8),
|
||||
long stride_hb __multipleof(8),
|
||||
long stride_hc __multipleof(8),
|
||||
int DS0, int DS1,
|
||||
int SDD_K __multipleof(16),
|
||||
int SDD_off_width,
|
||||
int* lut, int* locks, int nlocks) {
|
||||
/* ---------------- */
|
||||
/* Prologue */
|
||||
/* ---------------- */
|
||||
// program ids
|
||||
int pid0 = get_program_id(0);
|
||||
int pid1 = get_program_id(1);
|
||||
int pidz = get_program_id(2);
|
||||
__global__ void NAME(TYPE *A __readonly __noalias __aligned(16),
|
||||
TYPE *B __readonly __noalias __aligned(16),
|
||||
TYPE *C __noalias __aligned(16),
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8),
|
||||
long stride_za __multipleof(8),
|
||||
long stride_zb __multipleof(8),
|
||||
long stride_zc __multipleof(8),
|
||||
long stride_ha __multipleof(8),
|
||||
long stride_hb __multipleof(8),
|
||||
long stride_hc __multipleof(8),
|
||||
int DS0, int DS1,
|
||||
int SDD_K __multipleof(16),
|
||||
int SDD_off_width,
|
||||
int *lut, int *locks, int nlocks) {
|
||||
/* ---------------- */
|
||||
/* Prologue */
|
||||
/* ---------------- */
|
||||
// program ids
|
||||
int pid0 = get_program_id(0);
|
||||
int pid1 = get_program_id(1);
|
||||
int pidz = get_program_id(2);
|
||||
#ifdef SDD
|
||||
// load LUT header
|
||||
pid1 = pid1 + SDD_off_width;
|
||||
int blockidm[TM] = (0 ... TM) / BLOCK;
|
||||
int blockidn[TN] = (0 ... TN) / BLOCK;
|
||||
int offlutm[TM] = blockidm*(TN/BLOCK)*4;
|
||||
int offlutn[TN] = blockidn*4;
|
||||
int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4;
|
||||
int z = *(header + 0);
|
||||
int i[TM] = *(header + 1 + offlutm);
|
||||
int j[TN] = *(header + 2 + offlutn);
|
||||
int AS1 = SDD_K / TZ;
|
||||
int lockid = select(TZ > 1, 1, 0);
|
||||
int offka = pid0 * AS1;
|
||||
int offkb = pid0 * AS1;
|
||||
int offmc = 0;
|
||||
int offnc = 0;
|
||||
int offpa = 0;
|
||||
int offpb = 0;
|
||||
int maxid = TZ;
|
||||
int offhc = 0;
|
||||
int offha = z;
|
||||
int offhb = z;
|
||||
int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK);
|
||||
int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK);
|
||||
// load LUT header
|
||||
pid1 = pid1 + SDD_off_width;
|
||||
int blockidm[TM] = (0 ... TM) / BLOCK;
|
||||
int blockidn[TN] = (0 ... TN) / BLOCK;
|
||||
int offlutm[TM] = blockidm * (TN / BLOCK) * 4;
|
||||
int offlutn[TN] = blockidn * 4;
|
||||
int *header = lut + pid1 * (TM / BLOCK) * (TN / BLOCK) * 4;
|
||||
int z = *(header + 0);
|
||||
int i[TM] = *(header + 1 + offlutm);
|
||||
int j[TN] = *(header + 2 + offlutn);
|
||||
int AS1 = SDD_K / TZ;
|
||||
int lockid = select(TZ > 1, 1, 0);
|
||||
int offka = pid0 * AS1;
|
||||
int offkb = pid0 * AS1;
|
||||
int offmc = 0;
|
||||
int offnc = 0;
|
||||
int offpa = 0;
|
||||
int offpb = 0;
|
||||
int maxid = TZ;
|
||||
int offhc = 0;
|
||||
int offha = z;
|
||||
int offhb = z;
|
||||
int ram[TM] = i * BLOCK + ((0 ... TM) % BLOCK);
|
||||
int rbn[TN] = j * BLOCK + ((0 ... TN) % BLOCK);
|
||||
#else
|
||||
// load LUT header
|
||||
int *header = lut + pid0 * 6;
|
||||
int offset = *(header + 0);
|
||||
int AS1 = *(header + 1);
|
||||
int column = *(header + 2);
|
||||
int depth = *(header + 3);
|
||||
int lockid = *(header + 4);
|
||||
int maxid = *(header + 5);
|
||||
int *pinc = lut + offset;
|
||||
int offhc = depth;
|
||||
// load LUT header
|
||||
int *header = lut + pid0 * 6;
|
||||
int offset = *(header + 0);
|
||||
int AS1 = *(header + 1);
|
||||
int column = *(header + 2);
|
||||
int depth = *(header + 3);
|
||||
int lockid = *(header + 4);
|
||||
int maxid = *(header + 5);
|
||||
int *pinc = lut + offset;
|
||||
int offhc = depth;
|
||||
#ifdef DSD
|
||||
// output offset
|
||||
int offnc = pid1 * TN;
|
||||
int offmc = column * TM;
|
||||
int offpc = 0;
|
||||
// dense input offset
|
||||
int offnb = pid1 * TN;
|
||||
int offkb __multipleof(8) = *pinc;
|
||||
int offpb = 0;
|
||||
// sparse input offset
|
||||
int offma = 0;
|
||||
int offka = 0;
|
||||
long offpa __multipleof(8) = *(pinc + 1);
|
||||
offpa = offpa * BLOCK * BLOCK;
|
||||
int offha = 0;
|
||||
int offhb = depth;
|
||||
// output offset
|
||||
int offnc = pid1 * TN;
|
||||
int offmc = column * TM;
|
||||
int offpc = 0;
|
||||
// dense input offset
|
||||
int offnb = pid1 * TN;
|
||||
int offkb __multipleof(8) = *pinc;
|
||||
int offpb = 0;
|
||||
// sparse input offset
|
||||
int offma = 0;
|
||||
int offka = 0;
|
||||
long offpa __multipleof(8) = *(pinc + 1);
|
||||
offpa = offpa * BLOCK * BLOCK;
|
||||
int offha = 0;
|
||||
int offhb = depth;
|
||||
#endif
|
||||
#ifdef DDS
|
||||
// output offset
|
||||
int offmc = pid1 * TM;
|
||||
int offnc = column * TN;
|
||||
int offpc = 0;
|
||||
// dense input offset
|
||||
int offma = pid1 * TM;
|
||||
int offka __multipleof(8) = *pinc;
|
||||
int offpa = 0;
|
||||
// sparse input offset
|
||||
int offnb = 0;
|
||||
int offkb = 0;
|
||||
long offpb __multipleof(8) = *(pinc + 1);
|
||||
offpb = offpb * BLOCK * BLOCK;
|
||||
int offha = depth;
|
||||
int offhb = 0;
|
||||
// output offset
|
||||
int offmc = pid1 * TM;
|
||||
int offnc = column * TN;
|
||||
int offpc = 0;
|
||||
// dense input offset
|
||||
int offma = pid1 * TM;
|
||||
int offka __multipleof(8) = *pinc;
|
||||
int offpa = 0;
|
||||
// sparse input offset
|
||||
int offnb = 0;
|
||||
int offkb = 0;
|
||||
long offpb __multipleof(8) = *(pinc + 1);
|
||||
offpb = offpb * BLOCK * BLOCK;
|
||||
int offha = depth;
|
||||
int offhb = 0;
|
||||
#endif
|
||||
int ram[TM] = offma + 0 ... TM;
|
||||
int rbn[TN] = offnb + 0 ... TN;
|
||||
int ram[TM] = offma + 0 ... TM;
|
||||
int rbn[TN] = offnb + 0 ... TN;
|
||||
#endif
|
||||
// initialize a, b pointers
|
||||
int rka[TK] = offka + 0 ... TK;
|
||||
int rkb[TK] = offkb + 0 ... TK;
|
||||
TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
|
||||
TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
|
||||
// initialize a, b pointers
|
||||
int rka[TK] = offka + 0 ... TK;
|
||||
int rkb[TK] = offkb + 0 ... TK;
|
||||
TYPE *pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka [newaxis, :] * STRIDE_AK;
|
||||
TYPE *pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn [newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
|
||||
// pre-fetch
|
||||
#ifdef DDS
|
||||
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
|
||||
#else
|
||||
bool checkam[TM, TK] = AS1 > 0;
|
||||
#endif
|
||||
#ifdef DSD
|
||||
bool checkbn[TK, TN] = rbn [newaxis, :] < DS0;
|
||||
#else
|
||||
bool checkbn[TK, TN] = AS1 > 0;
|
||||
#endif
|
||||
TYPE a[TM, TK] = checkam ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkbn ? *pb : 0;
|
||||
|
||||
/* ---------------- */
|
||||
/* Inner Loop */
|
||||
/* ---------------- */
|
||||
// create result tile
|
||||
float acc[TM, TN] = 0;
|
||||
int step = TK;
|
||||
for (int k = AS1; k > 0; k -= step) {
|
||||
acc += a @b;
|
||||
// update pointers
|
||||
#ifdef SDD
|
||||
int inc_a = TK * STRIDE_AK;
|
||||
int inc_b = TK * STRIDE_BK;
|
||||
#else
|
||||
pinc += 2;
|
||||
#ifdef DSD
|
||||
int inc_b __multipleof(8) = *pinc;
|
||||
int inc_a __multipleof(8) = *(pinc + 1);
|
||||
inc_b = inc_b * STRIDE_BK;
|
||||
#endif
|
||||
#ifdef DDS
|
||||
int inc_a __multipleof(8) = *pinc;
|
||||
int inc_b __multipleof(8) = *(pinc + 1);
|
||||
inc_a = inc_a * STRIDE_AK;
|
||||
#endif
|
||||
#endif
|
||||
pa += inc_a;
|
||||
pb += inc_b;
|
||||
// pre-fetch
|
||||
#ifdef DDS
|
||||
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
|
||||
#else
|
||||
bool checkam[TM, TK] = AS1 > 0;
|
||||
#endif
|
||||
#ifdef DSD
|
||||
bool checkbn[TK, TN] = rbn[newaxis, :] < DS0;
|
||||
#else
|
||||
bool checkbn[TK, TN] = AS1 > 0;
|
||||
#endif
|
||||
TYPE a[TM, TK] = checkam ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkbn ? *pb : 0;
|
||||
bool checkak[TM, TK] = k > TK;
|
||||
bool checkbk[TK, TN] = k > TK;
|
||||
bool checka[TM, TK] = checkam && checkak;
|
||||
bool checkb[TK, TN] = checkbk && checkbn;
|
||||
a = *? (checka)pa;
|
||||
b = *? (checkb)pb;
|
||||
}
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
/* ---------------- */
|
||||
/* Inner Loop */
|
||||
/* ---------------- */
|
||||
// create result tile
|
||||
float acc[TM, TN] = 0;
|
||||
int step = TK;
|
||||
for(int k = AS1; k > 0; k -= step) {
|
||||
acc += a @ b;
|
||||
// update pointers
|
||||
/* ---------------- */
|
||||
/* Epilogue */
|
||||
/* ---------------- */
|
||||
// initialize c pointers
|
||||
#ifdef SDD
|
||||
int inc_a = TK * STRIDE_AK;
|
||||
int inc_b = TK * STRIDE_BK;
|
||||
bool checkc[TM, TN] = 1;
|
||||
// rematerialize
|
||||
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
|
||||
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
|
||||
int rr_offlutm[TM] = rr_blockidm * (TN / BLOCK) * 4;
|
||||
int rr_offlutn[TN] = rr_blockidn * 4;
|
||||
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn [newaxis, :];
|
||||
int bkid[TM, TN] = *(header + off_bkid);
|
||||
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
|
||||
// range within blocks
|
||||
int rcm[TM] = (0 ... TM) % BLOCK;
|
||||
int rcn[TN] = (0 ... TN) % BLOCK;
|
||||
#else
|
||||
pinc += 2;
|
||||
int rcm[TM] = offmc + 0 ... TM;
|
||||
int rcn[TN] = offnc + 0 ... TN;
|
||||
#ifdef DSD
|
||||
int inc_b __multipleof(8) = *pinc;
|
||||
int inc_a __multipleof(8) = *(pinc + 1);
|
||||
inc_b = inc_b * STRIDE_BK;
|
||||
bool checkc[TM, TN] = rcn [newaxis, :] < DS0;
|
||||
#endif
|
||||
#ifdef DDS
|
||||
int inc_a __multipleof(8) = *pinc;
|
||||
int inc_b __multipleof(8) = *(pinc + 1);
|
||||
inc_a = inc_a * STRIDE_AK;
|
||||
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
|
||||
#endif
|
||||
#endif
|
||||
pa += inc_a;
|
||||
pb += inc_b;
|
||||
// pre-fetch
|
||||
bool checkak[TM, TK] = k > TK;
|
||||
bool checkbk[TK, TN] = k > TK;
|
||||
bool checka[TM, TK] = checkam && checkak;
|
||||
bool checkb[TK, TN] = checkbk && checkbn;
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
/* ---------------- */
|
||||
/* Epilogue */
|
||||
/* ---------------- */
|
||||
// initialize c pointers
|
||||
#ifdef SDD
|
||||
bool checkc[TM, TN] = 1;
|
||||
// rematerialize
|
||||
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
|
||||
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
|
||||
int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4;
|
||||
int rr_offlutn[TN] = rr_blockidn*4;
|
||||
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
|
||||
int bkid[TM, TN] = *(header + off_bkid);
|
||||
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
|
||||
// range within blocks
|
||||
int rcm[TM] = (0 ... TM) % BLOCK;
|
||||
int rcn[TN] = (0 ... TN) % BLOCK;
|
||||
#else
|
||||
int rcm[TM] = offmc + 0 ... TM;
|
||||
int rcn[TN] = offnc + 0 ... TN;
|
||||
#ifdef DSD
|
||||
bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
|
||||
#endif
|
||||
#ifdef DDS
|
||||
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
|
||||
#endif
|
||||
#endif
|
||||
TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
|
||||
// write-back directly
|
||||
if(lockid == 0) {
|
||||
*?(checkc) pc = c;
|
||||
}
|
||||
// accumulate partial result using spin-locks
|
||||
else {
|
||||
int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
|
||||
int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % maxid);
|
||||
atomic_xchg(plock, 0);
|
||||
}
|
||||
}
|
||||
TYPE *pc[TM, TN] = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, newaxis] * STRIDE_CM + rcn [newaxis, :] * STRIDE_CN;
|
||||
// write-back directly
|
||||
if (lockid == 0) {
|
||||
*? (checkc)pc = c;
|
||||
}
|
||||
// accumulate partial result using spin-locks
|
||||
else {
|
||||
int *plock = locks + get_program_id(2) * nlocks * get_num_programs(1) + get_program_id(1) * nlocks + lockid - 1;
|
||||
int *pcount = plock + get_num_programs(2) * get_num_programs(1) * nlocks;
|
||||
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
|
||||
;
|
||||
int count = *pcount;
|
||||
if (count == 0)
|
||||
*? (checkc)pc = c;
|
||||
else
|
||||
*? (checkc)pc = c + *? (checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % maxid);
|
||||
atomic_xchg(plock, 0);
|
||||
}
|
||||
}
|
@@ -10,454 +10,416 @@ src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
|
||||
# MAIN API #
|
||||
##############
|
||||
class _matmul(torch.autograd.Function):
|
||||
|
||||
sdd_cache = dict()
|
||||
dsd_cache = dict()
|
||||
dds_cache = dict()
|
||||
locks = dict()
|
||||
|
||||
# Given an array sizes representing reduction size for each
|
||||
# column of a block-mode matrix multiplication,
|
||||
# performs load-balancing to achieve more smaller reductions
|
||||
# between `seg_size` elements
|
||||
@staticmethod
|
||||
def load_balance(sizes, block):
|
||||
# segment size
|
||||
# heuristics taken from OpenAI blocksparse code
|
||||
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
|
||||
max_size = sizes.max()
|
||||
min_size = sizes[sizes != 0].min()
|
||||
#if max_size > min_size * 2.0:
|
||||
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
|
||||
#else:
|
||||
# seg_max = max_size
|
||||
seg_max = max_size
|
||||
seg_min = max(triton.cdiv(seg_max, 4), 4)
|
||||
# split reduction into segments
|
||||
div = sizes // seg_max
|
||||
rem = sizes % seg_max
|
||||
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
|
||||
width = packs.sum()
|
||||
segments = torch.empty(width, dtype=sizes.dtype)
|
||||
column = torch.empty_like(segments)
|
||||
lockid = torch.zeros_like(segments)
|
||||
maxid = torch.zeros_like(segments)
|
||||
nlocks = 0
|
||||
current = 0
|
||||
col_idx = 0
|
||||
for i in range(len(sizes)):
|
||||
d, r = div[i], rem[i]
|
||||
isempty = sizes[i] < seg_min
|
||||
last = current + d + (r >= seg_min) + isempty
|
||||
# column id
|
||||
column[current:last] = col_idx
|
||||
# lock id
|
||||
if d > 1 or (d == 1 and r >= seg_min):
|
||||
nlocks += 1
|
||||
lockid[current:last] = nlocks
|
||||
maxid[current:last] = last - current
|
||||
# segment size
|
||||
segments[current:current+d] = seg_max
|
||||
if r < seg_min and not isempty:
|
||||
segments[current+d-1] += r
|
||||
if r >= seg_min or isempty:
|
||||
segments[current+d] = r
|
||||
current = last
|
||||
col_idx += 1
|
||||
offsets = torch.zeros_like(segments)
|
||||
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
|
||||
return segments, column, lockid, maxid, offsets
|
||||
|
||||
@staticmethod
|
||||
def get_locks(size, dev):
|
||||
if dev not in _matmul.locks or \
|
||||
size > _matmul.locks[dev].size(0):
|
||||
_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
|
||||
return _matmul.locks[dev]
|
||||
sdd_cache = dict()
|
||||
dsd_cache = dict()
|
||||
dds_cache = dict()
|
||||
locks = dict()
|
||||
|
||||
##########################
|
||||
# SPARSE = DENSE x DENSE #
|
||||
##########################
|
||||
# Given an array sizes representing reduction size for each
|
||||
# column of a block-mode matrix multiplication,
|
||||
# performs load-balancing to achieve more smaller reductions
|
||||
# between `seg_size` elements
|
||||
@staticmethod
|
||||
def load_balance(sizes, block):
|
||||
# segment size
|
||||
# heuristics taken from OpenAI blocksparse code
|
||||
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
|
||||
max_size = sizes.max()
|
||||
min_size = sizes[sizes != 0].min()
|
||||
#if max_size > min_size * 2.0:
|
||||
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
|
||||
#else:
|
||||
# seg_max = max_size
|
||||
seg_max = max_size
|
||||
seg_min = max(triton.cdiv(seg_max, 4), 4)
|
||||
# split reduction into segments
|
||||
div = sizes // seg_max
|
||||
rem = sizes % seg_max
|
||||
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
|
||||
width = packs.sum()
|
||||
segments = torch.empty(width, dtype=sizes.dtype)
|
||||
column = torch.empty_like(segments)
|
||||
lockid = torch.zeros_like(segments)
|
||||
maxid = torch.zeros_like(segments)
|
||||
nlocks = 0
|
||||
current = 0
|
||||
col_idx = 0
|
||||
for i in range(len(sizes)):
|
||||
d, r = div[i], rem[i]
|
||||
isempty = sizes[i] < seg_min
|
||||
last = current + d + (r >= seg_min) + isempty
|
||||
# column id
|
||||
column[current:last] = col_idx
|
||||
# lock id
|
||||
if d > 1 or (d == 1 and r >= seg_min):
|
||||
nlocks += 1
|
||||
lockid[current:last] = nlocks
|
||||
maxid[current:last] = last - current
|
||||
# segment size
|
||||
segments[current:current + d] = seg_max
|
||||
if r < seg_min and not isempty:
|
||||
segments[current + d - 1] += r
|
||||
if r >= seg_min or isempty:
|
||||
segments[current + d] = r
|
||||
current = last
|
||||
col_idx += 1
|
||||
offsets = torch.zeros_like(segments)
|
||||
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
|
||||
return segments, column, lockid, maxid, offsets
|
||||
|
||||
@staticmethod
|
||||
def make_sdd_lut(layout, block, dtype, device):
|
||||
start_width = 128 // block
|
||||
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
|
||||
luts, widths, packs = [], [], []
|
||||
for size, nnz in superblocks:
|
||||
width = nnz.shape[0] // (size*size)
|
||||
h = nnz[:, 0]
|
||||
i = nnz[:, 1]
|
||||
j = nnz[:, 2]
|
||||
b = nnz[:, 3]
|
||||
lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
|
||||
luts.append(lut.type(torch.int32).to(device))
|
||||
widths.append(width)
|
||||
packs.append(size)
|
||||
# create locks
|
||||
return luts, None, widths, packs
|
||||
@staticmethod
|
||||
def get_locks(size, dev):
|
||||
if dev not in _matmul.locks or \
|
||||
size > _matmul.locks[dev].size(0):
|
||||
_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
|
||||
return _matmul.locks[dev]
|
||||
|
||||
@staticmethod
|
||||
def _sdd_matmul(a, b, trans_a, trans_b, trans_c,
|
||||
spdims, block, luts, num_locks, widths, packs):
|
||||
|
||||
if trans_c:
|
||||
a, b = b, a
|
||||
trans_a, trans_b = not trans_b, not trans_a
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
AS2 = a.size(3 if trans_a else 2)
|
||||
AS3 = a.size(2 if trans_a else 3)
|
||||
BS0 = b.size(0)
|
||||
BS1 = b.size(1)
|
||||
BS2 = b.size(3 if trans_b else 2)
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
device = a.device
|
||||
is_16_multiple = AS3 % 16 == 0
|
||||
is_32_multiple = AS3 % 32 == 0
|
||||
is_64_multiple = AS3 % 64 == 0
|
||||
if not is_16_multiple:
|
||||
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
||||
# create kernel
|
||||
total_width = sum([width*pack*pack for width,pack in zip(widths, packs)])
|
||||
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
|
||||
for lut, width, pack in zip(luts, widths, packs):
|
||||
num_lock = 1
|
||||
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
|
||||
if key not in _matmul.sdd_cache:
|
||||
defines = {'TM': block*pack, 'TN': block*pack,
|
||||
'TMN': block*block*pack*pack,
|
||||
'BLOCK': block,
|
||||
'TK': 32,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': '1' if trans_a else 'lda',
|
||||
'STRIDE_AK': 'lda' if trans_a else '1',
|
||||
'STRIDE_BN': 'ldb' if trans_b else '1',
|
||||
'STRIDE_BK': '1' if trans_b else 'ldb',
|
||||
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
|
||||
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
|
||||
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
|
||||
##########################
|
||||
# SPARSE = DENSE x DENSE #
|
||||
##########################
|
||||
|
||||
kernel = _matmul.sdd_cache[key]
|
||||
# create output
|
||||
locks = _matmul.get_locks(2*width*AS0*num_lock, a.device)
|
||||
# maximum grid size is 65535
|
||||
# so operation might be decomposed into multiple
|
||||
# kernel calls
|
||||
max_width = 49152
|
||||
for off_width in range(0, width, max_width):
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
a.stride(2), b.stride(2), block,
|
||||
a.stride(0), b.stride(0), c.stride(0),
|
||||
a.stride(1), b.stride(1), c.stride(0),
|
||||
AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock,
|
||||
grid = lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
|
||||
# save for backward pass
|
||||
return c
|
||||
@staticmethod
|
||||
def make_sdd_lut(layout, block, dtype, device):
|
||||
start_width = 128 // block
|
||||
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
|
||||
luts, widths, packs = [], [], []
|
||||
for size, nnz in superblocks:
|
||||
width = nnz.shape[0] // (size * size)
|
||||
h = nnz[:, 0]
|
||||
i = nnz[:, 1]
|
||||
j = nnz[:, 2]
|
||||
b = nnz[:, 3]
|
||||
lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
|
||||
luts.append(lut.type(torch.int32).to(device))
|
||||
widths.append(width)
|
||||
packs.append(size)
|
||||
# create locks
|
||||
return luts, None, widths, packs
|
||||
|
||||
##########################
|
||||
# DENSE = DENSE x SPARSE #
|
||||
# DENSE = SPARSE x DENSE #
|
||||
##########################
|
||||
|
||||
# Given a binary layout of 0s and 1s,
|
||||
# Construct look-up table for efficient execution on GPUs
|
||||
@staticmethod
|
||||
def make_dxx_lut(layout, block, step, trans, device, transform = lambda idx: idx):
|
||||
# load-balancing
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
segments = _empty.clone()
|
||||
column = _empty.clone()
|
||||
depth = _empty.clone()
|
||||
lockid = _empty.clone()
|
||||
maxid = _empty.clone()
|
||||
offsets = _empty.clone()
|
||||
current_offset = 0
|
||||
current_maxid = 0
|
||||
for z in range(layout.size(0)):
|
||||
if trans:
|
||||
sizes = torch.sum(layout[z, :, :], 1)
|
||||
else:
|
||||
sizes = torch.sum(layout[z, :, :], 0)
|
||||
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
|
||||
z_depth = z * torch.ones_like(z_segments)
|
||||
z_lockid[z_lockid > 0] += current_maxid
|
||||
current_maxid = z_lockid.max()
|
||||
# concatenate depth
|
||||
segments = torch.cat((segments, z_segments))
|
||||
column = torch.cat((column, z_column))
|
||||
depth = torch.cat((depth, z_depth))
|
||||
maxid = torch.cat((maxid, z_maxid))
|
||||
offsets = torch.cat((offsets, current_offset + z_offsets))
|
||||
lockid = torch.cat((lockid, z_lockid))
|
||||
current_offset += layout[z, :, :].sum()
|
||||
segments *= step
|
||||
# pointer increments
|
||||
if trans:
|
||||
nnz = layout.nonzero(as_tuple=False)
|
||||
else:
|
||||
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
|
||||
num_blocks = nnz.size(0)
|
||||
offsets = torch.min(offsets, (num_blocks - 1)*torch.ones_like(offsets))
|
||||
idx = transform(nnz[:, 2]*block)
|
||||
xincs = idx.clone()
|
||||
xincs[1:] -= idx[:-1]
|
||||
# divide block into multiple steps
|
||||
div = block // step
|
||||
xincs = xincs.view(-1, 1).repeat(1, div)
|
||||
xincs[:, 1:] = step
|
||||
xincs[:, 0 ] -= (div-1)*step
|
||||
# first increment for each reduction is actually the offset
|
||||
xincs[offsets[segments>0], 0] = idx[offsets[segments>0]]
|
||||
xincs = xincs.view(-1)
|
||||
# block-mode input increments
|
||||
if trans:
|
||||
widx = torch.arange(num_blocks)
|
||||
else:
|
||||
widx = _empty.clone()
|
||||
current_offset = 0
|
||||
for z in range(layout.size(0)):
|
||||
layoutw = layout[z, :, :].clone()
|
||||
msum = layoutw.sum()
|
||||
layoutw[layoutw > 0] = 1 + torch.arange(msum)
|
||||
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
|
||||
current_offset += msum
|
||||
widx = widx
|
||||
wincs = widx*block*block
|
||||
wincs[1:] -= widx[:-1]*block*block
|
||||
wincs = wincs.view(-1, 1).repeat(1, div)
|
||||
if trans:
|
||||
wincs[:, 1:] = step
|
||||
wincs[:, 0] -= (div-1)*step
|
||||
else:
|
||||
wincs[:, 1:] = step*block
|
||||
wincs[:, 0] -= (div - 1)*step*block
|
||||
wincs[offsets[segments>0], 0] = widx[offsets[segments>0]]
|
||||
wincs = wincs.view(-1)
|
||||
# adjust offset and segment size
|
||||
offsets *= 2*div
|
||||
segments *= div
|
||||
# create header
|
||||
width = column.size(0)
|
||||
offsets += 6*width
|
||||
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
|
||||
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
|
||||
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
|
||||
# create lut
|
||||
lut = torch.cat((header, incs))
|
||||
lut = lut.type(torch.int32).to(device)
|
||||
# create locks
|
||||
num_locks = max(1, lockid.max())
|
||||
return lut, num_locks, width, None
|
||||
@staticmethod
|
||||
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
|
||||
|
||||
@staticmethod
|
||||
def _dds_matmul(a, b, trans_a, trans_b, trans_c,
|
||||
spdims, block, lut, num_locks, width, packs):
|
||||
# shapes / dtypes
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
AS2 = a.size(3 if trans_a else 2)
|
||||
AS3 = a.size(2 if trans_a else 3)
|
||||
BS0 = spdims[0]
|
||||
BS1 = block * spdims[2 if trans_b else 1]
|
||||
BS2 = block * spdims[1 if trans_b else 2]
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||
if key not in _matmul.dds_cache:
|
||||
defines = {'TM': 128,
|
||||
'TN': block,
|
||||
'TK': 16,
|
||||
'BLOCK': block,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': 1 if trans_a else 'lda',
|
||||
'STRIDE_AK': 'lda' if trans_a else 1,
|
||||
'STRIDE_BN': block if trans_b else 1,
|
||||
'STRIDE_BK': 1 if trans_b else block,
|
||||
'STRIDE_CM': '1' if trans_c else 'ldc',
|
||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||
'NAME': 'dds_kernel',
|
||||
'DDS': True}
|
||||
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||
kernel = _matmul.dds_cache[key]
|
||||
# output
|
||||
CS0 = AS0
|
||||
CS1 = AS1
|
||||
CS2 = BS2 if trans_c else AS2
|
||||
CS3 = AS2 if trans_c else BS2
|
||||
locks = _matmul.get_locks(2*AS0*AS2//32*num_locks, a.device)
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
a.stride(2), block, c.stride(2),
|
||||
a.stride(0), b.stride(0), c.stride(0),
|
||||
a.stride(1), b.stride(1), c.stride(1),
|
||||
AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
|
||||
grid = lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def _dsd_matmul(a, b, trans_a, trans_b, trans_c,
|
||||
spdims, block, lut, num_locks, width, packs):
|
||||
# shapes / dtypes
|
||||
AS0 = spdims[0]
|
||||
AS1 = block * spdims[2 if trans_a else 1]
|
||||
AS2 = block * spdims[1 if trans_a else 2]
|
||||
BS0 = b.size(0)
|
||||
BS1 = b.size(1)
|
||||
BS2 = b.size(3 if trans_b else 2)
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||
if key not in _matmul.dsd_cache:
|
||||
defines = {'TM': block,
|
||||
'TN': 128,
|
||||
'TK': 16,
|
||||
'BLOCK': block,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': 1 if trans_a else block,
|
||||
'STRIDE_AK': block if trans_a else 1,
|
||||
'STRIDE_BN': 'ldb' if trans_b else '1',
|
||||
'STRIDE_BK': '1' if trans_b else 'ldb',
|
||||
'STRIDE_CM': '1' if trans_c else 'ldc',
|
||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||
'NAME': 'dsd_kernel',
|
||||
'DSD': True}
|
||||
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||
kernel = _matmul.dsd_cache[key]
|
||||
# output
|
||||
CS0 = BS0
|
||||
CS1 = BS1
|
||||
CS2 = BS3 if trans_c else AS1
|
||||
CS3 = AS1 if trans_c else BS3
|
||||
locks = _matmul.get_locks(2*BS0*BS3//32*num_locks, a.device)
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
block, b.stride(2), c.stride(2),
|
||||
a.stride(0), b.stride(0), c.stride(0),
|
||||
a.stride(1), b.stride(1), c.stride(1),
|
||||
BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
|
||||
grid = lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
|
||||
return c
|
||||
if trans_c:
|
||||
a, b = b, a
|
||||
trans_a, trans_b = not trans_b, not trans_a
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
AS2 = a.size(3 if trans_a else 2)
|
||||
AS3 = a.size(2 if trans_a else 3)
|
||||
BS0 = b.size(0)
|
||||
BS1 = b.size(1)
|
||||
BS2 = b.size(3 if trans_b else 2)
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
device = a.device
|
||||
is_16_multiple = AS3 % 16 == 0
|
||||
is_32_multiple = AS3 % 32 == 0
|
||||
is_64_multiple = AS3 % 64 == 0
|
||||
if not is_16_multiple:
|
||||
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
||||
# create kernel
|
||||
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
|
||||
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
|
||||
for lut, width, pack in zip(luts, widths, packs):
|
||||
num_lock = 1
|
||||
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
|
||||
if key not in _matmul.sdd_cache:
|
||||
defines = {
|
||||
'TM': block * pack, 'TN': block * pack, 'TMN': block * block * pack * pack, 'BLOCK': block, 'TK':
|
||||
32, 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1',
|
||||
'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc',
|
||||
'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'
|
||||
}
|
||||
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
|
||||
|
||||
fn = {'sdd': _sdd_matmul.__get__(object),
|
||||
'dsd': _dsd_matmul.__get__(object),
|
||||
'dds': _dds_matmul.__get__(object)}
|
||||
kernel = _matmul.sdd_cache[key]
|
||||
# create output
|
||||
locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device)
|
||||
# maximum grid size is 65535
|
||||
# so operation might be decomposed into multiple
|
||||
# kernel calls
|
||||
max_width = 49152
|
||||
for off_width in range(0, width, max_width):
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), b.stride(2), block, a.stride(0),
|
||||
b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(0), AS2, AS2, AS3, off_width,
|
||||
lut.data_ptr(), locks.data_ptr(), num_lock,
|
||||
grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
|
||||
# save for backward pass
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, trans_a, trans_b, trans_c,
|
||||
mode, spdims, block,
|
||||
c_lut, c_num_locks, c_width, c_packs,
|
||||
da_lut, da_num_locks, da_width, da_packs,
|
||||
db_lut, db_num_locks, db_width, db_packs):
|
||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block,
|
||||
c_lut, c_num_locks, c_width, c_packs)
|
||||
# save for backward
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.da_num_locks = da_num_locks
|
||||
ctx.da_lut = da_lut
|
||||
ctx.da_width = da_width
|
||||
ctx.da_packs = da_packs
|
||||
ctx.db_lut = db_lut
|
||||
ctx.db_num_locks = db_num_locks
|
||||
ctx.db_width = db_width
|
||||
ctx.db_packs = db_packs
|
||||
ctx.mode = mode
|
||||
ctx.spdims = spdims
|
||||
ctx.block = block
|
||||
ctx.trans_a = trans_a
|
||||
ctx.trans_b = trans_b
|
||||
return c
|
||||
##########################
|
||||
# DENSE = DENSE x SPARSE #
|
||||
# DENSE = SPARSE x DENSE #
|
||||
##########################
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dc):
|
||||
# saved for backward
|
||||
a, b = ctx.saved_tensors
|
||||
mode = ctx.mode
|
||||
# gradients w.r.t. a
|
||||
if ctx.needs_input_grad[0]:
|
||||
mode_da = mode[1] + mode[0] + mode[2]
|
||||
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
|
||||
ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs)
|
||||
# gradients w.r.t. b
|
||||
if ctx.needs_input_grad[1]:
|
||||
mode_db = mode[2] + mode[1] + mode[0]
|
||||
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block,
|
||||
ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs)
|
||||
return da, db, None, None, None,\
|
||||
None, None, None, None,\
|
||||
None, None, None, None, None, None,\
|
||||
None, None, None, None, None, None,\
|
||||
None, None, None, None, None, None
|
||||
# Given a binary layout of 0s and 1s,
|
||||
# Construct look-up table for efficient execution on GPUs
|
||||
@staticmethod
|
||||
def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx):
|
||||
# load-balancing
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
segments = _empty.clone()
|
||||
column = _empty.clone()
|
||||
depth = _empty.clone()
|
||||
lockid = _empty.clone()
|
||||
maxid = _empty.clone()
|
||||
offsets = _empty.clone()
|
||||
current_offset = 0
|
||||
current_maxid = 0
|
||||
for z in range(layout.size(0)):
|
||||
if trans:
|
||||
sizes = torch.sum(layout[z, :, :], 1)
|
||||
else:
|
||||
sizes = torch.sum(layout[z, :, :], 0)
|
||||
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
|
||||
z_depth = z * torch.ones_like(z_segments)
|
||||
z_lockid[z_lockid > 0] += current_maxid
|
||||
current_maxid = z_lockid.max()
|
||||
# concatenate depth
|
||||
segments = torch.cat((segments, z_segments))
|
||||
column = torch.cat((column, z_column))
|
||||
depth = torch.cat((depth, z_depth))
|
||||
maxid = torch.cat((maxid, z_maxid))
|
||||
offsets = torch.cat((offsets, current_offset + z_offsets))
|
||||
lockid = torch.cat((lockid, z_lockid))
|
||||
current_offset += layout[z, :, :].sum()
|
||||
segments *= step
|
||||
# pointer increments
|
||||
if trans:
|
||||
nnz = layout.nonzero(as_tuple=False)
|
||||
else:
|
||||
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
|
||||
num_blocks = nnz.size(0)
|
||||
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
|
||||
idx = transform(nnz[:, 2] * block)
|
||||
xincs = idx.clone()
|
||||
xincs[1:] -= idx[:-1]
|
||||
# divide block into multiple steps
|
||||
div = block // step
|
||||
xincs = xincs.view(-1, 1).repeat(1, div)
|
||||
xincs[:, 1:] = step
|
||||
xincs[:, 0] -= (div - 1) * step
|
||||
# first increment for each reduction is actually the offset
|
||||
xincs[offsets[segments > 0], 0] = idx[offsets[segments > 0]]
|
||||
xincs = xincs.view(-1)
|
||||
# block-mode input increments
|
||||
if trans:
|
||||
widx = torch.arange(num_blocks)
|
||||
else:
|
||||
widx = _empty.clone()
|
||||
current_offset = 0
|
||||
for z in range(layout.size(0)):
|
||||
layoutw = layout[z, :, :].clone()
|
||||
msum = layoutw.sum()
|
||||
layoutw[layoutw > 0] = 1 + torch.arange(msum)
|
||||
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
|
||||
current_offset += msum
|
||||
widx = widx
|
||||
wincs = widx * block * block
|
||||
wincs[1:] -= widx[:-1] * block * block
|
||||
wincs = wincs.view(-1, 1).repeat(1, div)
|
||||
if trans:
|
||||
wincs[:, 1:] = step
|
||||
wincs[:, 0] -= (div - 1) * step
|
||||
else:
|
||||
wincs[:, 1:] = step * block
|
||||
wincs[:, 0] -= (div - 1) * step * block
|
||||
wincs[offsets[segments > 0], 0] = widx[offsets[segments > 0]]
|
||||
wincs = wincs.view(-1)
|
||||
# adjust offset and segment size
|
||||
offsets *= 2 * div
|
||||
segments *= div
|
||||
# create header
|
||||
width = column.size(0)
|
||||
offsets += 6 * width
|
||||
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
|
||||
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
|
||||
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
|
||||
# create lut
|
||||
lut = torch.cat((header, incs))
|
||||
lut = lut.type(torch.int32).to(device)
|
||||
# create locks
|
||||
num_locks = max(1, lockid.max())
|
||||
return lut, num_locks, width, None
|
||||
|
||||
@staticmethod
|
||||
def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
|
||||
# shapes / dtypes
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
AS2 = a.size(3 if trans_a else 2)
|
||||
AS3 = a.size(2 if trans_a else 3)
|
||||
BS0 = spdims[0]
|
||||
BS1 = block * spdims[2 if trans_b else 1]
|
||||
BS2 = block * spdims[1 if trans_b else 2]
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||
if key not in _matmul.dds_cache:
|
||||
defines = {
|
||||
'TM': 128, 'TN': block, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda',
|
||||
'STRIDE_AK': 'lda' if trans_a else 1, 'STRIDE_BN': block if trans_b else 1, 'STRIDE_BK':
|
||||
1 if trans_b else block, 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||
'NAME': 'dds_kernel', 'DDS': True
|
||||
}
|
||||
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||
kernel = _matmul.dds_cache[key]
|
||||
# output
|
||||
CS0 = AS0
|
||||
CS1 = AS1
|
||||
CS2 = BS2 if trans_c else AS2
|
||||
CS3 = AS2 if trans_c else BS2
|
||||
locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), block, c.stride(2), a.stride(0), b.stride(0),
|
||||
c.stride(0), a.stride(1), b.stride(1), c.stride(1), AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(),
|
||||
num_locks, grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
|
||||
# shapes / dtypes
|
||||
AS0 = spdims[0]
|
||||
AS1 = block * spdims[2 if trans_a else 1]
|
||||
AS2 = block * spdims[1 if trans_a else 2]
|
||||
BS0 = b.size(0)
|
||||
BS1 = b.size(1)
|
||||
BS2 = b.size(3 if trans_b else 2)
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||
if key not in _matmul.dsd_cache:
|
||||
defines = {
|
||||
'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block,
|
||||
'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK':
|
||||
'1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN':
|
||||
'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True
|
||||
}
|
||||
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||
kernel = _matmul.dsd_cache[key]
|
||||
# output
|
||||
CS0 = BS0
|
||||
CS1 = BS1
|
||||
CS2 = BS3 if trans_c else AS1
|
||||
CS3 = AS1 if trans_c else BS3
|
||||
locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), block, b.stride(2), c.stride(2), a.stride(0), b.stride(0),
|
||||
c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(),
|
||||
num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
|
||||
return c
|
||||
|
||||
fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut,
|
||||
da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs):
|
||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs)
|
||||
# save for backward
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.da_num_locks = da_num_locks
|
||||
ctx.da_lut = da_lut
|
||||
ctx.da_width = da_width
|
||||
ctx.da_packs = da_packs
|
||||
ctx.db_lut = db_lut
|
||||
ctx.db_num_locks = db_num_locks
|
||||
ctx.db_width = db_width
|
||||
ctx.db_packs = db_packs
|
||||
ctx.mode = mode
|
||||
ctx.spdims = spdims
|
||||
ctx.block = block
|
||||
ctx.trans_a = trans_a
|
||||
ctx.trans_b = trans_b
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dc):
|
||||
# saved for backward
|
||||
a, b = ctx.saved_tensors
|
||||
mode = ctx.mode
|
||||
# gradients w.r.t. a
|
||||
if ctx.needs_input_grad[0]:
|
||||
mode_da = mode[1] + mode[0] + mode[2]
|
||||
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut,
|
||||
ctx.da_num_locks, ctx.da_width, ctx.da_packs)
|
||||
# gradients w.r.t. b
|
||||
if ctx.needs_input_grad[1]:
|
||||
mode_db = mode[2] + mode[1] + mode[0]
|
||||
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut,
|
||||
ctx.db_num_locks, ctx.db_width, ctx.db_packs)
|
||||
return da, db, None, None, None,\
|
||||
None, None, None, None,\
|
||||
None, None, None, None, None, None,\
|
||||
None, None, None, None, None, None,\
|
||||
None, None, None, None, None, None
|
||||
|
||||
class matmul:
|
||||
|
||||
def make_lut(self, dtype, device):
|
||||
key = (dtype, device)
|
||||
if key in self.lut_cache:
|
||||
return self.lut_cache[key]
|
||||
# C look-up table
|
||||
layout, block = self.layout, self.block
|
||||
step = 8 if dtype == torch.float32 else 16
|
||||
if self.mode == 'sdd':
|
||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
elif self.mode == 'dsd':
|
||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
|
||||
elif self.mode == 'dds':
|
||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
|
||||
# DA look-up table
|
||||
if self.mode == 'sdd':
|
||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
|
||||
elif self.mode == 'dsd':
|
||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
elif self.mode == 'dds':
|
||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
|
||||
# DB look-up table
|
||||
if self.mode == 'sdd':
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
|
||||
elif self.mode == 'dsd':
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
|
||||
elif self.mode == 'dds':
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
|
||||
da_lut, da_num_locks, da_width, da_packs,\
|
||||
db_lut, db_num_locks, db_width, db_packs)
|
||||
return self.lut_cache[key]
|
||||
def make_lut(self, dtype, device):
|
||||
key = (dtype, device)
|
||||
if key in self.lut_cache:
|
||||
return self.lut_cache[key]
|
||||
# C look-up table
|
||||
layout, block = self.layout, self.block
|
||||
step = 8 if dtype == torch.float32 else 16
|
||||
if self.mode == 'sdd':
|
||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
elif self.mode == 'dsd':
|
||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
|
||||
elif self.mode == 'dds':
|
||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
|
||||
# DA look-up table
|
||||
if self.mode == 'sdd':
|
||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
|
||||
elif self.mode == 'dsd':
|
||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
elif self.mode == 'dds':
|
||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b,
|
||||
device)
|
||||
# DB look-up table
|
||||
if self.mode == 'sdd':
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
|
||||
elif self.mode == 'dsd':
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
|
||||
elif self.mode == 'dds':
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
|
||||
da_lut, da_num_locks, da_width, da_packs,\
|
||||
db_lut, db_num_locks, db_width, db_packs)
|
||||
return self.lut_cache[key]
|
||||
|
||||
def __init__(self, layout, block, mode, trans_a = False, trans_b = False):
|
||||
if mode not in ['sdd', 'dsd', 'dds']:
|
||||
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
|
||||
# look-up table cache
|
||||
self.lut_cache = dict()
|
||||
# attributes
|
||||
self.trans_a = trans_a
|
||||
self.trans_b = trans_b
|
||||
self.mode = mode
|
||||
self.spdims = layout.shape
|
||||
self.block = block
|
||||
self.layout = layout
|
||||
|
||||
# pad shapes of a tensor to make it
|
||||
# compatible with kernel calls
|
||||
@staticmethod
|
||||
def _pad_shape(x, is_sparse):
|
||||
max_dim = 3 if is_sparse else 4
|
||||
for i in range(max_dim - x.dim()):
|
||||
x = x.unsqueeze(0)
|
||||
return x
|
||||
def __init__(self, layout, block, mode, trans_a=False, trans_b=False):
|
||||
if mode not in ['sdd', 'dsd', 'dds']:
|
||||
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
|
||||
# look-up table cache
|
||||
self.lut_cache = dict()
|
||||
# attributes
|
||||
self.trans_a = trans_a
|
||||
self.trans_b = trans_b
|
||||
self.mode = mode
|
||||
self.spdims = layout.shape
|
||||
self.block = block
|
||||
self.layout = layout
|
||||
|
||||
def __call__(self, a, b):
|
||||
c_lut, c_num_locks, c_width, c_packs,\
|
||||
da_lut, da_num_locks, da_width, da_packs,\
|
||||
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
|
||||
# pad shapes with ones
|
||||
a = matmul._pad_shape(a, self.mode == 'dsd')
|
||||
b = matmul._pad_shape(b, self.mode == 'dds')
|
||||
# execute
|
||||
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False,
|
||||
self.mode, self.spdims, self.block,
|
||||
c_lut, c_num_locks, c_width, c_packs,
|
||||
da_lut, da_num_locks, da_width, da_packs,
|
||||
db_lut, db_num_locks, db_width, db_packs)
|
||||
return c
|
||||
# pad shapes of a tensor to make it
|
||||
# compatible with kernel calls
|
||||
@staticmethod
|
||||
def _pad_shape(x, is_sparse):
|
||||
max_dim = 3 if is_sparse else 4
|
||||
for i in range(max_dim - x.dim()):
|
||||
x = x.unsqueeze(0)
|
||||
return x
|
||||
|
||||
def __call__(self, a, b):
|
||||
c_lut, c_num_locks, c_width, c_packs,\
|
||||
da_lut, da_num_locks, da_width, da_packs,\
|
||||
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
|
||||
# pad shapes with ones
|
||||
a = matmul._pad_shape(a, self.mode == 'dsd')
|
||||
b = matmul._pad_shape(b, self.mode == 'dds')
|
||||
# execute
|
||||
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut,
|
||||
c_num_locks, c_width, c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks,
|
||||
db_width, db_packs)
|
||||
return c
|
||||
|
@@ -1,9 +1,9 @@
|
||||
#define STM 8
|
||||
#define STN 8
|
||||
|
||||
__global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __aligned(16),
|
||||
__global__ void matmul(TYPE *A __noalias __readonly __aligned(16),
|
||||
TYPE *B __noalias __readonly __aligned(16),
|
||||
TYPE *C __noalias __aligned(16),
|
||||
float alpha,
|
||||
int M,
|
||||
int N,
|
||||
@@ -11,87 +11,88 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
|
||||
int lda __multipleof(LDA_POW2_DIV),
|
||||
int ldb __multipleof(LDB_POW2_DIV),
|
||||
int ldc __multipleof(LDC_POW2_DIV),
|
||||
int* locks) {
|
||||
// prologue
|
||||
int pid = get_program_id(0);
|
||||
int pidz = get_program_id(2);
|
||||
int gridm = (M + TM - 1) / TM;
|
||||
int gridn = (N + TN - 1) / TN;
|
||||
int *locks) {
|
||||
// prologue
|
||||
int pid = get_program_id(0);
|
||||
int pidz = get_program_id(2);
|
||||
int gridm = (M + TM - 1) / TM;
|
||||
int gridn = (N + TN - 1) / TN;
|
||||
|
||||
// swizzle for better L2 performance
|
||||
int width = STM*gridn;
|
||||
int stm = pid / width;
|
||||
int RSTM = min(gridm - stm*STM, STM);
|
||||
int stn = (pid % width) / (RSTM*STN);
|
||||
int RSTN = min(gridn - stn*STN, STN);
|
||||
int laneid = pid % (RSTM * RSTN);
|
||||
int lanem = laneid / RSTN;
|
||||
int lanen = laneid % RSTN;
|
||||
int pidm = stm*STM + lanem;
|
||||
int pidn = stn*STN + lanen;
|
||||
int rm[TM] = pidm * TM + 0 ... TM;
|
||||
int rn[TN] = pidn * TN + 0 ... TN;
|
||||
// swizzle for better L2 performance
|
||||
int width = STM * gridn;
|
||||
int stm = pid / width;
|
||||
int RSTM = min(gridm - stm * STM, STM);
|
||||
int stn = (pid % width) / (RSTM * STN);
|
||||
int RSTN = min(gridn - stn * STN, STN);
|
||||
int laneid = pid % (RSTM * RSTN);
|
||||
int lanem = laneid / RSTN;
|
||||
int lanen = laneid % RSTN;
|
||||
int pidm = stm * STM + lanem;
|
||||
int pidn = stn * STN + lanen;
|
||||
int rm[TM] = pidm * TM + 0 ... TM;
|
||||
int rn[TN] = pidn * TN + 0 ... TN;
|
||||
|
||||
// split-k for better parrallelism
|
||||
K = K / TZ;
|
||||
int rk[TK] = 0 ... TK;
|
||||
// pointers to operands
|
||||
int offa[TM, TK] = (pidz*K + rk[newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
||||
int offb[TK, TN] = (pidz*K + rk[:, newaxis]) * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
|
||||
TYPE* pa[TM, TK] = A + offa;
|
||||
TYPE* pb[TK, TN] = B + offb;
|
||||
// split-k for better parrallelism
|
||||
K = K / SPLITK;
|
||||
int rk[TK] = 0 ... TK;
|
||||
// pointers to operands
|
||||
int offa[TM, TK] = (pidz * K + rk [newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
||||
int offb[TK, TN] = (pidz * K + rk[:, newaxis]) * STRIDE_BK + rn [newaxis, :] * STRIDE_BN;
|
||||
TYPE *pa[TM, TK] = A + offa;
|
||||
TYPE *pb[TK, TN] = B + offb;
|
||||
|
||||
// prefetches operands
|
||||
bool checka[TM, TK] = rk[newaxis, :] < K;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
// prefetches operands
|
||||
bool checka[TM, TK] = rk [newaxis, :] < K;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
#if (IS_TK_DIV_K==1)
|
||||
bool checkk[TK] = k > TK;
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for (int k = K; k > 0; k -= TK) {
|
||||
#if (IS_TK_DIV_K == 1)
|
||||
bool checkk[TK] = k > TK;
|
||||
#else
|
||||
bool checkk[TK] = rk < k - TK;
|
||||
bool checkk[TK] = rk < k - TK;
|
||||
#endif
|
||||
bool checka[TM, TK] = checkk[newaxis, :];
|
||||
bool checkb[TK, TN] = checkk[:, newaxis];
|
||||
acc += a @ b;
|
||||
#if (IS_TK_DIV_K==1)
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
bool checka[TM, TK] = checkk [newaxis, :];
|
||||
bool checkb[TK, TN] = checkk[:, newaxis];
|
||||
acc += a @b;
|
||||
#if (IS_TK_DIV_K == 1)
|
||||
a = *? (checka)pa;
|
||||
b = *? (checkb)pb;
|
||||
#else
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
#endif
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
int rcm[TM] = pidm * TM + 0 ... TM;
|
||||
int rcn[TN] = pidn * TN + 0 ... TN;
|
||||
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :];
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn[newaxis, :] < N;
|
||||
#if (TZ==1)
|
||||
*?(checkc) pc = c;
|
||||
// epilogue
|
||||
int rcm[TM] = pidm * TM + 0 ... TM;
|
||||
int rcn[TN] = pidn * TN + 0 ... TN;
|
||||
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn [newaxis, :];
|
||||
TYPE *pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn [newaxis, :] < N;
|
||||
#if (SPLITK == 1)
|
||||
*? (checkc)pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + pid;
|
||||
int *pcount = plock + get_num_programs(0);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + pid;
|
||||
int *pcount = plock + get_num_programs(0);
|
||||
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
|
||||
;
|
||||
int count = *pcount;
|
||||
if (count == 0)
|
||||
*? (checkc)pc = c;
|
||||
else
|
||||
*? (checkc)pc = c + *? (checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % SPLITK);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
@@ -3,29 +3,32 @@ import triton
|
||||
import os
|
||||
|
||||
class _matmul(torch.autograd.Function):
|
||||
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
|
||||
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
|
||||
|
||||
_DEFAULT_CONFIGS = [
|
||||
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
|
||||
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
|
||||
({'TM': '128', 'TN': '64', 'TK': '32', 'TZ': '1'}, 4),
|
||||
({'TM': '64', 'TN': '64', 'TK': '64', 'TZ': '1'}, 4),
|
||||
({'TM': '32', 'TN': '128', 'TK': '64', 'TZ': '1'}, 4),
|
||||
({'TM': '128', 'TN': '32', 'TK': '64', 'TZ': '1'}, 4),
|
||||
({'TM': '64', 'TN': '32', 'TK': '64', 'TZ': '1'}, 2),
|
||||
({'TM': '32', 'TN': '64', 'TK': '64', 'TZ': '1'}, 2),
|
||||
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
|
||||
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
|
||||
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
|
||||
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
|
||||
({"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, 4),
|
||||
({'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, 4),
|
||||
({'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, 4),
|
||||
({'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 4),
|
||||
({'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, 4),
|
||||
({'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 4),
|
||||
({'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 2),
|
||||
({'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 2),
|
||||
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
|
||||
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
|
||||
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
|
||||
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
|
||||
]
|
||||
_CONFIGS = _DEFAULT_CONFIGS
|
||||
|
||||
@staticmethod
|
||||
def largest_pow2_divisor(N):
|
||||
if N % 8 == 0: return 8
|
||||
if N % 4 == 0: return 4
|
||||
if N % 2 == 0: return 2
|
||||
if N % 8 == 0:
|
||||
return 8
|
||||
if N % 4 == 0:
|
||||
return 4
|
||||
if N % 2 == 0:
|
||||
return 2
|
||||
return 1
|
||||
|
||||
_locks = dict()
|
||||
@@ -40,8 +43,10 @@ class _matmul(torch.autograd.Function):
|
||||
K, N = b.shape
|
||||
c = torch.empty((M, N), dtype=dtype, device=device)
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous()
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# kernel hash
|
||||
is_a_row = a.stride(1) == 1
|
||||
is_b_row = b.stride(1) == 1
|
||||
@@ -52,28 +57,60 @@ class _matmul(torch.autograd.Function):
|
||||
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
|
||||
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
|
||||
is_tk_div_k = K % 64 == 0
|
||||
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
|
||||
key = (
|
||||
device,
|
||||
dtype,
|
||||
is_a_row,
|
||||
is_b_row,
|
||||
lda_pow2_div,
|
||||
ldb_pow2_div,
|
||||
ldc_pow2_div,
|
||||
is_tk_div_k,
|
||||
)
|
||||
if key not in _matmul._kernels:
|
||||
defines = {
|
||||
'TYPE': dtype, 'STRIDE_AM': 'lda' if is_a_row else '1', 'STRIDE_AK': '1' if is_a_row else 'lda',
|
||||
'STRIDE_BK': 'ldb' if is_b_row else '1', 'STRIDE_BN': '1' if is_b_row else 'ldb', 'LDA_POW2_DIV':
|
||||
lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, 'IS_TK_DIV_K':
|
||||
int(is_tk_div_k)
|
||||
"TYPE": dtype,
|
||||
"STRIDE_AM": "lda" if is_a_row else "1",
|
||||
"STRIDE_AK": "1" if is_a_row else "lda",
|
||||
"STRIDE_BK": "ldb" if is_b_row else "1",
|
||||
"STRIDE_BN": "1" if is_b_row else "ldb",
|
||||
"LDA_POW2_DIV": lda_pow2_div,
|
||||
"LDB_POW2_DIV": ldb_pow2_div,
|
||||
"LDC_POW2_DIV": ldc_pow2_div,
|
||||
"IS_TK_DIV_K": int(is_tk_div_k),
|
||||
}
|
||||
_matmul._kernels[key] = triton.kernel(_matmul.src,
|
||||
device,
|
||||
defines=defines,
|
||||
autotune_vals=_matmul._CONFIGS,
|
||||
autotune_key=['M', 'N', 'K'])
|
||||
_matmul._kernels[key] = triton.kernel(
|
||||
_matmul.src,
|
||||
device,
|
||||
defines=defines,
|
||||
autotune_vals=_matmul._CONFIGS,
|
||||
autotune_key=["M", "N", "K"],
|
||||
)
|
||||
kernel = _matmul._kernels[key]
|
||||
# # locks for split-k
|
||||
if device not in _matmul._locks:
|
||||
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
|
||||
locks = _matmul._locks[device]
|
||||
# enqueue
|
||||
alpha = 1.
|
||||
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
|
||||
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.TZ]
|
||||
alpha = 1.0
|
||||
args = [
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
alpha,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
locks.data_ptr(),
|
||||
]
|
||||
grid = lambda opt: [
|
||||
triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN),
|
||||
1,
|
||||
opt.SPLITK,
|
||||
]
|
||||
kernel(*args, grid=grid)
|
||||
return c
|
||||
|
||||
|
@@ -1,21 +1,33 @@
|
||||
import torch
|
||||
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||
ret = torch.empty(
|
||||
(x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device
|
||||
)
|
||||
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
||||
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
|
||||
ret[:, idx, :, :] = x[
|
||||
:, h, i * block : (i + 1) * block, j * block : (j + 1) * block
|
||||
]
|
||||
return ret
|
||||
|
||||
|
||||
def mask_tensor(x, mask, block, value=0):
|
||||
ret = x.clone()
|
||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
|
||||
ret[:, h, i * block : (i + 1) * block, j * block : (j + 1) * block] = value
|
||||
return ret
|
||||
|
||||
|
||||
def allclose(x, y):
|
||||
assert x.dtype == y.dtype
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype]
|
||||
return torch.allclose(x, y, atol=atol, rtol=rtol)
|
||||
diff = abs(x - y)
|
||||
x_max = torch.max(x)
|
||||
y_max = torch.max(y)
|
||||
tol = 1e-2
|
||||
err = torch.max(diff) / torch.max(x_max, y_max)
|
||||
return err < tol
|
||||
|
||||
|
||||
def do_bench(fn, flops=0, warmup=10, rep=50):
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -32,8 +44,11 @@ def do_bench(fn, flops=0, warmup=10, rep=50):
|
||||
time_ms = start_event.elapsed_time(end_event) / rep
|
||||
return time_ms
|
||||
|
||||
|
||||
class Benchmark:
|
||||
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args):
|
||||
def __init__(
|
||||
self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args
|
||||
):
|
||||
self.x_names = x_names
|
||||
self.x_vals = x_vals
|
||||
self.y_name = y_name
|
||||
@@ -44,6 +59,7 @@ class Benchmark:
|
||||
self.plot_name = plot_name
|
||||
self.args = args
|
||||
|
||||
|
||||
class Mark:
|
||||
def __init__(self, fn, benchmarks):
|
||||
self.fn = fn
|
||||
@@ -53,26 +69,31 @@ class Mark:
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
|
||||
for x in bench.x_vals:
|
||||
x_args = {x_name: x for x_name in bench.x_names}
|
||||
row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals]
|
||||
row = [
|
||||
self.fn(**x_args, **{bench.y_name: y}, **bench.args)
|
||||
for y in bench.y_vals
|
||||
]
|
||||
df.loc[len(df)] = [x] + row
|
||||
if with_plot and bench.plot_name:
|
||||
xlabel = ' = '.join(bench.x_names)
|
||||
xlabel = " = ".join(bench.x_names)
|
||||
plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
|
||||
plot.set_xlabel(xlabel)
|
||||
plot.set_ylabel(bench.ylabel)
|
||||
plot.set_title(bench.plot_name)
|
||||
plot.set_xscale('log' if bench.loglog else 'linear')
|
||||
plot.set_yscale('log' if bench.loglog else 'linear')
|
||||
plt.savefig(os.path.join(result_path, f'{bench.plot_name}.png'))
|
||||
df.to_csv(os.path.join(result_path, f'{bench.plot_name}.csv'))
|
||||
plot.set_xscale("log" if bench.loglog else "linear")
|
||||
plot.set_yscale("log" if bench.loglog else "linear")
|
||||
plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png"))
|
||||
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
|
||||
|
||||
def run(self, result_path, with_plot):
|
||||
for bench in self.benchmarks:
|
||||
self._run(bench, result_path, with_plot)
|
||||
|
||||
|
||||
def perf_report(benchmarks):
|
||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||
return wrapper
|
||||
|
Reference in New Issue
Block a user