[CODEGEN] Major performance improvements on A100 (#70)

Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
Philippe Tillet
2021-02-21 15:19:39 -08:00
committed by Philippe Tillet
parent 045ab5d62a
commit 5b83259592
31 changed files with 1331 additions and 1115 deletions

View File

@@ -2,58 +2,74 @@ import triton
import torch
# square benchmarks
nt = {False: 'n', True: 't'}
nt = {False: "n", True: "t"}
square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
y_name = 'provider',
y_vals = ['torch', 'triton', 'cutlass'],
y_lines = ['Torch', 'Triton', 'CUTLASS'],
ylabel = 'TFLOPS',
loglog = False,
plot_name = f'matmul-square-{nt[AT]}{nt[BT]}',
args = {'AT': False, 'BT': False, 'dtype': torch.float16}
)\
for AT in [False, True] for BT in [False, True]
x_names=["M", "N", "K"],
x_vals=[512 * i for i in range(1, 16)],
y_name="provider",
y_vals=["torch", "triton", "cutlass"],
y_lines=["Torch", "Triton", "CUTLASS"],
ylabel="TFLOPS",
loglog=False,
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
args={"AT": AT, "BT": BT, "dtype": torch.float16},
) for AT in [False, True] for BT in [False, True]
]
@triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=5):
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
import os
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
if AT: a = a.t()
if BT: b = b.t()
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT:
a = a.t()
if BT:
b = b.t()
num_flops = 2 * M * N * K
if provider == 'torch':
if provider == "torch":
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
torch_tflops = num_flops / torch_ms * 1e-9
return torch_tflops
if provider == 'triton':
if provider == "triton":
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
triton_tflops = num_flops / triton_ms * 1e-9
return triton_tflops
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ:
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
import subprocess
import tempfile
import pandas as pd
# run program specified by CUTLASS_PROFILER env variable
layout_a = 'column' if AT else 'row'
layout_b = 'column' if BT else 'row'
layout_a = "column" if AT else "row"
layout_b = "column" if BT else "row"
# create temporary file name
fd, fname = tempfile.mkstemp()
# run program and gets its output
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', f'--warmup-iterations={warmup}', \
f'--profiling-iterations={rep}', f'--output={fname}', '--verbose=false']
cmd = [
os.environ["CUTLASS_PROFILER"],
f"--m={M}",
f"--n={N}",
f"--k={K}",
f"--A=f16:{layout_a}",
f"--B=f16:{layout_b}",
"--C=f16:column",
"--accum=f32",
"--operation=gemm",
"--verification-enabled=false",
f"--warmup-iterations={warmup}",
f"--profiling-iterations={rep}",
f"--output={fname}",
"--verbose=false",
]
# run cmd
subprocess.run(cmd, stdout=subprocess.PIPE)
# read CSV output
df_c = pd.read_csv(f'{fname}.gemm.csv')
cutlass_tflops = max(df_c['GFLOPs']) / 1e3
df_c = pd.read_csv(f"{fname}.gemm.csv")
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
return cutlass_tflops
return None
if __name__ == '__main__':
if __name__ == "__main__":
bench_op.run()

View File

@@ -15,21 +15,21 @@ import distutils.spawn
import torch
def find_llvm():
versions = ['-10', '-10.0', '']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
versions = ["-10", "-10.0", ""]
supported = ["llvm-config{v}".format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
if paths:
return paths[0]
config = distutils.spawn.find_executable('llvm-config')
instructions = 'Please install llvm-10-dev'
config = distutils.spawn.find_executable("llvm-config")
instructions = "Please install llvm-10-dev"
if config is None:
raise RuntimeError('Could not find llvm-config. ' + instructions)
version = os.popen('{config} --version'.format(config=config)).read()
raise RuntimeError('Version {v} not supported. '.format(v=version) + instructions)
raise RuntimeError("Could not find llvm-config. " + instructions)
version = os.popen("{config} --version".format(config=config)).read()
raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions)
class CMakeExtension(Extension):
def __init__(self, name, path, sourcedir=''):
def __init__(self, name, path, sourcedir=""):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
self.path = path
@@ -37,84 +37,84 @@ class CMakeExtension(Extension):
class CMakeBuild(build_ext):
def run(self):
try:
out = subprocess.check_output(['cmake', '--version'])
out = subprocess.check_output(["cmake", "--version"])
except OSError:
raise RuntimeError("CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions))
if platform.system() == "Windows":
cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1))
if cmake_version < '3.1.0':
cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1))
if cmake_version < "3.1.0":
raise RuntimeError("CMake >= 3.1.0 is required on Windows")
for ext in self.extensions:
self.build_extension(ext)
def build_extension(self, ext):
#self.debug = True
# self.debug = True
self.debug = False
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# python directories
python_include_dirs = distutils.sysconfig.get_python_inc()
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
python_lib_dirs = distutils.sysconfig.get_config_var("LIBDIR")
torch_include_dirs = include_paths(True)
torch_library_dirs = library_paths(True)
cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
cmake_args = [
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
'-DBUILD_TUTORIALS=OFF',
'-DBUILD_PYTHON_MODULE=ON',
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DBUILD_TUTORIALS=OFF",
"-DBUILD_PYTHON_MODULE=ON",
#'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
'-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)),
'-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)),
'-DTORCH_CXX11_ABI=' + cxx11abi,
'-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton',
'-DLLVM_CONFIG=' + find_llvm()
"-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs] + include_paths(True)),
"-DPYTHON_LINK_DIRS=" + ";".join(library_paths(True)),
"-DTORCH_CXX11_ABI=" + cxx11abi,
"-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton",
"-DLLVM_CONFIG=" + find_llvm(),
]
# configuration
cfg = 'Debug' if self.debug else 'Release'
cfg = 'Release'
build_args = ['--config', cfg]
cfg = "Debug" if self.debug else "Release"
build_args = ["--config", cfg]
if platform.system() == "Windows":
cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)]
cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
if sys.maxsize > 2**32:
cmake_args += ['-A', 'x64']
build_args += ['--', '/m']
cmake_args += ["-A", "x64"]
build_args += ["--", "/m"]
else:
cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg]
build_args += ['--', '-j4']
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
build_args += ["--", "-j4"]
env = os.environ.copy()
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))
subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
setup(
name='triton',
version='1.0.0',
author='Philippe Tillet',
author_email='phil@openai.com',
description='A language and compiler for custom Deep Learning operations',
long_description='',
packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'],
install_requires=['numpy', 'torch'],
package_data={'triton/ops': ['*.c'], 'triton/ops/blocksparse': ['*.c']},
name="triton",
version="1.0.0",
author="Philippe Tillet",
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
long_description="",
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
install_requires=["numpy", "torch"],
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
include_package_data=True,
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
cmdclass={'build_ext': CMakeBuild},
ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={"build_ext": CMakeBuild},
zip_safe=False,
# for PyPI
keywords=['Compiler', 'Deep Learning'],
url='https://github.com/ptillet/triton/',
download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz',
keywords=["Compiler", "Deep Learning"],
url="https://github.com/ptillet/triton/",
download_url="https://github.com/ptillet/triton/archive/v0.1.tar.gz",
classifiers=[
'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
'Intended Audience :: Developers', # Define that your audience are developers
'Topic :: Software Development :: Build Tools',
'License :: OSI Approved :: MIT License', # Again, pick a license
'Programming Language :: Python :: 3.6',
"Development Status :: 3 - Alpha", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
"Intended Audience :: Developers", # Define that your audience are developers
"Topic :: Software Development :: Build Tools",
"License :: OSI Approved :: MIT License", # Again, pick a license
"Programming Language :: Python :: 3.6",
],
)

View File

@@ -2,29 +2,17 @@ import torch
import triton
import pytest
@pytest.mark.parametrize(
"MODE, TRANS_A, TRANS_B, BLOCK",
[
(mode, at, bt, block)
for mode in ["sdd", "dsd", "dds"]
for at in [False, True]
for bt in [False, True]
for block in [16, 32, 64]
],
[(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True]
for block in [16, 32, 64]],
)
def test_matmul(
MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384
):
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
# set seed
torch.random.manual_seed(0)
# create inputs
a = torch.randn(
(Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda"
)
b = torch.randn(
(Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda"
)
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
shape = {
"sdd": (M, N),
"dsd": (a.shape[2], a.shape[3]),
@@ -32,9 +20,7 @@ def test_matmul(
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
# triton result
op = triton.ops.blocksparse.matmul(
layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B
)
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
rc = op(ra, rb)
@@ -49,7 +35,6 @@ def test_matmul(
# compare
assert triton.testing.allclose(rc, tc)
@pytest.mark.parametrize(
"BLOCK, WIDTH",
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
@@ -62,12 +47,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# create inputs
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
at_mask = torch.randint(
low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda"
)
kp_mask = torch.randint(
low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda"
)
at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
kp_mask[kp_mask == 1.0] = float("-inf")
# triton result
op = triton.ops.blocksparse.softmax(layout, BLOCK)
@@ -94,7 +75,6 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# compare
assert triton.testing.allclose(ry, ty)
def test_attention_fwd_bwd(
input_scale=1.0,
tol=2e-2,
@@ -108,10 +88,7 @@ def test_attention_fwd_bwd(
# inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True)
.to(dtype)
.cuda()
for _ in range(3)
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
]
attn_mask = torch.tril(
torch.ones(
@@ -129,11 +106,9 @@ def test_attention_fwd_bwd(
query.retain_grad()
key.retain_grad()
value.retain_grad()
attn_out = triton_attention(
layout, block, attn_mask, query=query, key=key, value=value, scale=scale
)
attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
# ad hoc loss
loss = (attn_out ** 2).mean()
loss = (attn_out**2).mean()
loss.backward()
grads = [query.grad, key.grad, value.grad]
@@ -148,17 +123,16 @@ def test_attention_fwd_bwd(
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
# ad hoc loss
torch_loss = (torch_attn_out ** 2).mean()
torch_loss = (torch_attn_out**2).mean()
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
# comparison
print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol)
for g1, g2 in zip(grads, torch_grads):
torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol)
def triton_attention(
layout,
block: int,
@@ -168,12 +142,8 @@ def triton_attention(
value: torch.Tensor,
scale: float,
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(
layout, block, "sdd", trans_a=False, trans_b=True
)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(
layout, block, "dsd", trans_a=False, trans_b=False
)
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
sparse_softmax = triton.ops.blocksparse.softmax(
layout,
block,

View File

@@ -4,7 +4,7 @@ import triton
import torch
@pytest.mark.parametrize(
"TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE",
"TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE",
itertools.chain(*[
[
# 1 warp
@@ -17,14 +17,14 @@ import torch
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# 2 warp
# # 2 warp
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
# # 4 warp
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
@@ -40,24 +40,28 @@ import torch
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE)
] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True]
]))
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE]
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True]
]),
)
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
torch.manual_seed(0)
triton.ops._matmul._kernels = dict()
triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
if M is None: M = TM
if N is None: N = TN
if K is None: K = TK * TZ
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)]
if M is None:
M = TM
if N is None:
N = TN
if K is None:
K = TK * SPLITK
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a
b = b.t() if BT else b
th_c = torch.matmul(a, b)
tt_c = triton.ops.matmul(a, b)
assert triton.testing.allclose(th_c, tt_c)
assert triton.testing.allclose(th_c, tt_c)

View File

@@ -1,198 +1,199 @@
__global__ void NAME (TYPE* A __readonly __noalias __aligned(16),
TYPE* B __readonly __noalias __aligned(16),
TYPE* C __noalias __aligned(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8),
long stride_za __multipleof(8),
long stride_zb __multipleof(8),
long stride_zc __multipleof(8),
long stride_ha __multipleof(8),
long stride_hb __multipleof(8),
long stride_hc __multipleof(8),
int DS0, int DS1,
int SDD_K __multipleof(16),
int SDD_off_width,
int* lut, int* locks, int nlocks) {
/* ---------------- */
/* Prologue */
/* ---------------- */
// program ids
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int pidz = get_program_id(2);
__global__ void NAME(TYPE *A __readonly __noalias __aligned(16),
TYPE *B __readonly __noalias __aligned(16),
TYPE *C __noalias __aligned(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8),
long stride_za __multipleof(8),
long stride_zb __multipleof(8),
long stride_zc __multipleof(8),
long stride_ha __multipleof(8),
long stride_hb __multipleof(8),
long stride_hc __multipleof(8),
int DS0, int DS1,
int SDD_K __multipleof(16),
int SDD_off_width,
int *lut, int *locks, int nlocks) {
/* ---------------- */
/* Prologue */
/* ---------------- */
// program ids
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int pidz = get_program_id(2);
#ifdef SDD
// load LUT header
pid1 = pid1 + SDD_off_width;
int blockidm[TM] = (0 ... TM) / BLOCK;
int blockidn[TN] = (0 ... TN) / BLOCK;
int offlutm[TM] = blockidm*(TN/BLOCK)*4;
int offlutn[TN] = blockidn*4;
int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4;
int z = *(header + 0);
int i[TM] = *(header + 1 + offlutm);
int j[TN] = *(header + 2 + offlutn);
int AS1 = SDD_K / TZ;
int lockid = select(TZ > 1, 1, 0);
int offka = pid0 * AS1;
int offkb = pid0 * AS1;
int offmc = 0;
int offnc = 0;
int offpa = 0;
int offpb = 0;
int maxid = TZ;
int offhc = 0;
int offha = z;
int offhb = z;
int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK);
int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK);
// load LUT header
pid1 = pid1 + SDD_off_width;
int blockidm[TM] = (0 ... TM) / BLOCK;
int blockidn[TN] = (0 ... TN) / BLOCK;
int offlutm[TM] = blockidm * (TN / BLOCK) * 4;
int offlutn[TN] = blockidn * 4;
int *header = lut + pid1 * (TM / BLOCK) * (TN / BLOCK) * 4;
int z = *(header + 0);
int i[TM] = *(header + 1 + offlutm);
int j[TN] = *(header + 2 + offlutn);
int AS1 = SDD_K / TZ;
int lockid = select(TZ > 1, 1, 0);
int offka = pid0 * AS1;
int offkb = pid0 * AS1;
int offmc = 0;
int offnc = 0;
int offpa = 0;
int offpb = 0;
int maxid = TZ;
int offhc = 0;
int offha = z;
int offhb = z;
int ram[TM] = i * BLOCK + ((0 ... TM) % BLOCK);
int rbn[TN] = j * BLOCK + ((0 ... TN) % BLOCK);
#else
// load LUT header
int *header = lut + pid0 * 6;
int offset = *(header + 0);
int AS1 = *(header + 1);
int column = *(header + 2);
int depth = *(header + 3);
int lockid = *(header + 4);
int maxid = *(header + 5);
int *pinc = lut + offset;
int offhc = depth;
// load LUT header
int *header = lut + pid0 * 6;
int offset = *(header + 0);
int AS1 = *(header + 1);
int column = *(header + 2);
int depth = *(header + 3);
int lockid = *(header + 4);
int maxid = *(header + 5);
int *pinc = lut + offset;
int offhc = depth;
#ifdef DSD
// output offset
int offnc = pid1 * TN;
int offmc = column * TM;
int offpc = 0;
// dense input offset
int offnb = pid1 * TN;
int offkb __multipleof(8) = *pinc;
int offpb = 0;
// sparse input offset
int offma = 0;
int offka = 0;
long offpa __multipleof(8) = *(pinc + 1);
offpa = offpa * BLOCK * BLOCK;
int offha = 0;
int offhb = depth;
// output offset
int offnc = pid1 * TN;
int offmc = column * TM;
int offpc = 0;
// dense input offset
int offnb = pid1 * TN;
int offkb __multipleof(8) = *pinc;
int offpb = 0;
// sparse input offset
int offma = 0;
int offka = 0;
long offpa __multipleof(8) = *(pinc + 1);
offpa = offpa * BLOCK * BLOCK;
int offha = 0;
int offhb = depth;
#endif
#ifdef DDS
// output offset
int offmc = pid1 * TM;
int offnc = column * TN;
int offpc = 0;
// dense input offset
int offma = pid1 * TM;
int offka __multipleof(8) = *pinc;
int offpa = 0;
// sparse input offset
int offnb = 0;
int offkb = 0;
long offpb __multipleof(8) = *(pinc + 1);
offpb = offpb * BLOCK * BLOCK;
int offha = depth;
int offhb = 0;
// output offset
int offmc = pid1 * TM;
int offnc = column * TN;
int offpc = 0;
// dense input offset
int offma = pid1 * TM;
int offka __multipleof(8) = *pinc;
int offpa = 0;
// sparse input offset
int offnb = 0;
int offkb = 0;
long offpb __multipleof(8) = *(pinc + 1);
offpb = offpb * BLOCK * BLOCK;
int offha = depth;
int offhb = 0;
#endif
int ram[TM] = offma + 0 ... TM;
int rbn[TN] = offnb + 0 ... TN;
int ram[TM] = offma + 0 ... TM;
int rbn[TN] = offnb + 0 ... TN;
#endif
// initialize a, b pointers
int rka[TK] = offka + 0 ... TK;
int rkb[TK] = offkb + 0 ... TK;
TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
// initialize a, b pointers
int rka[TK] = offka + 0 ... TK;
int rkb[TK] = offkb + 0 ... TK;
TYPE *pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka [newaxis, :] * STRIDE_AK;
TYPE *pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn [newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
// pre-fetch
#ifdef DDS
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
#else
bool checkam[TM, TK] = AS1 > 0;
#endif
#ifdef DSD
bool checkbn[TK, TN] = rbn [newaxis, :] < DS0;
#else
bool checkbn[TK, TN] = AS1 > 0;
#endif
TYPE a[TM, TK] = checkam ? *pa : 0;
TYPE b[TK, TN] = checkbn ? *pb : 0;
/* ---------------- */
/* Inner Loop */
/* ---------------- */
// create result tile
float acc[TM, TN] = 0;
int step = TK;
for (int k = AS1; k > 0; k -= step) {
acc += a @b;
// update pointers
#ifdef SDD
int inc_a = TK * STRIDE_AK;
int inc_b = TK * STRIDE_BK;
#else
pinc += 2;
#ifdef DSD
int inc_b __multipleof(8) = *pinc;
int inc_a __multipleof(8) = *(pinc + 1);
inc_b = inc_b * STRIDE_BK;
#endif
#ifdef DDS
int inc_a __multipleof(8) = *pinc;
int inc_b __multipleof(8) = *(pinc + 1);
inc_a = inc_a * STRIDE_AK;
#endif
#endif
pa += inc_a;
pb += inc_b;
// pre-fetch
#ifdef DDS
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
#else
bool checkam[TM, TK] = AS1 > 0;
#endif
#ifdef DSD
bool checkbn[TK, TN] = rbn[newaxis, :] < DS0;
#else
bool checkbn[TK, TN] = AS1 > 0;
#endif
TYPE a[TM, TK] = checkam ? *pa : 0;
TYPE b[TK, TN] = checkbn ? *pb : 0;
bool checkak[TM, TK] = k > TK;
bool checkbk[TK, TN] = k > TK;
bool checka[TM, TK] = checkam && checkak;
bool checkb[TK, TN] = checkbk && checkbn;
a = *? (checka)pa;
b = *? (checkb)pb;
}
TYPE c[TM, TN] = acc;
/* ---------------- */
/* Inner Loop */
/* ---------------- */
// create result tile
float acc[TM, TN] = 0;
int step = TK;
for(int k = AS1; k > 0; k -= step) {
acc += a @ b;
// update pointers
/* ---------------- */
/* Epilogue */
/* ---------------- */
// initialize c pointers
#ifdef SDD
int inc_a = TK * STRIDE_AK;
int inc_b = TK * STRIDE_BK;
bool checkc[TM, TN] = 1;
// rematerialize
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
int rr_offlutm[TM] = rr_blockidm * (TN / BLOCK) * 4;
int rr_offlutn[TN] = rr_blockidn * 4;
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn [newaxis, :];
int bkid[TM, TN] = *(header + off_bkid);
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
// range within blocks
int rcm[TM] = (0 ... TM) % BLOCK;
int rcn[TN] = (0 ... TN) % BLOCK;
#else
pinc += 2;
int rcm[TM] = offmc + 0 ... TM;
int rcn[TN] = offnc + 0 ... TN;
#ifdef DSD
int inc_b __multipleof(8) = *pinc;
int inc_a __multipleof(8) = *(pinc + 1);
inc_b = inc_b * STRIDE_BK;
bool checkc[TM, TN] = rcn [newaxis, :] < DS0;
#endif
#ifdef DDS
int inc_a __multipleof(8) = *pinc;
int inc_b __multipleof(8) = *(pinc + 1);
inc_a = inc_a * STRIDE_AK;
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
#endif
#endif
pa += inc_a;
pb += inc_b;
// pre-fetch
bool checkak[TM, TK] = k > TK;
bool checkbk[TK, TN] = k > TK;
bool checka[TM, TK] = checkam && checkak;
bool checkb[TK, TN] = checkbk && checkbn;
a = *?(checka)pa;
b = *?(checkb)pb;
}
TYPE c[TM, TN] = acc;
/* ---------------- */
/* Epilogue */
/* ---------------- */
// initialize c pointers
#ifdef SDD
bool checkc[TM, TN] = 1;
// rematerialize
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4;
int rr_offlutn[TN] = rr_blockidn*4;
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
int bkid[TM, TN] = *(header + off_bkid);
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
// range within blocks
int rcm[TM] = (0 ... TM) % BLOCK;
int rcn[TN] = (0 ... TN) % BLOCK;
#else
int rcm[TM] = offmc + 0 ... TM;
int rcn[TN] = offnc + 0 ... TN;
#ifdef DSD
bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
#endif
#ifdef DDS
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
#endif
#endif
TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
// write-back directly
if(lockid == 0) {
*?(checkc) pc = c;
}
// accumulate partial result using spin-locks
else {
int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % maxid);
atomic_xchg(plock, 0);
}
}
TYPE *pc[TM, TN] = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, newaxis] * STRIDE_CM + rcn [newaxis, :] * STRIDE_CN;
// write-back directly
if (lockid == 0) {
*? (checkc)pc = c;
}
// accumulate partial result using spin-locks
else {
int *plock = locks + get_program_id(2) * nlocks * get_num_programs(1) + get_program_id(1) * nlocks + lockid - 1;
int *pcount = plock + get_num_programs(2) * get_num_programs(1) * nlocks;
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount;
if (count == 0)
*? (checkc)pc = c;
else
*? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(pcount, (count + 1) % maxid);
atomic_xchg(plock, 0);
}
}

View File

@@ -10,454 +10,416 @@ src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
# MAIN API #
##############
class _matmul(torch.autograd.Function):
sdd_cache = dict()
dsd_cache = dict()
dds_cache = dict()
locks = dict()
# Given an array sizes representing reduction size for each
# column of a block-mode matrix multiplication,
# performs load-balancing to achieve more smaller reductions
# between `seg_size` elements
@staticmethod
def load_balance(sizes, block):
# segment size
# heuristics taken from OpenAI blocksparse code
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
max_size = sizes.max()
min_size = sizes[sizes != 0].min()
#if max_size > min_size * 2.0:
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
#else:
# seg_max = max_size
seg_max = max_size
seg_min = max(triton.cdiv(seg_max, 4), 4)
# split reduction into segments
div = sizes // seg_max
rem = sizes % seg_max
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
width = packs.sum()
segments = torch.empty(width, dtype=sizes.dtype)
column = torch.empty_like(segments)
lockid = torch.zeros_like(segments)
maxid = torch.zeros_like(segments)
nlocks = 0
current = 0
col_idx = 0
for i in range(len(sizes)):
d, r = div[i], rem[i]
isempty = sizes[i] < seg_min
last = current + d + (r >= seg_min) + isempty
# column id
column[current:last] = col_idx
# lock id
if d > 1 or (d == 1 and r >= seg_min):
nlocks += 1
lockid[current:last] = nlocks
maxid[current:last] = last - current
# segment size
segments[current:current+d] = seg_max
if r < seg_min and not isempty:
segments[current+d-1] += r
if r >= seg_min or isempty:
segments[current+d] = r
current = last
col_idx += 1
offsets = torch.zeros_like(segments)
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
return segments, column, lockid, maxid, offsets
@staticmethod
def get_locks(size, dev):
if dev not in _matmul.locks or \
size > _matmul.locks[dev].size(0):
_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
return _matmul.locks[dev]
sdd_cache = dict()
dsd_cache = dict()
dds_cache = dict()
locks = dict()
##########################
# SPARSE = DENSE x DENSE #
##########################
# Given an array sizes representing reduction size for each
# column of a block-mode matrix multiplication,
# performs load-balancing to achieve more smaller reductions
# between `seg_size` elements
@staticmethod
def load_balance(sizes, block):
# segment size
# heuristics taken from OpenAI blocksparse code
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
max_size = sizes.max()
min_size = sizes[sizes != 0].min()
#if max_size > min_size * 2.0:
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
#else:
# seg_max = max_size
seg_max = max_size
seg_min = max(triton.cdiv(seg_max, 4), 4)
# split reduction into segments
div = sizes // seg_max
rem = sizes % seg_max
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
width = packs.sum()
segments = torch.empty(width, dtype=sizes.dtype)
column = torch.empty_like(segments)
lockid = torch.zeros_like(segments)
maxid = torch.zeros_like(segments)
nlocks = 0
current = 0
col_idx = 0
for i in range(len(sizes)):
d, r = div[i], rem[i]
isempty = sizes[i] < seg_min
last = current + d + (r >= seg_min) + isempty
# column id
column[current:last] = col_idx
# lock id
if d > 1 or (d == 1 and r >= seg_min):
nlocks += 1
lockid[current:last] = nlocks
maxid[current:last] = last - current
# segment size
segments[current:current + d] = seg_max
if r < seg_min and not isempty:
segments[current + d - 1] += r
if r >= seg_min or isempty:
segments[current + d] = r
current = last
col_idx += 1
offsets = torch.zeros_like(segments)
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
return segments, column, lockid, maxid, offsets
@staticmethod
def make_sdd_lut(layout, block, dtype, device):
start_width = 128 // block
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
luts, widths, packs = [], [], []
for size, nnz in superblocks:
width = nnz.shape[0] // (size*size)
h = nnz[:, 0]
i = nnz[:, 1]
j = nnz[:, 2]
b = nnz[:, 3]
lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
luts.append(lut.type(torch.int32).to(device))
widths.append(width)
packs.append(size)
# create locks
return luts, None, widths, packs
@staticmethod
def get_locks(size, dev):
if dev not in _matmul.locks or \
size > _matmul.locks[dev].size(0):
_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
return _matmul.locks[dev]
@staticmethod
def _sdd_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, luts, num_locks, widths, packs):
if trans_c:
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
device = a.device
is_16_multiple = AS3 % 16 == 0
is_32_multiple = AS3 % 32 == 0
is_64_multiple = AS3 % 64 == 0
if not is_16_multiple:
raise ValueError('Reduction size for SDD must be a multiple of 16')
# create kernel
total_width = sum([width*pack*pack for width,pack in zip(widths, packs)])
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
for lut, width, pack in zip(luts, widths, packs):
num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache:
defines = {'TM': block*pack, 'TN': block*pack,
'TMN': block*block*pack*pack,
'BLOCK': block,
'TK': 32,
'TYPE': dtype,
'STRIDE_AM': '1' if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
##########################
# SPARSE = DENSE x DENSE #
##########################
kernel = _matmul.sdd_cache[key]
# create output
locks = _matmul.get_locks(2*width*AS0*num_lock, a.device)
# maximum grid size is 65535
# so operation might be decomposed into multiple
# kernel calls
max_width = 49152
for off_width in range(0, width, max_width):
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
a.stride(2), b.stride(2), block,
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(0),
AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock,
grid = lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
# save for backward pass
return c
@staticmethod
def make_sdd_lut(layout, block, dtype, device):
start_width = 128 // block
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
luts, widths, packs = [], [], []
for size, nnz in superblocks:
width = nnz.shape[0] // (size * size)
h = nnz[:, 0]
i = nnz[:, 1]
j = nnz[:, 2]
b = nnz[:, 3]
lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
luts.append(lut.type(torch.int32).to(device))
widths.append(width)
packs.append(size)
# create locks
return luts, None, widths, packs
##########################
# DENSE = DENSE x SPARSE #
# DENSE = SPARSE x DENSE #
##########################
# Given a binary layout of 0s and 1s,
# Construct look-up table for efficient execution on GPUs
@staticmethod
def make_dxx_lut(layout, block, step, trans, device, transform = lambda idx: idx):
# load-balancing
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
segments = _empty.clone()
column = _empty.clone()
depth = _empty.clone()
lockid = _empty.clone()
maxid = _empty.clone()
offsets = _empty.clone()
current_offset = 0
current_maxid = 0
for z in range(layout.size(0)):
if trans:
sizes = torch.sum(layout[z, :, :], 1)
else:
sizes = torch.sum(layout[z, :, :], 0)
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
z_depth = z * torch.ones_like(z_segments)
z_lockid[z_lockid > 0] += current_maxid
current_maxid = z_lockid.max()
# concatenate depth
segments = torch.cat((segments, z_segments))
column = torch.cat((column, z_column))
depth = torch.cat((depth, z_depth))
maxid = torch.cat((maxid, z_maxid))
offsets = torch.cat((offsets, current_offset + z_offsets))
lockid = torch.cat((lockid, z_lockid))
current_offset += layout[z, :, :].sum()
segments *= step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
else:
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
num_blocks = nnz.size(0)
offsets = torch.min(offsets, (num_blocks - 1)*torch.ones_like(offsets))
idx = transform(nnz[:, 2]*block)
xincs = idx.clone()
xincs[1:] -= idx[:-1]
# divide block into multiple steps
div = block // step
xincs = xincs.view(-1, 1).repeat(1, div)
xincs[:, 1:] = step
xincs[:, 0 ] -= (div-1)*step
# first increment for each reduction is actually the offset
xincs[offsets[segments>0], 0] = idx[offsets[segments>0]]
xincs = xincs.view(-1)
# block-mode input increments
if trans:
widx = torch.arange(num_blocks)
else:
widx = _empty.clone()
current_offset = 0
for z in range(layout.size(0)):
layoutw = layout[z, :, :].clone()
msum = layoutw.sum()
layoutw[layoutw > 0] = 1 + torch.arange(msum)
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum
widx = widx
wincs = widx*block*block
wincs[1:] -= widx[:-1]*block*block
wincs = wincs.view(-1, 1).repeat(1, div)
if trans:
wincs[:, 1:] = step
wincs[:, 0] -= (div-1)*step
else:
wincs[:, 1:] = step*block
wincs[:, 0] -= (div - 1)*step*block
wincs[offsets[segments>0], 0] = widx[offsets[segments>0]]
wincs = wincs.view(-1)
# adjust offset and segment size
offsets *= 2*div
segments *= div
# create header
width = column.size(0)
offsets += 6*width
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
num_locks = max(1, lockid.max())
return lut, num_locks, width, None
@staticmethod
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
@staticmethod
def _dds_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = spdims[0]
BS1 = block * spdims[2 if trans_b else 1]
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dds_cache:
defines = {'TM': 128,
'TN': block,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else 1,
'STRIDE_BN': block if trans_b else 1,
'STRIDE_BK': 1 if trans_b else block,
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel',
'DDS': True}
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dds_cache[key]
# output
CS0 = AS0
CS1 = AS1
CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2
locks = _matmul.get_locks(2*AS0*AS2//32*num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
a.stride(2), block, c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
return c
@staticmethod
def _dsd_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = spdims[0]
AS1 = block * spdims[2 if trans_a else 1]
AS2 = block * spdims[1 if trans_a else 2]
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache:
defines = {'TM': block,
'TN': 128,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else block,
'STRIDE_AK': block if trans_a else 1,
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dsd_kernel',
'DSD': True}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dsd_cache[key]
# output
CS0 = BS0
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
locks = _matmul.get_locks(2*BS0*BS3//32*num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
block, b.stride(2), c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
return c
if trans_c:
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
device = a.device
is_16_multiple = AS3 % 16 == 0
is_32_multiple = AS3 % 32 == 0
is_64_multiple = AS3 % 64 == 0
if not is_16_multiple:
raise ValueError('Reduction size for SDD must be a multiple of 16')
# create kernel
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
for lut, width, pack in zip(luts, widths, packs):
num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache:
defines = {
'TM': block * pack, 'TN': block * pack, 'TMN': block * block * pack * pack, 'BLOCK': block, 'TK':
32, 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc',
'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'
}
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
fn = {'sdd': _sdd_matmul.__get__(object),
'dsd': _dsd_matmul.__get__(object),
'dds': _dds_matmul.__get__(object)}
kernel = _matmul.sdd_cache[key]
# create output
locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device)
# maximum grid size is 65535
# so operation might be decomposed into multiple
# kernel calls
max_width = 49152
for off_width in range(0, width, max_width):
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), b.stride(2), block, a.stride(0),
b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(0), AS2, AS2, AS3, off_width,
lut.data_ptr(), locks.data_ptr(), num_lock,
grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
# save for backward pass
return c
@staticmethod
def forward(ctx, a, b, trans_a, trans_b, trans_c,
mode, spdims, block,
c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block,
c_lut, c_num_locks, c_width, c_packs)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_num_locks = da_num_locks
ctx.da_lut = da_lut
ctx.da_width = da_width
ctx.da_packs = da_packs
ctx.db_lut = db_lut
ctx.db_num_locks = db_num_locks
ctx.db_width = db_width
ctx.db_packs = db_packs
ctx.mode = mode
ctx.spdims = spdims
ctx.block = block
ctx.trans_a = trans_a
ctx.trans_b = trans_b
return c
##########################
# DENSE = DENSE x SPARSE #
# DENSE = SPARSE x DENSE #
##########################
@staticmethod
def backward(ctx, dc):
# saved for backward
a, b = ctx.saved_tensors
mode = ctx.mode
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block,
ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs)
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None
# Given a binary layout of 0s and 1s,
# Construct look-up table for efficient execution on GPUs
@staticmethod
def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx):
# load-balancing
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
segments = _empty.clone()
column = _empty.clone()
depth = _empty.clone()
lockid = _empty.clone()
maxid = _empty.clone()
offsets = _empty.clone()
current_offset = 0
current_maxid = 0
for z in range(layout.size(0)):
if trans:
sizes = torch.sum(layout[z, :, :], 1)
else:
sizes = torch.sum(layout[z, :, :], 0)
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
z_depth = z * torch.ones_like(z_segments)
z_lockid[z_lockid > 0] += current_maxid
current_maxid = z_lockid.max()
# concatenate depth
segments = torch.cat((segments, z_segments))
column = torch.cat((column, z_column))
depth = torch.cat((depth, z_depth))
maxid = torch.cat((maxid, z_maxid))
offsets = torch.cat((offsets, current_offset + z_offsets))
lockid = torch.cat((lockid, z_lockid))
current_offset += layout[z, :, :].sum()
segments *= step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
else:
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
num_blocks = nnz.size(0)
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
idx = transform(nnz[:, 2] * block)
xincs = idx.clone()
xincs[1:] -= idx[:-1]
# divide block into multiple steps
div = block // step
xincs = xincs.view(-1, 1).repeat(1, div)
xincs[:, 1:] = step
xincs[:, 0] -= (div - 1) * step
# first increment for each reduction is actually the offset
xincs[offsets[segments > 0], 0] = idx[offsets[segments > 0]]
xincs = xincs.view(-1)
# block-mode input increments
if trans:
widx = torch.arange(num_blocks)
else:
widx = _empty.clone()
current_offset = 0
for z in range(layout.size(0)):
layoutw = layout[z, :, :].clone()
msum = layoutw.sum()
layoutw[layoutw > 0] = 1 + torch.arange(msum)
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum
widx = widx
wincs = widx * block * block
wincs[1:] -= widx[:-1] * block * block
wincs = wincs.view(-1, 1).repeat(1, div)
if trans:
wincs[:, 1:] = step
wincs[:, 0] -= (div - 1) * step
else:
wincs[:, 1:] = step * block
wincs[:, 0] -= (div - 1) * step * block
wincs[offsets[segments > 0], 0] = widx[offsets[segments > 0]]
wincs = wincs.view(-1)
# adjust offset and segment size
offsets *= 2 * div
segments *= div
# create header
width = column.size(0)
offsets += 6 * width
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
num_locks = max(1, lockid.max())
return lut, num_locks, width, None
@staticmethod
def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = spdims[0]
BS1 = block * spdims[2 if trans_b else 1]
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dds_cache:
defines = {
'TM': 128, 'TN': block, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else 1, 'STRIDE_BN': block if trans_b else 1, 'STRIDE_BK':
1 if trans_b else block, 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel', 'DDS': True
}
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dds_cache[key]
# output
CS0 = AS0
CS1 = AS1
CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2
locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), block, c.stride(2), a.stride(0), b.stride(0),
c.stride(0), a.stride(1), b.stride(1), c.stride(1), AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(),
num_locks, grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
return c
@staticmethod
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = spdims[0]
AS1 = block * spdims[2 if trans_a else 1]
AS2 = block * spdims[1 if trans_a else 2]
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache:
defines = {
'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block,
'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK':
'1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN':
'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True
}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dsd_cache[key]
# output
CS0 = BS0
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), block, b.stride(2), c.stride(2), a.stride(0), b.stride(0),
c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(),
num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
return c
fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}
@staticmethod
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut,
da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_num_locks = da_num_locks
ctx.da_lut = da_lut
ctx.da_width = da_width
ctx.da_packs = da_packs
ctx.db_lut = db_lut
ctx.db_num_locks = db_num_locks
ctx.db_width = db_width
ctx.db_packs = db_packs
ctx.mode = mode
ctx.spdims = spdims
ctx.block = block
ctx.trans_a = trans_a
ctx.trans_b = trans_b
return c
@staticmethod
def backward(ctx, dc):
# saved for backward
a, b = ctx.saved_tensors
mode = ctx.mode
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut,
ctx.da_num_locks, ctx.da_width, ctx.da_packs)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut,
ctx.db_num_locks, ctx.db_width, ctx.db_packs)
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None
class matmul:
def make_lut(self, dtype, device):
key = (dtype, device)
if key in self.lut_cache:
return self.lut_cache[key]
# C look-up table
layout, block = self.layout, self.block
step = 8 if dtype == torch.float32 else 16
if self.mode == 'sdd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dsd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
elif self.mode == 'dds':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
# DA look-up table
if self.mode == 'sdd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
elif self.mode == 'dsd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dds':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
# DB look-up table
if self.mode == 'sdd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
elif self.mode == 'dsd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
elif self.mode == 'dds':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs)
return self.lut_cache[key]
def make_lut(self, dtype, device):
key = (dtype, device)
if key in self.lut_cache:
return self.lut_cache[key]
# C look-up table
layout, block = self.layout, self.block
step = 8 if dtype == torch.float32 else 16
if self.mode == 'sdd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dsd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
elif self.mode == 'dds':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
# DA look-up table
if self.mode == 'sdd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
elif self.mode == 'dsd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dds':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b,
device)
# DB look-up table
if self.mode == 'sdd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
elif self.mode == 'dsd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
elif self.mode == 'dds':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs)
return self.lut_cache[key]
def __init__(self, layout, block, mode, trans_a = False, trans_b = False):
if mode not in ['sdd', 'dsd', 'dds']:
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
# look-up table cache
self.lut_cache = dict()
# attributes
self.trans_a = trans_a
self.trans_b = trans_b
self.mode = mode
self.spdims = layout.shape
self.block = block
self.layout = layout
# pad shapes of a tensor to make it
# compatible with kernel calls
@staticmethod
def _pad_shape(x, is_sparse):
max_dim = 3 if is_sparse else 4
for i in range(max_dim - x.dim()):
x = x.unsqueeze(0)
return x
def __init__(self, layout, block, mode, trans_a=False, trans_b=False):
if mode not in ['sdd', 'dsd', 'dds']:
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
# look-up table cache
self.lut_cache = dict()
# attributes
self.trans_a = trans_a
self.trans_b = trans_b
self.mode = mode
self.spdims = layout.shape
self.block = block
self.layout = layout
def __call__(self, a, b):
c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
# pad shapes with ones
a = matmul._pad_shape(a, self.mode == 'dsd')
b = matmul._pad_shape(b, self.mode == 'dds')
# execute
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False,
self.mode, self.spdims, self.block,
c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs)
return c
# pad shapes of a tensor to make it
# compatible with kernel calls
@staticmethod
def _pad_shape(x, is_sparse):
max_dim = 3 if is_sparse else 4
for i in range(max_dim - x.dim()):
x = x.unsqueeze(0)
return x
def __call__(self, a, b):
c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
# pad shapes with ones
a = matmul._pad_shape(a, self.mode == 'dsd')
b = matmul._pad_shape(b, self.mode == 'dds')
# execute
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut,
c_num_locks, c_width, c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks,
db_width, db_packs)
return c

View File

@@ -1,9 +1,9 @@
#define STM 8
#define STN 8
__global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16),
__global__ void matmul(TYPE *A __noalias __readonly __aligned(16),
TYPE *B __noalias __readonly __aligned(16),
TYPE *C __noalias __aligned(16),
float alpha,
int M,
int N,
@@ -11,87 +11,88 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
int lda __multipleof(LDA_POW2_DIV),
int ldb __multipleof(LDB_POW2_DIV),
int ldc __multipleof(LDC_POW2_DIV),
int* locks) {
// prologue
int pid = get_program_id(0);
int pidz = get_program_id(2);
int gridm = (M + TM - 1) / TM;
int gridn = (N + TN - 1) / TN;
int *locks) {
// prologue
int pid = get_program_id(0);
int pidz = get_program_id(2);
int gridm = (M + TM - 1) / TM;
int gridn = (N + TN - 1) / TN;
// swizzle for better L2 performance
int width = STM*gridn;
int stm = pid / width;
int RSTM = min(gridm - stm*STM, STM);
int stn = (pid % width) / (RSTM*STN);
int RSTN = min(gridn - stn*STN, STN);
int laneid = pid % (RSTM * RSTN);
int lanem = laneid / RSTN;
int lanen = laneid % RSTN;
int pidm = stm*STM + lanem;
int pidn = stn*STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// swizzle for better L2 performance
int width = STM * gridn;
int stm = pid / width;
int RSTM = min(gridm - stm * STM, STM);
int stn = (pid % width) / (RSTM * STN);
int RSTN = min(gridn - stn * STN, STN);
int laneid = pid % (RSTM * RSTN);
int lanem = laneid / RSTN;
int lanen = laneid % RSTN;
int pidm = stm * STM + lanem;
int pidn = stn * STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// split-k for better parrallelism
K = K / TZ;
int rk[TK] = 0 ... TK;
// pointers to operands
int offa[TM, TK] = (pidz*K + rk[newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = (pidz*K + rk[:, newaxis]) * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
TYPE* pa[TM, TK] = A + offa;
TYPE* pb[TK, TN] = B + offb;
// split-k for better parrallelism
K = K / SPLITK;
int rk[TK] = 0 ... TK;
// pointers to operands
int offa[TM, TK] = (pidz * K + rk [newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = (pidz * K + rk[:, newaxis]) * STRIDE_BK + rn [newaxis, :] * STRIDE_BN;
TYPE *pa[TM, TK] = A + offa;
TYPE *pb[TK, TN] = B + offb;
// prefetches operands
bool checka[TM, TK] = rk[newaxis, :] < K;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
// prefetches operands
bool checka[TM, TK] = rk [newaxis, :] < K;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
#if (IS_TK_DIV_K==1)
bool checkk[TK] = k > TK;
// reduction loop
float acc[TM, TN] = 0;
for (int k = K; k > 0; k -= TK) {
#if (IS_TK_DIV_K == 1)
bool checkk[TK] = k > TK;
#else
bool checkk[TK] = rk < k - TK;
bool checkk[TK] = rk < k - TK;
#endif
bool checka[TM, TK] = checkk[newaxis, :];
bool checkb[TK, TN] = checkk[:, newaxis];
acc += a @ b;
#if (IS_TK_DIV_K==1)
a = *?(checka)pa;
b = *?(checkb)pb;
bool checka[TM, TK] = checkk [newaxis, :];
bool checkb[TK, TN] = checkk[:, newaxis];
acc += a @b;
#if (IS_TK_DIV_K == 1)
a = *? (checka)pa;
b = *? (checkb)pb;
#else
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
#endif
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
int rcm[TM] = pidm * TM + 0 ... TM;
int rcn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn[newaxis, :] < N;
#if (TZ==1)
*?(checkc) pc = c;
// epilogue
int rcm[TM] = pidm * TM + 0 ... TM;
int rcn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn [newaxis, :];
TYPE *pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn [newaxis, :] < N;
#if (SPLITK == 1)
*? (checkc)pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + pid;
int *pcount = plock + get_num_programs(0);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
// accumulate partial result using spin-locks
int *plock = locks + pid;
int *pcount = plock + get_num_programs(0);
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount;
if (count == 0)
*? (checkc)pc = c;
else
*? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(pcount, (count + 1) % SPLITK);
atomic_xchg(plock, 0);
#endif
}

View File

@@ -3,29 +3,32 @@ import triton
import os
class _matmul(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
_DEFAULT_CONFIGS = [
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '64', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64', 'TN': '64', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '32', 'TN': '128', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '32', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '64', 'TN': '32', 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32', 'TN': '64', 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
({"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, 4),
({'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, 4),
({'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, 4),
({'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 2),
({'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 2),
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
]
_CONFIGS = _DEFAULT_CONFIGS
@staticmethod
def largest_pow2_divisor(N):
if N % 8 == 0: return 8
if N % 4 == 0: return 4
if N % 2 == 0: return 2
if N % 8 == 0:
return 8
if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1
_locks = dict()
@@ -40,8 +43,10 @@ class _matmul(torch.autograd.Function):
K, N = b.shape
c = torch.empty((M, N), dtype=dtype, device=device)
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous()
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# kernel hash
is_a_row = a.stride(1) == 1
is_b_row = b.stride(1) == 1
@@ -52,28 +57,60 @@ class _matmul(torch.autograd.Function):
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
is_tk_div_k = K % 64 == 0
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
key = (
device,
dtype,
is_a_row,
is_b_row,
lda_pow2_div,
ldb_pow2_div,
ldc_pow2_div,
is_tk_div_k,
)
if key not in _matmul._kernels:
defines = {
'TYPE': dtype, 'STRIDE_AM': 'lda' if is_a_row else '1', 'STRIDE_AK': '1' if is_a_row else 'lda',
'STRIDE_BK': 'ldb' if is_b_row else '1', 'STRIDE_BN': '1' if is_b_row else 'ldb', 'LDA_POW2_DIV':
lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, 'IS_TK_DIV_K':
int(is_tk_div_k)
"TYPE": dtype,
"STRIDE_AM": "lda" if is_a_row else "1",
"STRIDE_AK": "1" if is_a_row else "lda",
"STRIDE_BK": "ldb" if is_b_row else "1",
"STRIDE_BN": "1" if is_b_row else "ldb",
"LDA_POW2_DIV": lda_pow2_div,
"LDB_POW2_DIV": ldb_pow2_div,
"LDC_POW2_DIV": ldc_pow2_div,
"IS_TK_DIV_K": int(is_tk_div_k),
}
_matmul._kernels[key] = triton.kernel(_matmul.src,
device,
defines=defines,
autotune_vals=_matmul._CONFIGS,
autotune_key=['M', 'N', 'K'])
_matmul._kernels[key] = triton.kernel(
_matmul.src,
device,
defines=defines,
autotune_vals=_matmul._CONFIGS,
autotune_key=["M", "N", "K"],
)
kernel = _matmul._kernels[key]
# # locks for split-k
if device not in _matmul._locks:
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
locks = _matmul._locks[device]
# enqueue
alpha = 1.
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.TZ]
alpha = 1.0
args = [
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
alpha,
M,
N,
K,
lda,
ldb,
ldc,
locks.data_ptr(),
]
grid = lambda opt: [
triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN),
1,
opt.SPLITK,
]
kernel(*args, grid=grid)
return c

View File

@@ -1,21 +1,33 @@
import torch
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
ret = torch.empty(
(x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device
)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
ret[:, idx, :, :] = x[
:, h, i * block : (i + 1) * block, j * block : (j + 1) * block
]
return ret
def mask_tensor(x, mask, block, value=0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
ret[:, h, i * block : (i + 1) * block, j * block : (j + 1) * block] = value
return ret
def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype]
return torch.allclose(x, y, atol=atol, rtol=rtol)
diff = abs(x - y)
x_max = torch.max(x)
y_max = torch.max(y)
tol = 1e-2
err = torch.max(diff) / torch.max(x_max, y_max)
return err < tol
def do_bench(fn, flops=0, warmup=10, rep=50):
start_event = torch.cuda.Event(enable_timing=True)
@@ -32,8 +44,11 @@ def do_bench(fn, flops=0, warmup=10, rep=50):
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms
class Benchmark:
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args):
def __init__(
self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args
):
self.x_names = x_names
self.x_vals = x_vals
self.y_name = y_name
@@ -44,6 +59,7 @@ class Benchmark:
self.plot_name = plot_name
self.args = args
class Mark:
def __init__(self, fn, benchmarks):
self.fn = fn
@@ -53,26 +69,31 @@ class Mark:
import matplotlib.pyplot as plt
import pandas as pd
import os
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
for x in bench.x_vals:
x_args = {x_name: x for x_name in bench.x_names}
row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals]
row = [
self.fn(**x_args, **{bench.y_name: y}, **bench.args)
for y in bench.y_vals
]
df.loc[len(df)] = [x] + row
if with_plot and bench.plot_name:
xlabel = ' = '.join(bench.x_names)
xlabel = " = ".join(bench.x_names)
plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
plot.set_xlabel(xlabel)
plot.set_ylabel(bench.ylabel)
plot.set_title(bench.plot_name)
plot.set_xscale('log' if bench.loglog else 'linear')
plot.set_yscale('log' if bench.loglog else 'linear')
plt.savefig(os.path.join(result_path, f'{bench.plot_name}.png'))
df.to_csv(os.path.join(result_path, f'{bench.plot_name}.csv'))
plot.set_xscale("log" if bench.loglog else "linear")
plot.set_yscale("log" if bench.loglog else "linear")
plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png"))
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
def run(self, result_path, with_plot):
for bench in self.benchmarks:
self._run(bench, result_path, with_plot)
def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper