Merge branch 'master' into keren/improve-hook
This commit is contained in:
@@ -1,5 +0,0 @@
|
||||
# Run the benchmarks
|
||||
|
||||
Install the required dependencies via `pip install -r requirements-bench.txt` from the triton/python/bench folder.
|
||||
|
||||
Run the benchmarks through `python3 bench/run.py`, this will produce an HTML report in a results folder.
|
@@ -1,92 +0,0 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
# -------------------------------
|
||||
# Matrix Multiplication
|
||||
# -------------------------------
|
||||
|
||||
nt = {False: 'n', True: 't'}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'],
|
||||
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
line_arg='block',
|
||||
line_vals=[16, 32, 64, 128],
|
||||
line_names=['Block16', 'Block32', 'Block64', 'Block128'],
|
||||
ylabel='TFLOPS',
|
||||
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
||||
args={'layout_mode': layout_mode, 'op_mode': op_mode,
|
||||
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)
|
||||
for AT in [False] for BT in [False]
|
||||
for op_mode in ['dsd'] for layout_mode in ['dense']
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
|
||||
Z, H = 1, 1
|
||||
make_layout = {
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[layout_mode]
|
||||
# create layout
|
||||
shape = {'sdd': (M, N), 'dsd': (K, M) if AT else (M, K), 'dds': (N, K) if BT else (K, N)}[op_mode]
|
||||
layout = make_layout(H, shape[0] // block, shape[1] // block)
|
||||
# creat inputs
|
||||
a = torch.randn((Z, H, K, M) if AT else (Z, H, M, K), dtype=dtype, device='cuda')
|
||||
b = torch.randn((Z, H, N, K) if BT else (Z, H, K, N), dtype=dtype, device='cuda')
|
||||
# create op
|
||||
tflops = lambda ms: num_flops / ms * 1e3
|
||||
if provider == 'triton':
|
||||
op = triton.ops.blocksparse.matmul(layout, block, op_mode, device="cuda", trans_a=AT, trans_b=BT)
|
||||
# inputs
|
||||
a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a
|
||||
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
|
||||
num_flops = {
|
||||
'sdd': 2 * Z * K * float(layout.sum()) * block * block,
|
||||
'dsd': 2 * Z * N * float(layout.sum()) * block * block,
|
||||
'dds': 2 * Z * M * float(layout.sum()) * block * block
|
||||
}[op_mode] * 1e-12
|
||||
return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Softmax
|
||||
# -------------------------------
|
||||
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N'],
|
||||
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
line_arg='block',
|
||||
line_vals=[16, 32, 64],
|
||||
line_names=['Block16', 'Block32', 'Block64'],
|
||||
ylabel='GBPS',
|
||||
plot_name=f'{layout_mode}-square',
|
||||
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)
|
||||
for layout_mode in ['dense', 'tril']
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
|
||||
Z, H = 1, 1
|
||||
make_layout = {
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[layout_mode]
|
||||
layout = make_layout(H, M // block, N // block)
|
||||
a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda')
|
||||
if provider == 'triton':
|
||||
a = triton.testing.sparsify_tensor(a, layout, block)
|
||||
op = triton.ops.blocksparse.softmax(layout, block, device="cuda")
|
||||
gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
|
||||
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
|
||||
|
||||
|
||||
bench_matmul.run(print_data=True, show_plots=True)
|
@@ -1,41 +0,0 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'],
|
||||
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
|
||||
line_arg='provider',
|
||||
line_vals=['triton', 'torch'],
|
||||
line_names=['Triton', 'Torch'],
|
||||
ylabel='GBPS',
|
||||
plot_name=f'{mode}-2048',
|
||||
args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
||||
)
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(confs)
|
||||
def bench_op(M, N, dtype, mode, provider):
|
||||
# create inputs
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
||||
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
||||
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
|
||||
gbps = lambda ms: num_gb / ms * 1e3
|
||||
# forward pass
|
||||
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
|
||||
'triton': triton.ops.cross_entropy}[provider]
|
||||
if mode == 'forward':
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
|
||||
if mode == 'backward':
|
||||
y = op(x, idx)
|
||||
dy = torch.randn_like(y)
|
||||
fn = lambda: y.backward(dy, retain_graph=True)
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=[x])
|
||||
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
bench_op.run(print_data=True)
|
@@ -1,67 +0,0 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
def rounded_linspace(low, high, steps, div):
|
||||
ret = torch.linspace(low, high, steps)
|
||||
ret = torch.div(ret.int() + div - 1, div, rounding_mode='trunc') * div
|
||||
ret = torch.unique(ret)
|
||||
return list(map(int, ret))
|
||||
|
||||
|
||||
# Square benchmarks
|
||||
nt = {False: "n", True: "t"}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["M", "N", "K"],
|
||||
x_vals=rounded_linspace(512, 8192, 32, 128),
|
||||
line_arg="provider",
|
||||
line_vals=["cublas", "triton", "cutlass"],
|
||||
line_names=["cuBLAS", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
|
||||
args={"AT": AT, "BT": BT, "dtype": torch.float16},
|
||||
) for AT in [False] for BT in [False]
|
||||
]
|
||||
|
||||
# Transformer training benchmarks
|
||||
transformer_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=[x],
|
||||
x_vals=rounded_linspace(NK // 16, NK, 32, 128),
|
||||
line_arg="provider",
|
||||
line_vals=["cublas", "triton", "cutlass"],
|
||||
line_names=["cuBLAS", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
||||
args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
||||
) for NK in [12288]
|
||||
for i, x in enumerate(["N", "K"])
|
||||
for M in [2048]
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
|
||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||
if AT:
|
||||
a = a.t()
|
||||
if BT:
|
||||
b = b.t()
|
||||
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
|
||||
if provider == "cublas":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
if provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
if provider == "cutlass":
|
||||
cutlass_matmul = triton.testing.cutlass_matmul
|
||||
try:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
@@ -1,2 +0,0 @@
|
||||
pandas >= 1.3.3
|
||||
matplotlib >= 3.4.3
|
@@ -1,44 +0,0 @@
|
||||
import argparse
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
def run_all(result_dir, names):
|
||||
if not os.path.exists(result_dir):
|
||||
os.makedirs(result_dir)
|
||||
for mod in os.listdir(os.path.dirname(os.path.realpath(__file__))):
|
||||
# skip non python files
|
||||
if not mod.endswith('.py'):
|
||||
continue
|
||||
# skip file not in provided names
|
||||
if names and names not in mod:
|
||||
continue
|
||||
# skip files that don't start with 'bench_'
|
||||
if not mod.startswith('bench_'):
|
||||
continue
|
||||
print(f'running {mod}...')
|
||||
mod = __import__(os.path.splitext(mod)[0])
|
||||
benchmarks = inspect.getmembers(mod, lambda x: isinstance(x, triton.testing.Mark))
|
||||
for name, bench in benchmarks:
|
||||
curr_dir = os.path.join(result_dir, mod.__name__.replace('bench_', ''))
|
||||
if len(benchmarks) > 1:
|
||||
curr_dir = os.path.join(curr_dir, name.replace('bench_', ''))
|
||||
if not os.path.exists(curr_dir):
|
||||
os.makedirs(curr_dir)
|
||||
bench.run(save_path=curr_dir)
|
||||
|
||||
|
||||
def main(args):
|
||||
parser = argparse.ArgumentParser(description="Run the benchmark suite.")
|
||||
parser.add_argument("-r", "--result-dir", type=str, default='results', required=False)
|
||||
parser.add_argument("-n", "--names", type=str, default='', required=False)
|
||||
parser.set_defaults(feature=False)
|
||||
args = parser.parse_args(args)
|
||||
run_all(args.result_dir, args.names)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
19
python/examples/copy_strided.py
Normal file
19
python/examples/copy_strided.py
Normal file
@@ -0,0 +1,19 @@
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm,
|
||||
Z, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1
|
||||
Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn
|
||||
tl.store(Zs, tl.load(Xs))
|
||||
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
|
||||
print(ret)
|
13
python/examples/empty.py
Normal file
13
python/examples/empty.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
|
||||
pass
|
||||
|
||||
|
||||
X = torch.randn(1, device="cuda")
|
||||
pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)
|
@@ -1,5 +1,4 @@
|
||||
import distutils
|
||||
import distutils.spawn
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
@@ -25,42 +24,54 @@ def get_build_type():
|
||||
return "Debug"
|
||||
elif check_env_flag("REL_WITH_DEB_INFO"):
|
||||
return "RelWithDebInfo"
|
||||
elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"):
|
||||
return "TritonRelBuildWithAsserts"
|
||||
else:
|
||||
return "Release"
|
||||
# TODO: change to release when stable enough
|
||||
return "TritonRelBuildWithAsserts"
|
||||
|
||||
|
||||
def use_system_llvm():
|
||||
if platform.system() == "Windows":
|
||||
return True
|
||||
versions = ['-11.0', '-11', '-11-64']
|
||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||
return any(p is not None for p in paths)
|
||||
# --- third party packages -----
|
||||
|
||||
class Package(NamedTuple):
|
||||
package: str
|
||||
name: str
|
||||
url: str
|
||||
test_file: str
|
||||
include_flag: str
|
||||
lib_flag: str
|
||||
syspath_var_name: str
|
||||
|
||||
|
||||
def get_pybind11_package_info():
|
||||
name = "pybind11-2.10.0"
|
||||
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
|
||||
return Package("pybind11", name, url, "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
|
||||
|
||||
|
||||
def get_llvm_package_info():
|
||||
# download if nothing is installed
|
||||
system = platform.system()
|
||||
system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
|
||||
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
||||
if use_assert_enabled_llvm:
|
||||
name = 'llvm+mlir-14.0.0-x86_64-{}-assert'.format(system_suffix)
|
||||
url = "https://github.com/shintaro-iwasaki/llvm-releases/releases/download/llvm-14.0.0-329fda39c507/{}.tar.xz".format(name)
|
||||
else:
|
||||
name = 'clang+llvm-14.0.0-x86_64-{}'.format(system_suffix)
|
||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{}.tar.xz".format(name)
|
||||
return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
|
||||
|
||||
def get_thirdparty_packages(triton_cache_path):
|
||||
class Package(NamedTuple):
|
||||
package: str
|
||||
name: str
|
||||
url: str
|
||||
test_file: str
|
||||
include_flag: str
|
||||
lib_flag: str
|
||||
|
||||
packages = [
|
||||
Package("pybind11", "pybind11-2.10.0", "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz", "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "")
|
||||
]
|
||||
if not use_system_llvm():
|
||||
# donwload LLVM if no suitable system LLVM is installed
|
||||
packages.append(
|
||||
Package("llvm", "clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
|
||||
)
|
||||
|
||||
packages = [get_pybind11_package_info(), get_llvm_package_info()]
|
||||
thirdparty_cmake_args = []
|
||||
for p in packages:
|
||||
package_root_dir = os.path.join(triton_cache_path, p.package)
|
||||
package_dir = os.path.join(package_root_dir, p.name)
|
||||
test_file_path = os.path.join(package_dir, p.test_file)
|
||||
if p.syspath_var_name in os.environ:
|
||||
package_dir = os.environ[p.syspath_var_name]
|
||||
if not os.path.exists(test_file_path):
|
||||
try:
|
||||
shutil.rmtree(package_root_dir)
|
||||
@@ -77,6 +88,8 @@ def get_thirdparty_packages(triton_cache_path):
|
||||
thirdparty_cmake_args.append("-D{}={}/lib".format(p.lib_flag, package_dir))
|
||||
return thirdparty_cmake_args
|
||||
|
||||
# ---- cmake extension ----
|
||||
|
||||
|
||||
class CMakeExtension(Extension):
|
||||
def __init__(self, name, path, sourcedir=""):
|
||||
@@ -113,22 +126,27 @@ class CMakeBuild(build_ext):
|
||||
self.build_extension(ext)
|
||||
|
||||
def build_extension(self, ext):
|
||||
lit_dir = shutil.which('lit')
|
||||
triton_cache_path = os.path.join(os.environ["HOME"], ".triton")
|
||||
# lit is used by the test suite
|
||||
thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path)
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||
# create build directories
|
||||
if not os.path.exists(self.build_temp):
|
||||
os.makedirs(self.build_temp)
|
||||
# python directories
|
||||
python_include_dirs = [distutils.sysconfig.get_python_inc()]
|
||||
python_include_dir = distutils.sysconfig.get_python_inc()
|
||||
cmake_args = [
|
||||
"-DLLVM_ENABLE_WERROR=ON",
|
||||
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||
"-DBUILD_TUTORIALS=OFF",
|
||||
"-DBUILD_PYTHON_MODULE=ON",
|
||||
# '-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
||||
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
|
||||
"-DTRITON_BUILD_TUTORIALS=OFF",
|
||||
"-DTRITON_BUILD_PYTHON_MODULE=ON",
|
||||
"-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
|
||||
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON",
|
||||
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
|
||||
"-DLLVM_EXTERNAL_LIT=" + lit_dir,
|
||||
] + thirdparty_cmake_args
|
||||
|
||||
# configuration
|
||||
cfg = get_build_type()
|
||||
build_args = ["--config", cfg]
|
||||
@@ -155,16 +173,17 @@ setup(
|
||||
author_email="phil@openai.com",
|
||||
description="A language and compiler for custom Deep Learning operations",
|
||||
long_description="",
|
||||
packages=["triton", "triton/_C", "triton/language", "triton/runtime", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
|
||||
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
|
||||
install_requires=[
|
||||
"cmake",
|
||||
"filelock",
|
||||
"torch",
|
||||
"lit",
|
||||
],
|
||||
package_data={
|
||||
"triton/ops": ["*.c"],
|
||||
"triton/ops/blocksparse": ["*.c"],
|
||||
"triton/language": ["*.bc"],
|
||||
"triton/language": ["*.bc"]
|
||||
},
|
||||
include_package_data=True,
|
||||
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||
@@ -180,6 +199,7 @@ setup(
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
],
|
||||
test_suite="tests",
|
||||
extras_require={
|
||||
"tests": [
|
||||
"autopep8",
|
||||
|
@@ -1,202 +0,0 @@
|
||||
#include "cutlass/library/handle.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/operation_table.h"
|
||||
#include "cutlass/library/singleton.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
using namespace cutlass;
|
||||
using namespace cutlass::library;
|
||||
|
||||
std::map<std::vector<size_t>, const Operation *> op_cache_;
|
||||
|
||||
static int const kHostWorkspaceSize = (4 << 10);
|
||||
static int const kDeviceWorkspaceSize = (4 << 20);
|
||||
|
||||
void run(int M, int N, int K,
|
||||
int lda, int ldb, int ldc, int ldd,
|
||||
void const *ptr_A, void const *ptr_B, void const *ptr_C, void *ptr_D,
|
||||
void const *alpha, void const *beta,
|
||||
ScalarPointerMode scalar_mode,
|
||||
const Operation *operation,
|
||||
cudaStream_t stream) {
|
||||
|
||||
GemmUniversalConfiguration configuration{
|
||||
GemmUniversalMode::kGemm,
|
||||
{M, N, K},
|
||||
1,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldd};
|
||||
|
||||
// host workspace size
|
||||
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
|
||||
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed)
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
char host_workspace[kHostWorkspaceSize];
|
||||
|
||||
// device workspace size
|
||||
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
|
||||
if (uint64_t(kDeviceWorkspaceSize) < device_workspace_size_needed)
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
static void *device_workspace;
|
||||
|
||||
// Initialize host and device workspaces
|
||||
Status status = operation->initialize(&configuration, host_workspace, device_workspace, stream);
|
||||
if (status != cutlass::Status::kSuccess)
|
||||
throw std::runtime_error("Unable to initialize workspace");
|
||||
|
||||
// Run the operator
|
||||
GemmArguments arguments{ptr_A, ptr_B, ptr_C, ptr_D, alpha, beta, scalar_mode};
|
||||
operation->run(&arguments, host_workspace, device_workspace, stream);
|
||||
}
|
||||
|
||||
const Operation *autotune(int M, int N, int K,
|
||||
NumericTypeID element_compute,
|
||||
NumericTypeID element_scalar,
|
||||
void const *alpha,
|
||||
NumericTypeID element_A,
|
||||
LayoutTypeID layout_A,
|
||||
ComplexTransform transform_A,
|
||||
void const *ptr_A,
|
||||
int lda,
|
||||
NumericTypeID element_B,
|
||||
LayoutTypeID layout_B,
|
||||
ComplexTransform transform_B,
|
||||
void const *ptr_B,
|
||||
int ldb,
|
||||
void const *beta,
|
||||
NumericTypeID element_C,
|
||||
void const *ptr_C,
|
||||
int ldc,
|
||||
void *ptr_D,
|
||||
int ldd,
|
||||
ScalarPointerMode scalar_mode,
|
||||
int device_id,
|
||||
cudaStream_t stream) {
|
||||
|
||||
// index operation table with functional key
|
||||
GemmFunctionalKey key(
|
||||
Provider::kCUTLASS,
|
||||
GemmKind::kUniversal,
|
||||
element_compute,
|
||||
element_scalar,
|
||||
element_A,
|
||||
layout_A,
|
||||
transform_A,
|
||||
element_B,
|
||||
layout_B,
|
||||
transform_B,
|
||||
element_C);
|
||||
|
||||
auto operators_it = Singleton::get().operation_table.gemm_operations.find(key);
|
||||
if (operators_it == Singleton::get().operation_table.gemm_operations.end())
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
if (operators_it->second.empty())
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
|
||||
cudaDeviceProp device_prop;
|
||||
cudaError_t error = cudaGetDeviceProperties(&device_prop, device_id);
|
||||
if (error != cudaSuccess)
|
||||
throw std::runtime_error("Unable to get device properties");
|
||||
int cc = device_prop.major * 10 + device_prop.minor;
|
||||
|
||||
// index operation table with preference key
|
||||
// assume 8-bytes aligned memory pointers
|
||||
int alignment = 8;
|
||||
GemmPreferenceKey preference_key(cc, alignment);
|
||||
auto autotune_it = operators_it->second.find(preference_key);
|
||||
if (autotune_it == operators_it->second.end())
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
const std::vector<const Operation *> &operations = autotune_it->second;
|
||||
if (operations.empty())
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
|
||||
// auto-tune
|
||||
const Operation *best = nullptr;
|
||||
double best_ms = std::numeric_limits<double>::max();
|
||||
for (const Operation *op : operations) {
|
||||
auto fn = [&]() { run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D,
|
||||
alpha, beta, scalar_mode, op, stream); };
|
||||
triton::driver::cu_stream tt_stream((CUstream)stream, false);
|
||||
double ms = triton::tools::bench(fn, &tt_stream, 10, 25);
|
||||
if (ms < best_ms) {
|
||||
best_ms = ms;
|
||||
best = op;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
|
||||
// map of torch datatypes to cutlass datatypes
|
||||
std::map<std::string, NumericTypeID> type_map = {
|
||||
{"float16", NumericTypeID::kF16},
|
||||
{"float32", NumericTypeID::kF32},
|
||||
{"float64", NumericTypeID::kF64}};
|
||||
|
||||
void cutlass_matmul(uintptr_t A, uintptr_t B, uintptr_t C,
|
||||
size_t M, size_t N, size_t K,
|
||||
size_t stride_a_0, size_t stride_a_1,
|
||||
size_t stride_b_0, size_t stride_b_1,
|
||||
size_t stride_c_0, size_t stride_c_1,
|
||||
std::string type_a, std::string type_b, std::string type_c,
|
||||
size_t dev_id, uint64_t stream_handle) {
|
||||
void *ptr_A = (void *)A;
|
||||
void *ptr_B = (void *)B;
|
||||
void *ptr_C = (void *)C;
|
||||
void *ptr_D = ptr_C;
|
||||
size_t lda = stride_a_0;
|
||||
size_t ldb = stride_b_0;
|
||||
size_t ldc = stride_c_1;
|
||||
size_t ldd = ldc;
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
// layout for A
|
||||
LayoutTypeID layout_A;
|
||||
if (stride_a_0 == 1)
|
||||
layout_A = LayoutTypeID::kColumnMajor;
|
||||
else if (stride_a_1 == 1)
|
||||
layout_A = LayoutTypeID::kRowMajor;
|
||||
else
|
||||
throw std::runtime_error("A layout is not supported");
|
||||
// layout for B
|
||||
LayoutTypeID layout_B;
|
||||
if (stride_b_0 == 1)
|
||||
layout_B = LayoutTypeID::kColumnMajor;
|
||||
else if (stride_b_1 == 1)
|
||||
layout_B = LayoutTypeID::kRowMajor;
|
||||
else
|
||||
throw std::runtime_error("B layout is not supported");
|
||||
// data types
|
||||
NumericTypeID element_compute = NumericTypeID::kF32;
|
||||
NumericTypeID element_A = type_map[type_a];
|
||||
NumericTypeID element_B = type_map[type_b];
|
||||
NumericTypeID element_C = type_map[type_c];
|
||||
// misc. flags
|
||||
ScalarPointerMode scalar_mode = ScalarPointerMode::kHost;
|
||||
NumericTypeID element_scalar = NumericTypeID::kF32;
|
||||
ComplexTransform transform_A = ComplexTransform::kNone;
|
||||
ComplexTransform transform_B = ComplexTransform::kNone;
|
||||
// runtime flags
|
||||
cudaStream_t stream = (cudaStream_t)stream_handle;
|
||||
// auto-tune
|
||||
std::vector<size_t> tune_key = {M, N, K, (size_t)element_A, (size_t)element_B, (size_t)element_C,
|
||||
dev_id, (size_t)element_compute, (size_t)scalar_mode};
|
||||
auto it = op_cache_.find(tune_key);
|
||||
if (it == op_cache_.end()) {
|
||||
const Operation *op = autotune(M, N, K, element_compute, element_scalar, &alpha,
|
||||
element_A, layout_A, transform_A, ptr_A, lda,
|
||||
element_B, layout_B, transform_B, ptr_B, ldb,
|
||||
&beta, element_C, ptr_C, ldc, ptr_D, ldd, scalar_mode,
|
||||
dev_id, stream);
|
||||
it = op_cache_.insert({tune_key, op}).first;
|
||||
}
|
||||
run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D, &alpha, &beta,
|
||||
scalar_mode, it->second, stream);
|
||||
}
|
||||
|
||||
void init_cutlass(pybind11::module &m) {
|
||||
pybind11::module subm = m.def_submodule("cutlass");
|
||||
subm.def("matmul", &cutlass_matmul, "matrix multiplication");
|
||||
}
|
@@ -1,696 +0,0 @@
|
||||
#include "triton/ir/builder.h"
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace ir = triton::ir;
|
||||
namespace py = pybind11;
|
||||
|
||||
static const std::string _builder_doc = R"pbdoc(
|
||||
:param builder: IR builder to generate code into, optional, set automatically when called inside a @triton.jit function
|
||||
:type builder: triton.ir.builder
|
||||
)pbdoc";
|
||||
|
||||
#define VA_ARGS(...) , ##__VA_ARGS__
|
||||
#define DEF_FUNC(MOD, PY_NAME, C_FUNC, ...) \
|
||||
MOD.def(PY_NAME, C_FUNC, (C_FUNC##_docstr + _builder_doc).c_str(), \
|
||||
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
|
||||
|
||||
void throw_not_implemented(std::string key) {
|
||||
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. This is likely a bug on our side.");
|
||||
}
|
||||
|
||||
void throw_not_int_or_float(std::string key) {
|
||||
throw std::runtime_error("`" + key + "` only supported for integer and floating point types.");
|
||||
}
|
||||
|
||||
enum type_code {
|
||||
_bool,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
float16,
|
||||
float32,
|
||||
float64
|
||||
};
|
||||
|
||||
ir::type *make_ir(type_code ty, ir::builder *builder) {
|
||||
switch (ty) {
|
||||
case float16:
|
||||
return builder->get_half_ty();
|
||||
case float32:
|
||||
return builder->get_float_ty();
|
||||
default:
|
||||
throw_not_implemented("make_ir");
|
||||
}
|
||||
}
|
||||
|
||||
type_code from_ir(ir::type *ty) {
|
||||
if (ty->is_half_ty())
|
||||
return float16;
|
||||
if (ty->is_float_ty())
|
||||
return float32;
|
||||
throw_not_implemented("from_ir");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.cast / triton.ir.value.to
|
||||
----------------------------------------------*/
|
||||
std::string cast_docstr = R"pbdoc(
|
||||
Tries to cast a block to a new data type.
|
||||
|
||||
:param input: The input block.
|
||||
:type input: triton.ir.value
|
||||
)pbdoc";
|
||||
|
||||
ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) {
|
||||
ir::type *src_ty = input->get_type();
|
||||
ir::type *dst_ty = make_ir(_dtype, builder);
|
||||
if (src_ty->is_block_ty())
|
||||
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
|
||||
ir::type *src_sca_ty = src_ty->get_scalar_ty();
|
||||
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
|
||||
// FP Truncation
|
||||
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
|
||||
dst_sca_ty->is_floating_point_ty() &&
|
||||
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
|
||||
if (truncate_fp)
|
||||
return builder->create_fp_trunc(input, dst_ty);
|
||||
// FP Extension
|
||||
bool ext_fp = src_sca_ty->is_floating_point_ty() &&
|
||||
dst_sca_ty->is_floating_point_ty() &&
|
||||
src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width();
|
||||
if (ext_fp)
|
||||
return builder->create_fp_ext(input, dst_ty);
|
||||
// Int cast
|
||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
|
||||
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
|
||||
return builder->create_int_cast(input, dst_ty, true);
|
||||
// Float -> Int
|
||||
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty())
|
||||
return builder->create_fp_to_si(input, dst_ty);
|
||||
// int -> Float
|
||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty())
|
||||
return builder->create_si_to_fp(input, dst_ty);
|
||||
// Ptr -> Ptr
|
||||
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
|
||||
return builder->create_cast(ir::BitCast, input, dst_ty);
|
||||
// * -> Bool
|
||||
if (dst_sca_ty->is_bool_ty()) {
|
||||
if (src_sca_ty->is_pointer_ty())
|
||||
input = cast(input, int64, builder);
|
||||
ir::value *other = builder->get_int64(0);
|
||||
if (src_ty->is_bool_ty())
|
||||
other = builder->create_splat(other, src_ty->get_block_shapes());
|
||||
return builder->create_icmpNE(input, other);
|
||||
}
|
||||
throw_not_implemented("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.broadcast_check
|
||||
----------------------------------------------*/
|
||||
std::string try_broadcast_docstr = R"pbdoc(
|
||||
Tries to broadcast two blocks to a common compatible shape.
|
||||
|
||||
:param input: The first input block.
|
||||
:type input: triton.ir.value
|
||||
:param other: The second input block.
|
||||
:type other: triton.ir.value
|
||||
)pbdoc";
|
||||
|
||||
std::tuple<ir::value *, ir::value *> try_broadcast(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
|
||||
ir::type *lhs_ty = lhs->get_type();
|
||||
ir::type *rhs_ty = rhs->get_type();
|
||||
// make_shape_compatible(block, scalar)
|
||||
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
|
||||
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());
|
||||
// make_shape_compatible(scalar, block)
|
||||
else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty())
|
||||
lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes());
|
||||
// make_shape_compatible(block, block)
|
||||
else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) {
|
||||
auto lhs_shape = lhs_ty->get_block_shapes();
|
||||
auto rhs_shape = rhs_ty->get_block_shapes();
|
||||
if (lhs_shape.size() != rhs_shape.size())
|
||||
throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank");
|
||||
ir::type::block_shapes_t ret_shape;
|
||||
for (size_t i = 0; i < lhs_shape.size(); ++i) {
|
||||
unsigned left = lhs_shape[i];
|
||||
unsigned right = rhs_shape[i];
|
||||
if (left == 1)
|
||||
ret_shape.push_back(right);
|
||||
else if (right == 1)
|
||||
ret_shape.push_back(left);
|
||||
else if (left == right)
|
||||
ret_shape.push_back(left);
|
||||
else
|
||||
throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) +
|
||||
": " + std::to_string(left) + " and " + std::to_string(right));
|
||||
}
|
||||
if (lhs_shape != ret_shape)
|
||||
lhs = builder->create_broadcast(lhs, ret_shape);
|
||||
if (rhs_shape != ret_shape)
|
||||
rhs = builder->create_broadcast(rhs, ret_shape);
|
||||
}
|
||||
return std::make_tuple(lhs, rhs);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.broadcast_to
|
||||
----------------------------------------------*/
|
||||
std::string broadcast_to_docstr = R"pbdoc(
|
||||
Tries to broadcast a block to a new shape.
|
||||
|
||||
:param input: The input block.
|
||||
:type input: triton.value
|
||||
:param shape: The new shape.
|
||||
:type shape: tuple of int
|
||||
)pbdoc";
|
||||
|
||||
ir::value *broadcast_to(ir::value *input, const ir::type::block_shapes_t &shape, ir::builder *builder) {
|
||||
if (!input->get_type()->is_block_ty())
|
||||
return builder->create_splat(input, shape);
|
||||
auto src_shape = input->get_type()->get_block_shapes();
|
||||
if (src_shape.size() != shape.size())
|
||||
throw std::runtime_error("Cannot broadcast");
|
||||
return builder->create_broadcast(input, shape);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.load
|
||||
----------------------------------------------*/
|
||||
std::string load_docstr = R"pbdoc(
|
||||
Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`.
|
||||
|
||||
:param pointer: Pointer to the data to be loaded.
|
||||
:type pointer: Block of triton.pointer
|
||||
:param mask: if mask[idx] is false, do not load the data at `pointer[idx]`.
|
||||
:type mask: Block of triton.bool, optional
|
||||
:param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]`
|
||||
:type other: Block of triton.value, optional
|
||||
)pbdoc";
|
||||
|
||||
ir::value *load(ir::value *pointer, std::optional<ir::value *> _mask, std::optional<ir::value *> _other, ir::builder *builder) {
|
||||
if (!_mask.has_value() && !_other.has_value())
|
||||
return builder->create_load(pointer);
|
||||
if (!_mask.has_value())
|
||||
throw std::runtime_error("`other` cannot be provided without `mask`");
|
||||
ir::value *mask = _mask.value();
|
||||
ir::type *elt_ty = pointer->get_type()->get_scalar_ty()->get_pointer_element_ty();
|
||||
auto shape = pointer->get_type()->get_block_shapes();
|
||||
ir::value *other = _other.has_value() ? _other.value() : ir::undef_value::get(elt_ty);
|
||||
other = cast(other, from_ir(elt_ty), builder);
|
||||
other = broadcast_to(other, shape, builder);
|
||||
mask = broadcast_to(mask, shape, builder);
|
||||
return builder->create_masked_load(pointer, mask, other);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.store
|
||||
----------------------------------------------*/
|
||||
std::string store_docstr = R"pbdoc(
|
||||
Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`.
|
||||
|
||||
:param pointer: The memory locations where the elements of `value` are stored.
|
||||
:type pointer: Block of triton.pointer
|
||||
:param value: The block of elements to be stored.
|
||||
:type value: Block of triton.value
|
||||
:param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`.
|
||||
:type mask: Block of triton.bool, optional
|
||||
)pbdoc";
|
||||
ir::value *store(ir::value *ptr, ir::value *val, std::optional<ir::value *> _mask, ir::builder *builder) {
|
||||
if (!_mask.has_value())
|
||||
return builder->create_store(ptr, val);
|
||||
ir::value *mask = _mask.value();
|
||||
return builder->create_masked_store(ptr, val, mask);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.dot
|
||||
----------------------------------------------*/
|
||||
std::string dot_docstr = R"pbdoc(
|
||||
Returns the matrix product of two blocks.
|
||||
The two blocks must be two dimensionals and have compatible inner dimensions.
|
||||
|
||||
:param input: The first block to be multiplied.
|
||||
:type input: 2D block of scalar-type in {`float16`, `float32`}
|
||||
:param other: The second block to be multiplied.
|
||||
:type other: 2D block of scalar-type in {`float16`, `float32`}
|
||||
)pbdoc";
|
||||
ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
|
||||
ir::value *_0 = builder->get_float32(0);
|
||||
unsigned M = lhs->get_type()->get_block_shapes()[0];
|
||||
unsigned N = rhs->get_type()->get_block_shapes()[1];
|
||||
_0 = builder->create_splat(_0, {M, N});
|
||||
return builder->create_dot(lhs, rhs, _0);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.where
|
||||
----------------------------------------------*/
|
||||
std::string where_docstr = R"pbdoc(
|
||||
Returns a block of elements from either `x` or `y`, depending on `condition`.
|
||||
Note that `x` and `y` are always evaluated regardless of the value of `condition`.
|
||||
If you want to avoid unintended memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
|
||||
|
||||
:param condition: When True (nonzero), yield x, otherwise yield y.
|
||||
:type condition: Block of triton.bool
|
||||
:param x: values selected at indices where condition is True.
|
||||
:param y: values selected at indices where condition is False.
|
||||
)pbdoc";
|
||||
ir::value *where(ir::value *condition, ir::value *x, ir::value *y, ir::builder *builder) {
|
||||
return builder->create_select(condition, x, y);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.arange
|
||||
----------------------------------------------*/
|
||||
std::string arange_docstr = R"pbdoc(
|
||||
Returns contiguous values within the open interval [start, end).
|
||||
|
||||
:param start: Start of the interval.
|
||||
:type start: int
|
||||
:param stop: End of the interval.
|
||||
:type stop: int
|
||||
)pbdoc";
|
||||
ir::value *arange(int start, int end, ir::builder *builder) {
|
||||
return builder->get_range(start, end);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.program_id
|
||||
----------------------------------------------*/
|
||||
std::string program_id_docstr = R"pbdoc(
|
||||
Returns the id of the current program instance along the given `axis`.
|
||||
Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s.
|
||||
|
||||
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
||||
:type axis: int
|
||||
)pbdoc";
|
||||
ir::value *program_id(int axis, ir::builder *builder) {
|
||||
return builder->create_get_program_id(axis);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.num_programs
|
||||
----------------------------------------------*/
|
||||
std::string num_programs_docstr = R"pbdoc(
|
||||
Returns the number of program instances launched along the given `axis`.
|
||||
|
||||
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
||||
:type axis: int
|
||||
)pbdoc";
|
||||
ir::value *num_programs(int axis, ir::builder *builder) {
|
||||
return builder->create_get_num_programs(axis);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.zeros
|
||||
----------------------------------------------*/
|
||||
std::string zeros_docstr = R"pbdoc(
|
||||
Returns a block filled with the scalar value 0 and the given shape.
|
||||
|
||||
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
||||
:type shape: tuple of ints
|
||||
:param dtype: Data-type of the new array, e.g., tl.float16
|
||||
:type dtype: triton.ir.dtype
|
||||
)pbdoc";
|
||||
ir::value *zeros(ir::type::block_shapes_t shape, type_code _dtype, ir::builder *builder) {
|
||||
ir::type *dtype = make_ir(_dtype, builder);
|
||||
ir::value *_0 = ir::constant::get_null_value(dtype);
|
||||
return builder->create_splat(_0, shape);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.exp
|
||||
----------------------------------------------*/
|
||||
std::string _exp_docstr = R"pbdoc(
|
||||
Returns the element-wise exponential of `input`.
|
||||
)pbdoc";
|
||||
ir::value *_exp(ir::value *input, ir::builder *builder) {
|
||||
return builder->create_exp(input);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.log
|
||||
----------------------------------------------*/
|
||||
std::string _log_docstr = R"pbdoc(
|
||||
Returns the element-wise natural logarithm of `input`.
|
||||
)pbdoc";
|
||||
ir::value *_log(ir::value *input, ir::builder *builder) {
|
||||
return builder->create_log(input);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.sqrt
|
||||
----------------------------------------------*/
|
||||
std::string sqrt_docstr = R"pbdoc(
|
||||
Returns the element-wise square root of `input`.
|
||||
)pbdoc";
|
||||
ir::value *sqrt(ir::value *input, ir::builder *builder) {
|
||||
return builder->create_sqrt(input);
|
||||
};
|
||||
|
||||
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
|
||||
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_reduce(input, FLOAT_OP, axis);
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_reduce(input, INT_OP, axis);
|
||||
else
|
||||
throw_not_int_or_float(name);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.min
|
||||
----------------------------------------------*/
|
||||
std::string min_docstr = R"pbdoc(
|
||||
Returns the minimum value of `input`.
|
||||
)pbdoc";
|
||||
ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.arg_min
|
||||
----------------------------------------------*/
|
||||
std::string min_docstr = R"pbdoc(
|
||||
Returns the minimum value's index of `input`.
|
||||
)pbdoc";
|
||||
ir::value *argmin(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "argmin", ir::reduce_inst::ARGFMIN, ir::reduce_inst::ARGMIN);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.max
|
||||
----------------------------------------------*/
|
||||
std::string max_docstr = R"pbdoc(
|
||||
Returns the maximum value of `input`.
|
||||
)pbdoc";
|
||||
ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.arg_max
|
||||
----------------------------------------------*/
|
||||
std::string max_docstr = R"pbdoc(
|
||||
Returns the maximum value's index of `input`.
|
||||
)pbdoc";
|
||||
ir::value *argmax(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "argmax", ir::reduce_inst::ARGFMAX, ir::reduce_inst::ARGMAX);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.sum
|
||||
----------------------------------------------*/
|
||||
std::string sum_docstr = R"pbdoc(
|
||||
Returns the sum of `input`.
|
||||
)pbdoc";
|
||||
ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.atomic_cas
|
||||
----------------------------------------------*/
|
||||
std::string atomic_cas_docstr = R"pbdoc(
|
||||
Atomic compare-and-swap.
|
||||
)pbdoc";
|
||||
ir::value *atomic_cas(ir::value *ptr, ir::value *cmp, ir::value *val, ir::builder *builder) {
|
||||
return builder->create_atomic_cas(ptr, cmp, val);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.atomic_xchg
|
||||
----------------------------------------------*/
|
||||
std::string atomic_xchg_docstr = R"pbdoc(
|
||||
Atomic exchange.
|
||||
)pbdoc";
|
||||
ir::value *atomic_xchg(ir::value *ptr, ir::value *val, ir::builder *builder) {
|
||||
return builder->create_atomic_exch(ptr, val);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
debug barrier
|
||||
----------------------------------------------*/
|
||||
std::string debug_barrier_docstr = R"pbdoc(
|
||||
Temporary hacky fixup for when the compiler forgets to insert sync barriers
|
||||
)pbdoc";
|
||||
ir::value *debug_barrier(ir::builder *builder) {
|
||||
return builder->create_barrier();
|
||||
}
|
||||
|
||||
#define DEF_BINARY_OP(MOD, PY_NAME, C_FUNC, ...) \
|
||||
MOD.def(PY_NAME, binary_op(C_FUNC), (C_FUNC##_docstr + _builder_doc).c_str(), \
|
||||
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
|
||||
|
||||
template <class FN>
|
||||
std::function<ir::value *(ir::value *, ir::value *, ir::builder *builder)>
|
||||
binary_op(const FN &fn) {
|
||||
auto ret = [&fn](ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
//std::tie(self, other) = try_broadcast(self, other, builder);
|
||||
return fn(self, other, builder);
|
||||
};
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self + other
|
||||
----------------------------------------------*/
|
||||
std::string add_docstr = R"pbdoc(
|
||||
Returns self + other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *add(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// ptr + offset
|
||||
if (scalar_ty->is_pointer_ty())
|
||||
return builder->create_gep(self, {other});
|
||||
// float + float
|
||||
else if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fadd(self, other);
|
||||
// int + int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_add(self, other);
|
||||
throw_not_implemented("add");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self - other
|
||||
----------------------------------------------*/
|
||||
std::string sub_docstr = R"pbdoc(
|
||||
Returns self - other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *sub(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// ptr + offset
|
||||
if (scalar_ty->is_pointer_ty())
|
||||
return builder->create_gep(self, {other});
|
||||
// float + float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fsub(self, other);
|
||||
// int + int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_sub(self, other);
|
||||
throw_not_implemented("sub");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self * other
|
||||
----------------------------------------------*/
|
||||
std::string mul_docstr = R"pbdoc(
|
||||
Returns self * other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *mul(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float * float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fmul(self, other);
|
||||
// int * int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_mul(self, other);
|
||||
throw_not_implemented("mul");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self > other
|
||||
----------------------------------------------*/
|
||||
std::string greater_than_docstr = R"pbdoc(
|
||||
Returns self > other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *greater_than(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float > float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGT(self, other);
|
||||
// int > int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSGT(self, other);
|
||||
throw_not_implemented("greater_than");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self >= other
|
||||
----------------------------------------------*/
|
||||
std::string greater_equal_docstr = R"pbdoc(
|
||||
Returns self >= other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *greater_equal(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float >= float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGE(self, other);
|
||||
// int >= int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSGE(self, other);
|
||||
throw_not_implemented("greater_equal");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self < other
|
||||
----------------------------------------------*/
|
||||
std::string less_than_docstr = R"pbdoc(
|
||||
Returns self < other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *less_than(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float < float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLT(self, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSLT(self, other);
|
||||
throw_not_implemented("less_than");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self <= other
|
||||
----------------------------------------------*/
|
||||
std::string less_equal_docstr = R"pbdoc(
|
||||
Returns self <= other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *less_equal(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float < float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLE(self, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSLE(self, other);
|
||||
throw_not_implemented("less_equal");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self == other
|
||||
----------------------------------------------*/
|
||||
std::string equal_docstr = R"pbdoc(
|
||||
Returns self == other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *equal(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float == float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOEQ(self, other);
|
||||
// int == int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpEQ(self, other);
|
||||
throw_not_implemented("equal");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self / other
|
||||
----------------------------------------------*/
|
||||
std::string _div_docstr = R"pbdoc(
|
||||
Returns self / other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *_div(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float / float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fdiv(self, other);
|
||||
// int / int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_sdiv(self, other);
|
||||
throw_not_implemented("div");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self % other
|
||||
----------------------------------------------*/
|
||||
std::string mod_docstr = R"pbdoc(
|
||||
Returns self % other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *mod(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float % int
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_frem(self, other);
|
||||
// int % int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_srem(self, other);
|
||||
throw_not_implemented("mod");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self & other
|
||||
----------------------------------------------*/
|
||||
std::string _and_docstr = R"pbdoc(
|
||||
Returns self & other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *_and(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
return builder->create_and(self, other);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of minimum(self, other)
|
||||
----------------------------------------------*/
|
||||
std::string minimum_docstr = R"pbdoc(
|
||||
Returns element-wise minimum of self and other
|
||||
)pbdoc";
|
||||
ir::value *minimum(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
return where(less_than(self, other, builder), self, other, builder);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self[slices]
|
||||
----------------------------------------------*/
|
||||
|
||||
enum slice_mode_t {
|
||||
NEWAXIS,
|
||||
ALL
|
||||
};
|
||||
|
||||
std::string subscript_docstr = R"pbdoc(
|
||||
returns self[slices].
|
||||
|
||||
:param slices: The slices to subscript with.
|
||||
:type slices: List of `None` or `:` slices.
|
||||
)pbdoc";
|
||||
ir::value *subscript(ir::value *self, std::vector<py::object> slices, ir::builder *builder) {
|
||||
std::vector<slice_mode_t> modes;
|
||||
for (py::object slice : slices) {
|
||||
py::object none = py::none();
|
||||
py::object all = py::make_tuple(none, none, none);
|
||||
if (slice.is(none))
|
||||
modes.push_back(NEWAXIS);
|
||||
else if (all.attr("__eq__")(slice))
|
||||
modes.push_back(ALL);
|
||||
else
|
||||
throw std::runtime_error("slice must be None or (None, None, None)");
|
||||
}
|
||||
|
||||
ir::type::block_shapes_t shape;
|
||||
size_t curr = 0;
|
||||
for (slice_mode_t mode : modes) {
|
||||
if (mode == NEWAXIS)
|
||||
shape.push_back(1);
|
||||
else {
|
||||
assert(mode == ALL);
|
||||
shape.push_back(self->get_type()->get_block_shapes()[curr++]);
|
||||
}
|
||||
}
|
||||
return builder->create_reshape(self, shape);
|
||||
}
|
@@ -8,8 +8,4 @@ void init_cutlass(pybind11::module &m);
|
||||
PYBIND11_MODULE(libtriton, m) {
|
||||
m.doc() = "Python bindings to the C++ Triton API";
|
||||
init_triton(m);
|
||||
init_superblocking(m);
|
||||
#ifdef WITH_CUTLASS_BINDINGS
|
||||
init_cutlass(m);
|
||||
#endif
|
||||
}
|
||||
|
@@ -1,119 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
// row-major 3d tensor
|
||||
class tensor_3d {
|
||||
public:
|
||||
tensor_3d(int size_0, int size_1, int size_2, int *data = nullptr) : data_(size_0 * size_1 * size_2, 0) {
|
||||
if (data)
|
||||
std::copy(data, data + data_.size(), data_.begin());
|
||||
stride_0_ = size_1 * size_2;
|
||||
stride_1_ = size_2;
|
||||
stride_2_ = 1;
|
||||
}
|
||||
|
||||
int &operator()(int i, int j, int k) {
|
||||
return data_[i * stride_0_ + j * stride_1_ + k];
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> data_;
|
||||
int stride_0_;
|
||||
int stride_1_;
|
||||
int stride_2_;
|
||||
};
|
||||
|
||||
std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width, int H, int M, int N) {
|
||||
tensor_3d tmp(H, M, N);
|
||||
std::vector<int> current(H, 0);
|
||||
int num = 0;
|
||||
std::vector<int> lut(H * M * N * 4);
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
// surrounding indices
|
||||
std::vector<int> ii_left(max_width, -1);
|
||||
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
|
||||
// start the dynamic programming algorithm
|
||||
for (size_t m = 0; m < M; m++) {
|
||||
for (size_t n = 0; n < N; n++) {
|
||||
int v = layout(h, m, n);
|
||||
if (v == 0)
|
||||
continue;
|
||||
int n_left = ii_left[max_width - 1];
|
||||
int m_top = ii_top[max_width - 1][n];
|
||||
int top = (m_top >= 0) ? tmp(h, m_top, n) : 0;
|
||||
int left = (n_left >= 0) ? tmp(h, m, n_left) : 0;
|
||||
int topleft = (m_top >= 0 && n_left >= 0) ? tmp(h, m_top, n_left) : 0;
|
||||
int width = std::min(left, std::min(top, topleft)) + 1;
|
||||
// reset width if blocks cannot be
|
||||
// packed together (i.e., there's a 1 "in the middle")
|
||||
for (int nn = n_left + 1; nn < n; nn++)
|
||||
if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n])
|
||||
width = 1;
|
||||
tmp(h, m, n) = width;
|
||||
// update n_left ring buffer
|
||||
for (int k = 0; k < max_width - 1; k++)
|
||||
ii_left[k] = ii_left[k + 1];
|
||||
ii_left[max_width - 1] = n;
|
||||
// update ii_top ring buffer
|
||||
for (int k = 0; k < max_width - 1; k++)
|
||||
ii_top[k][n] = ii_top[k + 1][n];
|
||||
ii_top[max_width - 1][n] = m;
|
||||
// block is too small -- skip
|
||||
if (width != max_width)
|
||||
continue;
|
||||
// retained blocks are set to zeros
|
||||
for (size_t km = 0; km < max_width; km++)
|
||||
for (size_t kn = 0; kn < max_width; kn++) {
|
||||
int mm = ii_top[km][n];
|
||||
int nn = ii_left[kn];
|
||||
if (mm < 0 || nn < 0)
|
||||
continue;
|
||||
layout(h, mm, nn) = 0;
|
||||
tmp(h, mm, nn) = 0;
|
||||
lut[num++] = (int)h;
|
||||
lut[num++] = (int)mm;
|
||||
lut[num++] = (int)nn;
|
||||
lut[num++] = idx(h, mm, nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
lut.resize(num);
|
||||
return lut;
|
||||
}
|
||||
|
||||
typedef std::pair<int, pybind11::array_t<int>> lut_t;
|
||||
|
||||
std::vector<lut_t> superblock(uintptr_t LAYOUT, int H, int M, int N, int start_width) {
|
||||
std::vector<lut_t> ret;
|
||||
int current = 0;
|
||||
tensor_3d layout(H, M, N, (int *)LAYOUT);
|
||||
tensor_3d idx(H, M, N);
|
||||
for (int64_t h = 0; h < H; h++)
|
||||
for (int64_t m = 0; m < M; m++)
|
||||
for (int64_t n = 0; n < N; n++) {
|
||||
if (layout(h, m, n) == 0)
|
||||
continue;
|
||||
idx(h, m, n) = current++;
|
||||
}
|
||||
// create lut
|
||||
for (int max_width = start_width; max_width > 0; max_width /= 2) {
|
||||
auto lut = segment_blocks(layout, idx, max_width, H, M, N);
|
||||
if (lut.size() == 0)
|
||||
continue;
|
||||
ret.push_back(std::make_pair(max_width, pybind11::array_t<int>(lut.size(), lut.data())));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void init_superblocking(pybind11::module &m) {
|
||||
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
|
||||
}
|
2292
python/src/triton.cc
2292
python/src/triton.cc
File diff suppressed because it is too large
Load Diff
56
python/test/unit/language/printf_helper.py
Normal file
56
python/test/unit/language/printf_helper.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
torch_type = {
|
||||
"bool": torch.bool,
|
||||
'int8': torch.int8,
|
||||
'uint8': torch.uint8,
|
||||
'int16': torch.int16,
|
||||
"int32": torch.int32,
|
||||
'int64': torch.long,
|
||||
'float16': torch.float16,
|
||||
'bfloat16': torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64
|
||||
}
|
||||
|
||||
|
||||
def get_tensor(shape, data_type, b_positive=False):
|
||||
x = None
|
||||
if data_type.startswith('int'):
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
else:
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
|
||||
return x
|
||||
|
||||
# @pytest.mark.parametrize('data_type',
|
||||
# [("int8"),
|
||||
# ('int16'),
|
||||
# ('int32'),
|
||||
# ("int64"),
|
||||
# ('float16'),
|
||||
# ("float32"),
|
||||
# ("float64")])
|
||||
|
||||
|
||||
def printf(data_type):
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.printf("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = get_tensor(shape, data_type)
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
kernel[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
printf("float16")
|
||||
printf("int8")
|
@@ -1,5 +1,6 @@
|
||||
# flake8: noqa: F821,F841
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -104,8 +105,8 @@ def check_type_supported(dtype):
|
||||
'''
|
||||
skip test if dtype is not supported on the current device
|
||||
'''
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
cc = torch.cuda.get_device_capability()
|
||||
if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
|
||||
|
||||
@@ -414,8 +415,8 @@ def test_where(dtype):
|
||||
def test_where_broadcast():
|
||||
@triton.jit
|
||||
def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
|
||||
yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
|
||||
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
||||
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
||||
|
||||
mask = tl.load(cond_ptr + yoffsets)
|
||||
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
||||
@@ -424,8 +425,8 @@ def test_where_broadcast():
|
||||
|
||||
@triton.jit
|
||||
def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
|
||||
yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
|
||||
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
||||
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
||||
mask = 0
|
||||
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
||||
res = tl.where(mask, vals, 0.)
|
||||
@@ -462,9 +463,6 @@ def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
# ----------------
|
||||
# test math ops
|
||||
# ----------------
|
||||
# @pytest.mark.parametrize("expr", [
|
||||
# 'exp', 'log', 'cos', 'sin'
|
||||
# ])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expr", [
|
||||
@@ -490,9 +488,12 @@ def make_ptr_str(name, shape):
|
||||
return f"{name} + {' + '.join(offsets)}"
|
||||
|
||||
|
||||
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
||||
@pytest.mark.parametrize("expr, dtype_str", [
|
||||
(f'x[{s}]', d)
|
||||
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
||||
for s in ['None, :', ':, None',
|
||||
'None, :, :',
|
||||
':, :, None']
|
||||
for d in ['int32', 'uint32', 'uint16']
|
||||
])
|
||||
def test_index1d(expr, dtype_str, device='cuda'):
|
||||
@@ -605,8 +606,8 @@ def test_tuples():
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
if dtype_x_str == 'float16':
|
||||
pytest.skip("Only test atomic float16 ops on devices with sm >= 70")
|
||||
n_programs = 5
|
||||
@@ -651,9 +652,10 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("axis", [0, 1])
|
||||
def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||
shape0, shape1 = 8, 8
|
||||
@pytest.mark.parametrize("shape, axis",
|
||||
[(shape, axis) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32)] for axis in [0, 1]])
|
||||
def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
shape0, shape1 = shape
|
||||
# triton kernel
|
||||
|
||||
@triton.jit
|
||||
@@ -662,14 +664,18 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||
off1 = tl.arange(0, SHAPE1)
|
||||
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
||||
z = tl.sum(x, axis=AXIS)
|
||||
tl.atomic_add(Z + off0, z)
|
||||
if AXIS == 1:
|
||||
tl.atomic_add(Z + off0, z)
|
||||
else:
|
||||
tl.atomic_add(Z + off1, z)
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||
# reference result
|
||||
z_ref = np.sum(x, axis=axis)
|
||||
z_ref = np.sum(x, axis=axis, keepdims=False)
|
||||
# triton result
|
||||
x_tri = to_triton(x, device=device)
|
||||
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
|
||||
z_shape = (shape0, ) if axis == 1 else (shape1, )
|
||||
z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device)
|
||||
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
@@ -724,6 +730,10 @@ def test_atomic_cas():
|
||||
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
# bfloat16 on cc < 80 will not be tested
|
||||
check_type_supported(dtype_x)
|
||||
check_type_supported(dtype_z)
|
||||
|
||||
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||
x0 = 43 if dtype_x in int_dtypes else 43.5
|
||||
if dtype_x in float_dtypes and dtype_z == 'int1':
|
||||
@@ -737,9 +747,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, BITCAST: tl.constexpr):
|
||||
x = tl.load(X)
|
||||
x_ptr = X + tl.arange(0, 1)
|
||||
z_ptr = Z + tl.arange(0, 1)
|
||||
x = tl.load(x_ptr)
|
||||
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
||||
tl.store(Z, z)
|
||||
tl.store(z_ptr, z)
|
||||
|
||||
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
|
||||
# triton result
|
||||
@@ -869,9 +881,19 @@ def test_f16_to_f8_rounding():
|
||||
# ---------------
|
||||
|
||||
|
||||
def get_reduced_dtype(dtype_str, op):
|
||||
if op == 'argmin' or op == 'argmax':
|
||||
return 'int32'
|
||||
if dtype_str in ['int8', 'uint8', 'int16', 'uint16']:
|
||||
return 'int32'
|
||||
if dtype_str == 'bfloat16':
|
||||
return 'float32'
|
||||
return dtype_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape",
|
||||
[(op, dtype, shape)
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for op in ['min', 'max', 'sum']
|
||||
for dtype in dtypes_with_bfloat16
|
||||
for shape in [32, 64, 128, 512]])
|
||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
@@ -892,7 +914,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
# numpy result
|
||||
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
z_dtype_str = get_reduced_dtype(dtype_str, op)
|
||||
z_tri_dtype_str = z_dtype_str
|
||||
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||
z_dtype_str = 'float32'
|
||||
@@ -919,21 +941,35 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
np.testing.assert_equal(z_ref, z_tri)
|
||||
|
||||
|
||||
# TODO: [Qingyi] Fix argmin / argmax
|
||||
reduce_configs1 = [
|
||||
(op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for op in ['min', 'max', 'sum']
|
||||
for axis in [1]
|
||||
]
|
||||
|
||||
|
||||
# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory
|
||||
# exceeds the limit of 99KB
|
||||
reduce2d_shapes = [(2, 32), (4, 32), (4, 128)]
|
||||
# TODO: fix and uncomment
|
||||
# , (32, 64), (64, 128)]
|
||||
if 'V100' in torch.cuda.get_device_name(0):
|
||||
reduce2d_shapes += [(128, 256) and (32, 1024)]
|
||||
|
||||
|
||||
reduce_configs2 = [
|
||||
(op, 'float32', shape, axis)
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
|
||||
for op in ['min', 'max', 'sum']
|
||||
for shape in reduce2d_shapes
|
||||
for axis in [0, 1]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
||||
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||
@@ -954,7 +990,7 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
x_tri = to_triton(x)
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
z_dtype_str = get_reduced_dtype(dtype_str, op)
|
||||
z_tri_dtype_str = z_dtype_str
|
||||
# numpy result
|
||||
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
||||
@@ -992,7 +1028,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||
[(dtype, shape, perm)
|
||||
for dtype in ['bfloat16', 'float16', 'float32']
|
||||
# TODO: bfloat16
|
||||
for dtype in ['float16', 'float32']
|
||||
for shape in [(64, 64), (128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
@@ -1038,25 +1075,37 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
|
||||
[(epilogue, allow_tf32, dtype)
|
||||
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
|
||||
[(*shape, 4, False, False, epilogue, allow_tf32, dtype)
|
||||
for shape in [(64, 64, 64)]
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float16']
|
||||
if not (allow_tf32 and (dtype in ['float16']))])
|
||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
for dtype in ['float16', 'float32']
|
||||
if not (allow_tf32 and (dtype in ['float16']))] +
|
||||
|
||||
[(*shape_nw, col_a, col_b, 'none', allow_tf32, dtype)
|
||||
for shape_nw in [[128, 256, 32, 8],
|
||||
[128, 16, 32, 4],
|
||||
[32, 128, 64, 4],
|
||||
[128, 128, 64, 4],
|
||||
[64, 128, 128, 4],
|
||||
[32, 128, 64, 2],
|
||||
[128, 128, 64, 2],
|
||||
[64, 128, 128, 4]]
|
||||
for allow_tf32 in [True]
|
||||
for col_a in [True, False]
|
||||
for col_b in [True, False]
|
||||
for dtype in ['int8', 'float16', 'float32']])
|
||||
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
if cc < 80:
|
||||
if capability[0] < 8:
|
||||
if dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
|
||||
M, N, K = 128, 128, 64
|
||||
num_warps = 8
|
||||
trans_a, trans_b = False, False
|
||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -1068,7 +1117,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
ALLOW_TF32: tl.constexpr,
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
|
||||
COL_A: tl.constexpr, COL_B: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
off_l = tl.arange(0, BLOCK_N)
|
||||
@@ -1077,7 +1126,9 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
||||
Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32)
|
||||
x = tl.load(Xs)
|
||||
y = tl.load(Ys)
|
||||
z = tl.dot(x, y, allow_tf32=ALLOW_TF32)
|
||||
if ADD_MATRIX:
|
||||
z += tl.load(Zs)
|
||||
if ADD_ROWS:
|
||||
@@ -1093,16 +1144,24 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
den = tl.sum(num, 1)
|
||||
z = num / den[:, None]
|
||||
if CHAIN_DOT:
|
||||
# tl.store(Zs, z)
|
||||
# tl.debug_barrier()
|
||||
z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A)
|
||||
w = tl.load(Ws)
|
||||
z = tl.dot(z.to(w.dtype), w)
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
|
||||
y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
|
||||
w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
|
||||
if allow_tf32:
|
||||
if col_a:
|
||||
x = numpy_random((K, M), dtype_str=dtype, rs=rs).T
|
||||
else:
|
||||
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
|
||||
if col_b:
|
||||
y = numpy_random((N, K), dtype_str=dtype, rs=rs).T
|
||||
else:
|
||||
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
|
||||
w = numpy_random((N, N), dtype_str=dtype, rs=rs)
|
||||
if 'int' not in dtype:
|
||||
x *= .1
|
||||
y *= .1
|
||||
if dtype == 'float32' and allow_tf32:
|
||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
@@ -1110,7 +1169,11 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
y_tri = to_triton(y, device=device)
|
||||
w_tri = to_triton(w, device=device)
|
||||
# triton result
|
||||
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
|
||||
if dtype == 'int8':
|
||||
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
|
||||
else:
|
||||
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
|
||||
|
||||
z_tri = to_triton(z, device=device)
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
@@ -1118,7 +1181,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
TRANS_A=trans_a, TRANS_B=trans_b,
|
||||
COL_A=col_a, COL_B=col_b,
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
@@ -1128,9 +1191,12 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps)
|
||||
# torch result
|
||||
x_ref = x.T if trans_a else x
|
||||
y_ref = y.T if trans_b else y
|
||||
z_ref = np.matmul(x_ref, y_ref)
|
||||
if dtype == 'int8':
|
||||
z_ref = np.matmul(x.astype(np.float32),
|
||||
y.astype(np.float32())).astype(np.int32)
|
||||
else:
|
||||
z_ref = np.matmul(x, y)
|
||||
|
||||
if epilogue == 'add-matrix':
|
||||
z_ref += z
|
||||
if epilogue == 'add-rows':
|
||||
@@ -1142,17 +1208,21 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
denom = np.sum(num, axis=-1, keepdims=True)
|
||||
z_ref = num / denom
|
||||
if epilogue == 'chain-dot':
|
||||
z_ref = np.matmul(z_ref.T if trans_a else z_ref, w)
|
||||
z_ref = np.matmul(z_ref, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
if dtype == 'float32':
|
||||
# XXX: Somehow there's a larger difference when we use float32
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
if allow_tf32:
|
||||
if dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
elif dtype == 'float32':
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
@@ -1216,7 +1286,7 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
|
||||
in_offsets = tl.arange(0, out_size)
|
||||
# Load inputs.
|
||||
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1.0)
|
||||
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
|
||||
# Store output
|
||||
output_offsets = tl.arange(0, out_size)
|
||||
tl.store(out_ptr + output_offsets, x)
|
||||
@@ -1227,16 +1297,12 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
|
||||
triton.testing.allclose(output, reference_out)
|
||||
|
||||
|
||||
# 'bfloat16': torch.bfloat16,
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
M = 32
|
||||
@@ -1325,6 +1391,7 @@ def test_vectorization(N):
|
||||
else:
|
||||
assert "ld.global.b32" in ptx
|
||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
# ---------------
|
||||
@@ -1402,6 +1469,10 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
||||
JITFunction.cache_hook = None
|
||||
assert spec_type == value_type
|
||||
|
||||
# --------------------
|
||||
# value specialization
|
||||
# --------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, overflow",
|
||||
@@ -1552,9 +1623,23 @@ def test_num_warps_pow2():
|
||||
# -------------
|
||||
|
||||
|
||||
def system_libdevice_path() -> str:
|
||||
_SYSTEM_LIBDEVICE_SEARCH_PATHS = [
|
||||
'/usr/lib/cuda/nvvm/libdevice/libdevice.10.bc',
|
||||
'/usr/local/cuda/nvvm/libdevice/libdevice.10.bc',
|
||||
]
|
||||
SYSTEM_LIBDEVICE_PATH: Optional[str] = None
|
||||
for _p in _SYSTEM_LIBDEVICE_SEARCH_PATHS:
|
||||
if os.path.exists(_p):
|
||||
SYSTEM_LIBDEVICE_PATH = _p
|
||||
assert SYSTEM_LIBDEVICE_PATH is not None, \
|
||||
"Could not find libdevice.10.bc path"
|
||||
return SYSTEM_LIBDEVICE_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
||||
[('int32', 'libdevice.ffs', ''),
|
||||
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
|
||||
('float32', 'libdevice.pow', system_libdevice_path()),
|
||||
('float64', 'libdevice.norm4d', '')])
|
||||
def test_libdevice_tensor(dtype_str, expr, lib_path):
|
||||
|
||||
@@ -1621,3 +1706,95 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
||||
# compare
|
||||
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
||||
|
||||
# -----------------------
|
||||
# test layout conversions
|
||||
# -----------------------
|
||||
# TODO: backend hsould be tested separately
|
||||
|
||||
|
||||
class MmaLayout:
|
||||
def __init__(self, version, warps_per_cta):
|
||||
self.version = version
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>"
|
||||
|
||||
|
||||
class BlockedLayout:
|
||||
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
|
||||
self.sz_per_thread = str(size_per_thread)
|
||||
self.threads_per_warp = str(threads_per_warp)
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.order = str(order)
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
|
||||
|
||||
|
||||
layouts = [
|
||||
# MmaLayout(version=1, warps_per_cta=[1, 4]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
|
||||
# MmaLayout(version=1, warps_per_cta=[4, 1]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
|
||||
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
|
||||
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
|
||||
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
|
||||
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
|
||||
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
|
||||
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(128, 128)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
@pytest.mark.parametrize("dst_layout", layouts)
|
||||
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
if str(src_layout) == str(dst_layout):
|
||||
pytest.skip()
|
||||
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
|
||||
pytest.skip()
|
||||
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
#dst = {dst_layout}
|
||||
""" + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
|
||||
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
|
||||
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #src>
|
||||
%4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src>
|
||||
%5 = arith.muli %4, %cst : tensor<128x1xi32, #src>
|
||||
%6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src>
|
||||
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
|
||||
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
|
||||
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
|
||||
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
|
||||
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
|
||||
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
|
||||
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
|
||||
tt.store %14, %13 : tensor<128x128xf16, #dst>
|
||||
return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
x = to_triton(numpy_random(shape, dtype_str=dtype))
|
||||
z = torch.empty_like(x)
|
||||
|
||||
# write the IR to a temporary file using mkstemp
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
||||
f.write(ir)
|
||||
f.flush()
|
||||
kernel = triton.compile(f.name)
|
||||
kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr())
|
||||
|
||||
assert torch.equal(z, x)
|
||||
|
@@ -1,261 +0,0 @@
|
||||
# flake8: noqa: F821,F841
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int8(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int8(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int4(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int4(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int2(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + 2 + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(input_ptr).to(tl.float16, bitcast=True)
|
||||
shift = tl.load(input_ptr + 1).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int2(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
def test_dequantize_int8() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 128, 4)
|
||||
else:
|
||||
size = random.randrange(132, 1024, 4)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int8 * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 128)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int8[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int8[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int4() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 2,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input_int8_h1 = input_int8 >> 4
|
||||
input_int8_h0 = input_int8 & 15
|
||||
|
||||
input_int4_val = torch.stack(
|
||||
(input_int8_h0, input_int8_h1), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int4_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int4[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int4[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int2() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int16)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 4,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int16 = input_int8.view(torch.int16)
|
||||
|
||||
input_int8_q3 = input_int8 >> 6
|
||||
input_int8_q2 = (input_int8 >> 4) & 3
|
||||
input_int8_q1 = (input_int8 >> 2) & 3
|
||||
input_int8_q0 = input_int8 & 3
|
||||
|
||||
input_int2_val = torch.stack(
|
||||
(input_int8_q0, input_int8_q1, input_int8_q2, input_int8_q3), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int16))
|
||||
expected = (input_int2_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
|
||||
dequantize_kernel_int2[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int2[grid](
|
||||
output,
|
||||
input_int16,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
22
python/test/unit/language/test_printf.py
Normal file
22
python/test/unit/language/test_printf.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
printf_path = os.path.join(dir_path, "printf_helper.py")
|
||||
|
||||
|
||||
def test_printf():
|
||||
proc = subprocess.Popen([sys.executable, printf_path], stdout=subprocess.PIPE, shell=False)
|
||||
(outs, err) = proc.communicate()
|
||||
outs = outs.split()
|
||||
new_lines = set()
|
||||
for line in outs:
|
||||
try:
|
||||
value = int(float(line))
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
assert len(new_lines) == 128
|
@@ -2,13 +2,13 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
||||
@pytest.mark.parametrize("TRANS_A", [False, True])
|
||||
@pytest.mark.parametrize("TRANS_B", [False, True])
|
||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
||||
# TODO: float32 fails
|
||||
@pytest.mark.parametrize("DTYPE", [torch.float16])
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
|
||||
seed = 0
|
||||
@@ -32,9 +32,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
layout[1, 2, :] = 0
|
||||
layout[1, :, 1] = 0
|
||||
# create data
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1)
|
||||
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1)
|
||||
dc_ref, dc_tri = triton.testing.make_pair(c_shape)
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE)
|
||||
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE)
|
||||
dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE)
|
||||
# compute [torch]
|
||||
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
|
||||
a_ref = do_mask(a_ref) if is_dsd else a_ref
|
||||
@@ -126,8 +126,8 @@ def test_attention_fwd_bwd(
|
||||
batch_size=2,
|
||||
n_heads=2,
|
||||
):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
|
||||
# inputs
|
||||
|
@@ -2,20 +2,19 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['bfloat16', 'float16', 'float32']
|
||||
for dtype in ['float16', 'float32']
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
)
|
||||
def test_op(M, N, dtype, mode):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and dtype == "bfloat16":
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8 and dtype == "bfloat16":
|
||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||
dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||
# create inputs
|
||||
|
@@ -4,7 +4,6 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -67,10 +66,10 @@ import triton._C.libtriton.triton as _triton
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 70:
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
if cc < 80 and DTYPE == "bfloat16":
|
||||
if capability[0] < 8 and DTYPE == "bfloat16":
|
||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||
if DTYPE == "bfloat16" and SPLIT_K != 1:
|
||||
pytest.skip("bfloat16 matmuls don't allow split_k for now")
|
||||
|
@@ -1,15 +1,52 @@
|
||||
"""isort:skip_file"""
|
||||
# flake8: noqa: F401
|
||||
__version__ = '2.0.0'
|
||||
|
||||
# ---------------------------------------
|
||||
# Note: import order is significant here.
|
||||
|
||||
# TODO: torch needs to be imported first
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch
|
||||
import torch # noqa: F401
|
||||
|
||||
# submodules
|
||||
from .utils import *
|
||||
from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface
|
||||
from . import impl
|
||||
from .utils import (
|
||||
cdiv,
|
||||
MockTensor,
|
||||
next_power_of_2,
|
||||
reinterpret,
|
||||
TensorWrapper,
|
||||
)
|
||||
from .runtime import (
|
||||
autotune,
|
||||
Config,
|
||||
heuristics,
|
||||
JITFunction,
|
||||
KernelInterface,
|
||||
)
|
||||
from .runtime.jit import jit
|
||||
from .compiler import compile, CompilationError
|
||||
from . import language
|
||||
from . import testing
|
||||
from . import ops
|
||||
|
||||
__all__ = [
|
||||
"autotune",
|
||||
"cdiv",
|
||||
"CompilationError",
|
||||
"compile",
|
||||
"Config",
|
||||
"heuristics",
|
||||
"impl",
|
||||
"jit",
|
||||
"JITFunction",
|
||||
"KernelInterface",
|
||||
"language",
|
||||
"MockTensor",
|
||||
"next_power_of_2",
|
||||
"ops",
|
||||
"reinterpret",
|
||||
"runtime",
|
||||
"TensorWrapper",
|
||||
"testing",
|
||||
]
|
||||
|
File diff suppressed because it is too large
Load Diff
18
python/triton/impl/__init__.py
Normal file
18
python/triton/impl/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Triton internal implementation details.
|
||||
|
||||
Client libraries should not import interfaces from the `triton.impl` module;
|
||||
as the details are subject to change.
|
||||
|
||||
APIs defined in the `triton.impl` module which are public will be re-exported
|
||||
in other relevant `triton` module namespaces.
|
||||
"""
|
||||
|
||||
from .base import builtin, extern, is_builtin
|
||||
from triton._C.libtriton.triton import ir
|
||||
|
||||
__all__ = [
|
||||
"builtin",
|
||||
"extern",
|
||||
"ir",
|
||||
"is_builtin",
|
||||
]
|
36
python/triton/impl/base.py
Normal file
36
python/triton/impl/base.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
TRITON_BUILTIN = "__triton_builtin__"
|
||||
|
||||
|
||||
def builtin(fn: T) -> T:
|
||||
"""Mark a function as a builtin."""
|
||||
assert callable(fn)
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if "_builder" not in kwargs or kwargs["_builder"] is None:
|
||||
raise ValueError(
|
||||
"Did you forget to add @triton.jit ? "
|
||||
"(`_builder` argument must be provided outside of JIT functions.)"
|
||||
)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
setattr(wrapper, TRITON_BUILTIN, True)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_builtin(fn) -> bool:
|
||||
"""Is this a registered triton builtin function?"""
|
||||
return getattr(fn, TRITON_BUILTIN, False)
|
||||
|
||||
|
||||
def extern(fn: T) -> T:
|
||||
"""A decorator for external functions."""
|
||||
return builtin(fn)
|
@@ -1,4 +1,181 @@
|
||||
# flake8: noqa: F401
|
||||
from . import core, extern, libdevice, random
|
||||
from .core import *
|
||||
from .random import *
|
||||
"""isort:skip_file"""
|
||||
# Import order is significant here.
|
||||
|
||||
from ..impl import (
|
||||
ir,
|
||||
builtin,
|
||||
)
|
||||
from . import libdevice
|
||||
from .core import (
|
||||
abs,
|
||||
arange,
|
||||
argmin,
|
||||
argmax,
|
||||
atomic_add,
|
||||
atomic_and,
|
||||
atomic_cas,
|
||||
atomic_max,
|
||||
atomic_min,
|
||||
atomic_or,
|
||||
atomic_xchg,
|
||||
atomic_xor,
|
||||
bfloat16,
|
||||
block_type,
|
||||
broadcast,
|
||||
broadcast_to,
|
||||
cat,
|
||||
cdiv,
|
||||
constexpr,
|
||||
cos,
|
||||
debug_barrier,
|
||||
dot,
|
||||
dtype,
|
||||
exp,
|
||||
fdiv,
|
||||
float16,
|
||||
float32,
|
||||
float64,
|
||||
float8,
|
||||
function_type,
|
||||
int1,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
int8,
|
||||
load,
|
||||
log,
|
||||
max,
|
||||
max_contiguous,
|
||||
maximum,
|
||||
min,
|
||||
minimum,
|
||||
multiple_of,
|
||||
num_programs,
|
||||
pi32_t,
|
||||
pointer_type,
|
||||
printf,
|
||||
program_id,
|
||||
ravel,
|
||||
reshape,
|
||||
sigmoid,
|
||||
sin,
|
||||
softmax,
|
||||
sqrt,
|
||||
store,
|
||||
sum,
|
||||
swizzle2d,
|
||||
tensor,
|
||||
trans,
|
||||
triton,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
uint8,
|
||||
umulhi,
|
||||
view,
|
||||
void,
|
||||
where,
|
||||
xor_sum,
|
||||
zeros,
|
||||
zeros_like,
|
||||
)
|
||||
from .random import (
|
||||
pair_uniform_to_normal,
|
||||
philox,
|
||||
philox_impl,
|
||||
rand,
|
||||
rand4x,
|
||||
randint,
|
||||
randint4x,
|
||||
randn,
|
||||
randn4x,
|
||||
uint32_to_uniform_float,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"abs",
|
||||
"arange",
|
||||
"argmin",
|
||||
"argmax",
|
||||
"atomic_add",
|
||||
"atomic_and",
|
||||
"atomic_cas",
|
||||
"atomic_max",
|
||||
"atomic_min",
|
||||
"atomic_or",
|
||||
"atomic_xchg",
|
||||
"atomic_xor",
|
||||
"bfloat16",
|
||||
"block_type",
|
||||
"broadcast",
|
||||
"broadcast_to",
|
||||
"builtin",
|
||||
"cat",
|
||||
"cdiv",
|
||||
"constexpr",
|
||||
"cos",
|
||||
"debug_barrier",
|
||||
"dot",
|
||||
"dtype",
|
||||
"exp",
|
||||
"fdiv",
|
||||
"float16",
|
||||
"float32",
|
||||
"float64",
|
||||
"float8",
|
||||
"function_type",
|
||||
"int1",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"int8",
|
||||
"ir",
|
||||
"libdevice",
|
||||
"load",
|
||||
"log",
|
||||
"max",
|
||||
"max_contiguous",
|
||||
"maximum",
|
||||
"min",
|
||||
"minimum",
|
||||
"multiple_of",
|
||||
"num_programs",
|
||||
"pair_uniform_to_normal",
|
||||
"philox",
|
||||
"philox_impl",
|
||||
"pi32_t",
|
||||
"pointer_type",
|
||||
"printf",
|
||||
"program_id",
|
||||
"rand",
|
||||
"rand4x",
|
||||
"randint",
|
||||
"randint4x",
|
||||
"randn",
|
||||
"randn4x",
|
||||
"ravel",
|
||||
"reshape",
|
||||
"sigmoid",
|
||||
"sin",
|
||||
"softmax",
|
||||
"sqrt",
|
||||
"store",
|
||||
"sum",
|
||||
"swizzle2d",
|
||||
"tensor",
|
||||
"trans",
|
||||
"triton",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint32_to_uniform_float",
|
||||
"uint64",
|
||||
"uint8",
|
||||
"umulhi",
|
||||
"view",
|
||||
"void",
|
||||
"where",
|
||||
"xor_sum",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
]
|
||||
|
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import List
|
||||
from typing import Callable, List, TypeVar
|
||||
|
||||
import triton
|
||||
from . import semantic
|
||||
from . import builtin, semantic
|
||||
from triton._C.libtriton.triton import ir
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def _to_tensor(x, builder):
|
||||
if isinstance(x, bool):
|
||||
@@ -17,41 +18,28 @@ def _to_tensor(x, builder):
|
||||
if -2**31 <= x < 2**31:
|
||||
return tensor(builder.get_int32(x), int32)
|
||||
elif 2**31 <= x < 2**32:
|
||||
return tensor(builder.get_uint32(x), uint32)
|
||||
return tensor(builder.get_int32(x), uint32)
|
||||
elif -2**63 <= x < 2**63:
|
||||
return tensor(builder.get_int64(x), int64)
|
||||
elif 2**63 <= x < 2**64:
|
||||
return tensor(builder.get_uint64(x), uint64)
|
||||
return tensor(builder.get_int64(x), uint64)
|
||||
else:
|
||||
raise RuntimeError(f'Nonrepresentable integer {x}.')
|
||||
elif isinstance(x, float):
|
||||
return tensor(builder.get_float32(x), float32)
|
||||
elif isinstance(x, constexpr):
|
||||
if x.value is None:
|
||||
return None
|
||||
return _to_tensor(x.value, builder)
|
||||
elif isinstance(x, tensor):
|
||||
return x
|
||||
elif x is None:
|
||||
return None
|
||||
assert False, f'cannot convert {x} to tensor'
|
||||
|
||||
|
||||
def builtin(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if '_builder' not in kwargs or \
|
||||
kwargs['_builder'] is None:
|
||||
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class dtype:
|
||||
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
|
||||
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64']
|
||||
CUSTOMIZED_FP_TYPES = ['fp8']
|
||||
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
|
||||
OTHER_TYPES = ['void']
|
||||
|
||||
class SIGNEDNESS(Enum):
|
||||
@@ -133,6 +121,12 @@ class dtype:
|
||||
def is_floating(self):
|
||||
return self.name in dtype.FP_TYPES
|
||||
|
||||
def is_customized_floating(self):
|
||||
return self.name in dtype.CUSTOMIZED_FP_TYPES
|
||||
|
||||
def is_standard_floating(self):
|
||||
return self.name in dtype.STANDARD_FP_TYPES
|
||||
|
||||
def is_int_signed(self):
|
||||
return self.name in dtype.SINT_TYPES
|
||||
|
||||
@@ -146,7 +140,7 @@ class dtype:
|
||||
return self.is_int1()
|
||||
|
||||
def is_void(self):
|
||||
return self.name == 'void'
|
||||
raise RuntimeError("Not implemented")
|
||||
|
||||
def is_block(self):
|
||||
return False
|
||||
@@ -216,7 +210,7 @@ class pointer_type(dtype):
|
||||
self.name = self.__str__()
|
||||
|
||||
def to_ir(self, builder: ir.builder) -> ir.pointer_type:
|
||||
return ir.type.make_ptr(self.element_ty.to_ir(builder), 1)
|
||||
return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1)
|
||||
|
||||
def __str__(self):
|
||||
return f'pointer<{self.element_ty}>'
|
||||
@@ -241,22 +235,27 @@ class pointer_type(dtype):
|
||||
|
||||
|
||||
class block_type(dtype):
|
||||
def __init__(self, element_ty: dtype, shape: List[int]):
|
||||
def __init__(self, element_ty: dtype, shape: List):
|
||||
self.element_ty = element_ty
|
||||
# FIXME:
|
||||
# block_type's shape is a list of int
|
||||
# while tensor's shape is a list of constexpr
|
||||
|
||||
# Note that block_type's shape is a list of int
|
||||
# while tensor's shape is a list of constexpr.
|
||||
|
||||
# shape can be empty ([]) when an input is a 0D tensor.
|
||||
if not shape:
|
||||
raise TypeError('0d block_type is forbidden')
|
||||
if isinstance(shape[0], constexpr):
|
||||
shape = [s.value for s in shape]
|
||||
|
||||
self.shape = shape
|
||||
self.numel = 1
|
||||
for i, s in enumerate(self.shape):
|
||||
if isinstance(s, constexpr):
|
||||
self.shape[i] = s.value
|
||||
self.numel *= self.shape[i]
|
||||
for s in self.shape:
|
||||
self.numel *= s
|
||||
|
||||
self.name = self.__str__()
|
||||
|
||||
def to_ir(self, builder: ir.builder) -> ir.block_type:
|
||||
return ir.type.make_block(self.element_ty.to_ir(builder), self.shape)
|
||||
return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape)
|
||||
|
||||
def __str__(self):
|
||||
return f'<{self.shape}, {self.element_ty}>'
|
||||
@@ -284,28 +283,17 @@ class block_type(dtype):
|
||||
|
||||
|
||||
class function_type(dtype):
|
||||
def __init__(self, ret_type: dtype, param_types: List[dtype]) -> None:
|
||||
self.ret_type = ret_type
|
||||
def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None:
|
||||
self.ret_types = ret_types
|
||||
self.param_types = param_types
|
||||
|
||||
def __str__(self):
|
||||
return f'fn ({self.param_types}) -> {self.ret_type}'
|
||||
return f'fn ({self.param_types}) -> {self.ret_types}'
|
||||
|
||||
def to_ir(self, builder: ir.builder):
|
||||
ir_param_types = [ty.to_ir(builder) for ty in self.param_types]
|
||||
return ir.type.make_function(self.ret_type.to_ir(builder), ir_param_types)
|
||||
|
||||
|
||||
class tuple_type(dtype):
|
||||
def __init__(self, element_types: List[dtype]) -> None:
|
||||
self.element_types = element_types
|
||||
|
||||
def __str__(self):
|
||||
return f'<{self.element_types}>'
|
||||
|
||||
def to_ir(self, builder: ir.builder):
|
||||
ir_element_types = [ty.to_ir(builder) for ty in self.element_types]
|
||||
return ir.struct_type.get(ir_element_types, True)
|
||||
ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types]
|
||||
return builder.get_function_ty(ir_param_types, ret_types)
|
||||
|
||||
|
||||
# scalar types
|
||||
@@ -346,83 +334,96 @@ class constexpr:
|
||||
def __repr__(self) -> str:
|
||||
return f"constexpr[{self.value}]"
|
||||
|
||||
def __add__(self, other):
|
||||
return constexpr(self.value + other.value)
|
||||
|
||||
def __radd__(self, other):
|
||||
return constexpr(other.value + self.value)
|
||||
|
||||
def __sub__(self, other):
|
||||
return constexpr(self.value - other.value)
|
||||
|
||||
def __rsub__(self, other):
|
||||
return constexpr(other.value - self.value)
|
||||
|
||||
def __mul__(self, other):
|
||||
return constexpr(self.value * other.value)
|
||||
|
||||
def __mod__(self, other):
|
||||
return constexpr(self.value % other.value)
|
||||
|
||||
def __rmul__(self, other):
|
||||
return constexpr(other.value * self.value)
|
||||
|
||||
def __truediv__(self, other):
|
||||
return constexpr(self.value / other.value)
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return constexpr(other.value / self.value)
|
||||
|
||||
def __floordiv__(self, other):
|
||||
return constexpr(self.value // other.value)
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
return constexpr(other.value // self.value)
|
||||
|
||||
def __gt__(self, other):
|
||||
return constexpr(self.value > other.value)
|
||||
|
||||
def __rgt__(self, other):
|
||||
return constexpr(other.value > self.value)
|
||||
|
||||
def __ge__(self, other):
|
||||
return constexpr(self.value >= other.value)
|
||||
|
||||
def __rge__(self, other):
|
||||
return constexpr(other.value >= self.value)
|
||||
|
||||
def __lt__(self, other):
|
||||
return constexpr(self.value < other.value)
|
||||
|
||||
def __rlt__(self, other):
|
||||
return constexpr(other.value < self.value)
|
||||
|
||||
def __le__(self, other):
|
||||
return constexpr(self.value <= other.value)
|
||||
|
||||
def __rle__(self, other):
|
||||
return constexpr(other.value <= self.value)
|
||||
|
||||
def __eq__(self, other):
|
||||
return constexpr(self.value == other.value)
|
||||
|
||||
def __ne__(self, other):
|
||||
return constexpr(self.value != other.value)
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.value)
|
||||
|
||||
def __ge__(self, other):
|
||||
other = other.value if isinstance(other, constexpr) else other
|
||||
return self.value >= other
|
||||
def __neg__(self):
|
||||
return constexpr(-self.value)
|
||||
|
||||
def __gt__(self, other):
|
||||
other = other.value if isinstance(other, constexpr) else other
|
||||
return self.value > other
|
||||
def __pos__(self):
|
||||
return constexpr(+self.value)
|
||||
|
||||
def __le__(self, other):
|
||||
other = other.value if isinstance(other, constexpr) else other
|
||||
return self.value <= other
|
||||
|
||||
def __lt__(self, other):
|
||||
other = other.value if isinstance(other, constexpr) else other
|
||||
return self.value < other
|
||||
|
||||
def __eq__(self, other):
|
||||
other = other.value if isinstance(other, constexpr) else other
|
||||
return self.value == other
|
||||
def __invert__(self):
|
||||
return constexpr(~self.value)
|
||||
|
||||
def __call__(self, *args, **kwds):
|
||||
return self.value(*args, **kwds)
|
||||
|
||||
def to(self, dtype, bitcast=False, _builder=None):
|
||||
if dtype in [float8, float16, bfloat16]:
|
||||
raise ValueError("floating point constexpr must be float64")
|
||||
if dtype.is_int():
|
||||
ret_ty = int
|
||||
elif dtype.is_bool():
|
||||
ret_ty = bool
|
||||
elif dtype.is_floating():
|
||||
ret_ty = float
|
||||
return constexpr(ret_ty(self.value))
|
||||
|
||||
|
||||
class tensor:
|
||||
# infer dtype from ir type
|
||||
@staticmethod
|
||||
def _to_dtype(ir_type):
|
||||
# block type
|
||||
if ir_type.is_block():
|
||||
scalar_ty = tensor._to_dtype(ir_type.scalar)
|
||||
return block_type(scalar_ty, ir_type.get_block_shapes())
|
||||
# pointer type
|
||||
if ir_type.is_ptr():
|
||||
element_ty = tensor._to_dtype(ir_type.element)
|
||||
return pointer_type(element_ty)
|
||||
# primitive type
|
||||
if ir_type.is_void(): return void
|
||||
if ir_type.is_int1(): return int1
|
||||
if ir_type.is_int8(): return int8
|
||||
if ir_type.is_int16(): return int16
|
||||
if ir_type.is_int32(): return int32
|
||||
if ir_type.is_int64(): return int64
|
||||
if ir_type.is_fp8(): return float8
|
||||
if ir_type.is_fp16(): return float16
|
||||
if ir_type.is_bf16(): return bfloat16
|
||||
if ir_type.is_fp32(): return float32
|
||||
if ir_type.is_fp64(): return float64
|
||||
raise ValueError(f"Unsupported type {ir_type.repr()}")
|
||||
|
||||
def __init__(self, handle, type: dtype):
|
||||
# IR handle
|
||||
self.handle = handle
|
||||
# Block shape
|
||||
self.shape = (1, )
|
||||
if self.handle.type.is_block():
|
||||
self.shape = self.handle.type.shape
|
||||
if type.is_block():
|
||||
self.shape = type.shape
|
||||
self.numel = 1
|
||||
for s in self.shape:
|
||||
self.numel *= s
|
||||
is_pow2 = (self.numel and (not (self.numel & (self.numel - 1))))
|
||||
if not is_pow2:
|
||||
raise ValueError("Triton tensors must have a power-of-two number of elements")
|
||||
self.numel = constexpr(self.numel)
|
||||
self.type = type # Tensor type (can be block_type)
|
||||
# Following the practice in pytorch, dtype is scalar type
|
||||
@@ -580,22 +581,34 @@ class tensor:
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.not_equal(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def logical_and(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.logical_and(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def logical_or(self, other, _builder=None):
|
||||
other = _to_tensor(other, _builder)
|
||||
return semantic.logical_or(self, other, _builder)
|
||||
|
||||
@builtin
|
||||
def __getitem__(self, slices, _builder=None):
|
||||
if isinstance(slices, slice):
|
||||
slices = [slices]
|
||||
src_shape = self.shape
|
||||
dst_shape = []
|
||||
curr = 0
|
||||
for sl in slices:
|
||||
ret = self
|
||||
for dim, sl in enumerate(slices):
|
||||
if isinstance(sl, constexpr) and sl.value is None:
|
||||
dst_shape.append(1)
|
||||
ret = semantic.expand_dims(ret, dim, _builder)
|
||||
elif sl == slice(None, None, None):
|
||||
dst_shape.append(src_shape[curr].value)
|
||||
curr += 1
|
||||
ret = semantic.reshape(self, dst_shape, _builder)
|
||||
pass
|
||||
else:
|
||||
assert False, "unsupported"
|
||||
return ret
|
||||
|
||||
@property
|
||||
def T(self):
|
||||
assert False, "Transposition must be created by the AST Visitor"
|
||||
|
||||
@builtin
|
||||
def to(self, dtype, bitcast=False, _builder=None):
|
||||
if isinstance(bitcast, constexpr):
|
||||
@@ -685,20 +698,6 @@ def zeros(shape, dtype, _builder=None):
|
||||
return semantic.zeros(shape, dtype, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# dequantize
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def dequantize(input, scale, shift, nbit, dst_ty=float16, _builder=None):
|
||||
"""
|
||||
Tries to dequantize the input to given dtype
|
||||
"""
|
||||
nbit = _constexpr_to_value(nbit)
|
||||
return semantic.dequantize(input, scale, shift, nbit, dst_ty, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Shape Manipulation
|
||||
# -----------------------
|
||||
@@ -731,7 +730,12 @@ def broadcast_to(input, shape, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def cat(input, other, _builder=None):
|
||||
def trans(input, _builder=None):
|
||||
return semantic.trans(input, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def cat(input, other, can_reorder=False, _builder=None):
|
||||
"""
|
||||
Concatenate the given blocks
|
||||
|
||||
@@ -739,14 +743,19 @@ def cat(input, other, _builder=None):
|
||||
:type input:
|
||||
:param other: The second input tensor.
|
||||
:type other:
|
||||
:param reorder: Compiler hint. If true, the compiler is
|
||||
allowed to reorder elements while concatenating inputs.
|
||||
Only use if the order does not matter (e.g., result is
|
||||
only used in reduction ops)
|
||||
"""
|
||||
return semantic.cat(input, other, _builder)
|
||||
return semantic.cat(input, other, can_reorder, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def reshape(input, shape, _builder=None):
|
||||
def view(input, shape, _builder=None):
|
||||
"""
|
||||
Tries to reshape the given tensor to a new shape.
|
||||
Returns a tensor with the same elements as `input` but a different shape.
|
||||
The order of the elements may not be preserved.
|
||||
|
||||
:param input: The input tensor.
|
||||
:type input:
|
||||
@@ -755,20 +764,26 @@ def reshape(input, shape, _builder=None):
|
||||
|
||||
"""
|
||||
shape = [x.value for x in shape]
|
||||
return semantic.reshape(input, shape, _builder)
|
||||
return semantic.view(input, shape, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def reshape(input, shape, _builder=None):
|
||||
# TODO: should be more than just a view
|
||||
shape = [x.value for x in shape]
|
||||
return semantic.view(input, shape, _builder)
|
||||
|
||||
# -----------------------
|
||||
# Linear Algebra
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def dot(input, other, trans_a=False, trans_b=False, allow_tf32=True, _builder=None):
|
||||
def dot(input, other, allow_tf32=True, _builder=None):
|
||||
"""
|
||||
Returns the matrix product of two blocks.
|
||||
|
||||
The two blocks must be two dimensionals and have compatible inner dimensions.
|
||||
The two blocks must be two-dimensional and have compatible inner dimensions.
|
||||
|
||||
:param input: The first tensor to be multiplied.
|
||||
:type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
||||
@@ -776,7 +791,7 @@ def dot(input, other, trans_a=False, trans_b=False, allow_tf32=True, _builder=No
|
||||
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
||||
"""
|
||||
allow_tf32 = _constexpr_to_value(allow_tf32)
|
||||
return semantic.dot(input, other, trans_a, trans_b, allow_tf32, _builder)
|
||||
return semantic.dot(input, other, allow_tf32, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
@@ -814,7 +829,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
|
||||
|
||||
|
||||
@builtin
|
||||
def store(pointer, value, mask=None, eviction_policy="", _builder=None):
|
||||
def store(pointer, value, mask=None, _builder=None):
|
||||
"""
|
||||
Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
|
||||
|
||||
@@ -831,46 +846,27 @@ def store(pointer, value, mask=None, eviction_policy="", _builder=None):
|
||||
value = _to_tensor(value, _builder)
|
||||
if mask is not None:
|
||||
mask = _to_tensor(mask, _builder)
|
||||
return semantic.store(pointer, value, mask, eviction_policy, _builder)
|
||||
return semantic.store(pointer, value, mask, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Atomic Memory Operations
|
||||
# -----------------------
|
||||
|
||||
@builtin
|
||||
def atomic_cas(pointer, cmp, val, _builder=None):
|
||||
"""
|
||||
Performs an atomic compare-and-swap at the memory location specified by :code:`pointer`.
|
||||
def _add_atomic_docstr(name: str) -> Callable[[T], T]:
|
||||
|
||||
Return the data stored at :code:`pointer` before the atomic operation.
|
||||
|
||||
:param pointer: The memory locations to compare-and-swap.
|
||||
:type pointer: Block of dtype=triton.PointerDType
|
||||
:param cmp: The values expected to be found in the atomic object
|
||||
:type cmp: Block of dtype=`pointer.dtype.element_ty`
|
||||
:param val: The values to copy in case the expected value matches the contained value.
|
||||
:type val: Block of dtype=`pointer.dtype.element_ty`
|
||||
"""
|
||||
cmp = _to_tensor(cmp, _builder)
|
||||
val = _to_tensor(val, _builder)
|
||||
return semantic.atomic_cas(pointer, cmp, val, _builder)
|
||||
|
||||
|
||||
def _add_atomic_docstr(name):
|
||||
|
||||
def _decorator(func):
|
||||
def _decorator(func: T) -> T:
|
||||
docstr = """
|
||||
Performs an atomic {name} at the memory location specified by :code:`pointer`.
|
||||
|
||||
Return the data stored at :code:`pointer` before the atomic operation.
|
||||
|
||||
:param pointer: The memory locations to apply {name}.
|
||||
:param pointer: The memory locations to compare-and-swap.
|
||||
:type pointer: Block of dtype=triton.PointerDType
|
||||
:param val: The values to {name} in the atomic object.
|
||||
:param cmp: The values expected to be found in the atomic object
|
||||
:type cmp: Block of dtype=`pointer.dtype.element_ty`
|
||||
:param val: The values to copy in case the expected value matches the contained value.
|
||||
:type val: Block of dtype=`pointer.dtype.element_ty`
|
||||
:param mask: If mask[idx] is false, do not apply {name}.
|
||||
:type mask: Block of triton.int1, optional
|
||||
"""
|
||||
func.__doc__ = docstr.format(name=name)
|
||||
return func
|
||||
@@ -878,6 +874,14 @@ def _add_atomic_docstr(name):
|
||||
return _decorator
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("compare-and-swap")
|
||||
def atomic_cas(pointer, cmp, val, _builder=None):
|
||||
cmp = _to_tensor(cmp, _builder)
|
||||
val = _to_tensor(val, _builder)
|
||||
return semantic.atomic_cas(pointer, cmp, val, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("exchange")
|
||||
def atomic_xchg(pointer, val, mask=None, _builder=None):
|
||||
@@ -972,9 +976,9 @@ def fdiv(x, y, ieee_rounding=False, _builder=None):
|
||||
return semantic.fdiv(x, y, ieee_rounding, _builder)
|
||||
|
||||
|
||||
def _add_math_1arg_docstr(name):
|
||||
def _add_math_1arg_docstr(name: str) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func):
|
||||
def _decorator(func: T) -> T:
|
||||
docstr = """
|
||||
Computes the element-wise {name} of :code:`x`
|
||||
|
||||
@@ -1021,9 +1025,9 @@ def sqrt(x, _builder=None):
|
||||
# Reductions
|
||||
# -----------------------
|
||||
|
||||
def _add_reduction_docstr(name):
|
||||
def _add_reduction_docstr(name: str) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func):
|
||||
def _decorator(func: T) -> T:
|
||||
docstr = """
|
||||
Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
|
||||
|
||||
@@ -1077,19 +1081,6 @@ def xor_sum(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.xor_sum(input, axis, _builder)
|
||||
|
||||
# -----------------------
|
||||
# Utilities
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def globaltimer(_builder=None):
|
||||
return semantic.globaltimer(_builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def clock(_builder=None):
|
||||
return semantic.clock(_builder)
|
||||
|
||||
# -----------------------
|
||||
# Internal for debugging
|
||||
@@ -1189,7 +1180,7 @@ def sigmoid(x):
|
||||
|
||||
@triton.jit
|
||||
@_add_math_1arg_docstr("softmax")
|
||||
def softmax(x, ieee_rounding: constexpr = False):
|
||||
def softmax(x, ieee_rounding=False):
|
||||
z = x - triton.language.max(x, 0)
|
||||
num = triton.language.exp(z)
|
||||
den = triton.language.sum(num, 0)
|
||||
@@ -1204,13 +1195,13 @@ def ravel(x):
|
||||
:param x: the input tensor
|
||||
:type x: Block
|
||||
"""
|
||||
return triton.language.reshape(x, [x.numel])
|
||||
return triton.language.view(x, [x.numel])
|
||||
|
||||
|
||||
@triton.jit
|
||||
def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
"""
|
||||
transformes indices of a row-major size_i*size_j matrix into those
|
||||
Transforms indices of a row-major size_i*size_j matrix into those
|
||||
of one where indices are row major for each group of size_j rows.
|
||||
For example, for size_i = size_j = 4 and size_g = 2, it will transform
|
||||
[[0 , 1 , 2 , 3 ],
|
||||
@@ -1243,26 +1234,22 @@ def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
@triton.jit
|
||||
def zeros_like(input):
|
||||
return zeros(input.shape, input.dtype)
|
||||
# -----------------------
|
||||
# Dynamic Parallelism
|
||||
# -----------------------
|
||||
|
||||
|
||||
# class LaunchProxy:
|
||||
|
||||
# def __init__(self, fn, args, constants, grid, num_warps) -> None:
|
||||
# self.args = args
|
||||
# self.grid = grid
|
||||
# self.constants = constants
|
||||
# self.num_warps = num_warps
|
||||
# self.fn = fn
|
||||
|
||||
|
||||
# @builtin
|
||||
# def launch(fn, args, grid, num_warps=None, _builder=None):
|
||||
# constants = {i: x for i, x in enumerate(args) if isinstance(x, constexpr)}
|
||||
# args = [_to_ir(x, builder=_builder) for x in args if not isinstance(x, constexpr)]
|
||||
# grid = [_to_ir(x, builder=_builder) for x in grid]
|
||||
# if num_warps is None:
|
||||
# num_warps = _to_ir(4, builder=_builder)
|
||||
# return LaunchProxy(fn, args, constants, grid, num_warps)
|
||||
@builtin
|
||||
def printf(prefix, *args, _builder=None):
|
||||
import string
|
||||
new_prefix = prefix
|
||||
if isinstance(prefix, constexpr):
|
||||
new_prefix = prefix.value
|
||||
assert isinstance(new_prefix, str), f"{new_prefix} is not string"
|
||||
b_ascii = True
|
||||
for ch in new_prefix:
|
||||
if ch not in string.printable:
|
||||
b_ascii = False
|
||||
break
|
||||
assert b_ascii, f"{new_prefix} is not an ascii string"
|
||||
new_args = []
|
||||
for arg in args:
|
||||
new_args.append(_to_tensor(arg, _builder))
|
||||
return semantic.printf(new_prefix, new_args, _builder)
|
||||
|
@@ -6,7 +6,6 @@ from . import core, semantic
|
||||
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, _builder=None):
|
||||
'''
|
||||
Dispatch a function to a library
|
||||
|
||||
:param func: the function to dispatch
|
||||
:param lib_name: the name of the library
|
||||
:param lib_path: the path of the library
|
||||
@@ -14,7 +13,6 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
|
||||
:param arg_type_symbol_dict: the type of the arguments
|
||||
:param ret_shape: the shape of the return value
|
||||
:param _builder: the builder
|
||||
|
||||
:return: the return value of the function
|
||||
'''
|
||||
if len(arg_type_symbol_dict) == 0:
|
||||
@@ -42,20 +40,19 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
|
||||
else:
|
||||
symbol = arg_type_symbol_dict[arg_types][0]
|
||||
ret_type = arg_type_symbol_dict[arg_types][1]
|
||||
ret_type = core.block_type(ret_type, ret_shape) if ret_shape is not None else ret_type
|
||||
if ret_shape:
|
||||
ret_type = core.block_type(ret_type, ret_shape)
|
||||
return core.tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder)), ret_type)
|
||||
|
||||
|
||||
def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, _builder=None):
|
||||
'''
|
||||
Dispatch an elementwise function to a library
|
||||
|
||||
:param lib_name: the name of the library
|
||||
:param lib_path: the path of the library
|
||||
:param args: the arguments of the function
|
||||
:param arg_type_symbol_dict: the type of the arguments
|
||||
:param _builder: the builder
|
||||
|
||||
:return: the return value of the function
|
||||
'''
|
||||
dispatch_args = args.copy()
|
||||
@@ -87,27 +84,5 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||
dispatch_args[i], broadcast_arg, _builder)
|
||||
ret_shape = broadcast_arg.shape
|
||||
func = getattr(_builder, "create_extern_elementwise")
|
||||
func = getattr(_builder, "create_external_elementwise")
|
||||
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
|
||||
|
||||
|
||||
class ExternalFunction:
|
||||
'''
|
||||
A wrapper for external functions
|
||||
'''
|
||||
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if '_builder' not in kwargs or \
|
||||
kwargs['_builder'] is None:
|
||||
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
|
||||
def extern(fn):
|
||||
'''
|
||||
A decorator for external functions
|
||||
'''
|
||||
return ExternalFunction(fn)
|
||||
|
BIN
python/triton/language/libdevice.10.bc
Normal file → Executable file
BIN
python/triton/language/libdevice.10.bc
Normal file → Executable file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,10 @@
|
||||
import triton
|
||||
from . import core as tl
|
||||
|
||||
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
|
||||
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
|
||||
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
|
||||
PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57
|
||||
PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
|
||||
PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
|
||||
PHILOX_ROUND_A: tl.constexpr = 0xD2511F53
|
||||
PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57
|
||||
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
|
||||
|
||||
# -------------------
|
||||
|
@@ -7,12 +7,12 @@ from triton._C.libtriton.triton import ir
|
||||
|
||||
|
||||
# Create custom exception that prints message "hello"
|
||||
class IncompatibleTypeErrorimpl(Exception):
|
||||
class IncompatibleTypeErrorImpl(Exception):
|
||||
def __init__(self, type_a, type_b):
|
||||
self.type_a = type_a
|
||||
self.type_b = type_b
|
||||
self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__()
|
||||
super(IncompatibleTypeErrorimpl, self).__init__(self.message)
|
||||
super(IncompatibleTypeErrorImpl, self).__init__(self.message)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===##
|
||||
@@ -88,13 +88,13 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t
|
||||
def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
|
||||
if type_a.is_ptr():
|
||||
if not allow_ptr_a:
|
||||
raise IncompatibleTypeErrorimpl(type_a, type_b)
|
||||
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
||||
# T* + U* with T != U
|
||||
if type_b.is_ptr() and (type_a != type_b):
|
||||
raise IncompatibleTypeErrorimpl(type_a, type_b)
|
||||
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
||||
# T* + float
|
||||
if type_b.is_floating():
|
||||
raise IncompatibleTypeErrorimpl(type_a, type_b)
|
||||
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
||||
|
||||
|
||||
def binary_op_type_checking_impl(lhs: tl.tensor,
|
||||
@@ -129,7 +129,7 @@ def add(input: tl.tensor,
|
||||
if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
|
||||
input, other = other, input
|
||||
if input_scalar_ty.is_ptr():
|
||||
return tl.tensor(builder.create_gep(input.handle, [other.handle]), input.type)
|
||||
return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type)
|
||||
# float + float
|
||||
elif input_scalar_ty.is_floating():
|
||||
return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
|
||||
@@ -146,7 +146,7 @@ def sub(input: tl.tensor,
|
||||
scalar_ty = input.type.scalar
|
||||
# ptr - offset
|
||||
if scalar_ty.is_ptr():
|
||||
return tl.tensor(builder.create_gep(input.handle, [minus(other, builder).handle]),
|
||||
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle),
|
||||
input.type)
|
||||
# float - float
|
||||
if scalar_ty.is_floating():
|
||||
@@ -223,10 +223,9 @@ def fdiv(input: tl.tensor,
|
||||
input_scalar_ty = input.type.scalar
|
||||
other_scalar_ty = other.type.scalar
|
||||
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
|
||||
raise ValueError("both operands of fdiv must have floating poscalar type")
|
||||
raise ValueError("both operands of fdiv must have floating scalar type")
|
||||
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
|
||||
ret = builder.create_fdiv(input.handle, other.handle)
|
||||
ret.set_fdiv_ieee_rounding(ieee_rounding)
|
||||
return tl.tensor(ret, input.type)
|
||||
|
||||
|
||||
@@ -238,7 +237,11 @@ def mod(input: tl.tensor,
|
||||
other_scalar_ty = other.type.scalar
|
||||
# float % float
|
||||
if scalar_ty.is_floating():
|
||||
return tl.tensor(builder.create_frem(input.handle, other.handle), input.type)
|
||||
# input - input.div(other, rounding_mode="floor") * other
|
||||
ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder),
|
||||
other, builder),
|
||||
builder)
|
||||
return ret
|
||||
# % int
|
||||
elif scalar_ty.is_int():
|
||||
if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
|
||||
@@ -263,7 +266,7 @@ def bitwise_op_type_checking_impl(input: tl.tensor,
|
||||
input_sca_ty = input.type.scalar
|
||||
other_sca_ty = other.type.scalar
|
||||
if not input_sca_ty.is_int() or not other_sca_ty.is_int():
|
||||
raise IncompatibleTypeErrorimpl(input_sca_ty, other_sca_ty)
|
||||
raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
|
||||
ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty)
|
||||
if ret_sca_ty != input_sca_ty:
|
||||
input = cast(input, ret_sca_ty, builder)
|
||||
@@ -293,6 +296,22 @@ def xor_(input: tl.tensor,
|
||||
return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
|
||||
|
||||
|
||||
def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
if not input.type.is_int1():
|
||||
input = bitcast(input, tl.dtype("int1"), builder)
|
||||
if not other.type.is_int1():
|
||||
other = bitcast(other, tl.dtype("int1"), builder)
|
||||
return and_(input, other, builder)
|
||||
|
||||
|
||||
def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
if not input.type.is_int1():
|
||||
input = bitcast(input, tl.dtype("int1"), builder)
|
||||
if not other.type.is_int1():
|
||||
other = bitcast(other, tl.dtype("int1"), builder)
|
||||
return or_(input, other, builder)
|
||||
|
||||
|
||||
def lshr(input: tl.tensor,
|
||||
other: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
@@ -320,7 +339,7 @@ def minus(input: tl.tensor,
|
||||
input_sca_ty = input.type.scalar
|
||||
if input_sca_ty.is_ptr():
|
||||
raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
|
||||
_0 = tl.tensor(ir.constant.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
|
||||
_0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
|
||||
return sub(_0, input, builder)
|
||||
|
||||
|
||||
@@ -329,7 +348,7 @@ def invert(input: tl.tensor,
|
||||
input_sca_ty = input.type.scalar
|
||||
if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
|
||||
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
|
||||
_1 = tl.tensor(ir.constant.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
|
||||
_1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
|
||||
return xor_(input, _1, builder)
|
||||
|
||||
|
||||
@@ -449,11 +468,11 @@ def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
|
||||
|
||||
shape = [end - start]
|
||||
ret_ty = tl.block_type(tl.int32, shape)
|
||||
return tl.tensor(builder.get_range(start, end), ret_ty)
|
||||
return tl.tensor(builder.create_make_range(start, end), ret_ty)
|
||||
|
||||
|
||||
def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
_0 = ir.constant.get_null_value(dtype.to_ir(builder))
|
||||
_0 = builder.get_null_value(dtype.to_ir(builder))
|
||||
ret_ty = tl.block_type(dtype, shape)
|
||||
return tl.tensor(builder.create_splat(_0, shape), ret_ty)
|
||||
|
||||
@@ -462,24 +481,40 @@ def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def reshape(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
def view(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
# TODO: disable when TritonToTritonGPU handles views properly
|
||||
|
||||
# assert len(input.shape) == len(dst_shape)
|
||||
numel = 1
|
||||
for s in dst_shape:
|
||||
numel *= s
|
||||
if input.type.numel != numel:
|
||||
raise ValueError("cannot reshape block of different shape")
|
||||
raise ValueError("cannot view block of different shape")
|
||||
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
||||
return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty)
|
||||
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
|
||||
|
||||
|
||||
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
assert lhs.type.shape[1:] == rhs.type.shape[1:]
|
||||
ret_shape = [lhs.type.shape[0] + rhs.type.shape[0]]
|
||||
ret_ty = tl.block_type(lhs.type.scalar, ret_shape)
|
||||
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_ty)
|
||||
def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
dst_shape = [s for s in input.type.shape]
|
||||
dst_shape.insert(axis, 1)
|
||||
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
||||
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
|
||||
|
||||
|
||||
def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor:
|
||||
assert can_reorder, "current implementation of `cat` always may reorder elements"
|
||||
assert len(lhs.shape) == 1
|
||||
ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
|
||||
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type)
|
||||
|
||||
|
||||
def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
if len(input.shape) != 2:
|
||||
raise ValueError("Only 2D tensors can be transposed")
|
||||
ret_type = tl.block_type(input.type.scalar, [input.shape[1], input.shape[0]])
|
||||
return tl.tensor(builder.create_trans(input.handle), ret_type)
|
||||
|
||||
|
||||
def broadcast_impl_shape(input: tl.tensor,
|
||||
@@ -496,7 +531,7 @@ def broadcast_impl_shape(input: tl.tensor,
|
||||
for i in range(len(src_shape)):
|
||||
if shape[i] != src_shape[i] and src_shape[i] != 1:
|
||||
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
|
||||
f" must match the existing size ({src_shape[1]}) at non-singleton dimension"
|
||||
f" must match the existing size ({src_shape[i]}) at non-singleton dimension"
|
||||
f" {i}: {src_shape}, {shape}")
|
||||
ret_ty = tl.block_type(input.type.scalar, shape)
|
||||
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
|
||||
@@ -520,8 +555,21 @@ def broadcast_impl_value(lhs: tl.tensor,
|
||||
elif lhs_ty.is_block() and rhs_ty.is_block():
|
||||
lhs_shape = lhs_ty.get_block_shapes()
|
||||
rhs_shape = rhs_ty.get_block_shapes()
|
||||
if len(lhs_shape) != len(rhs_shape):
|
||||
raise ValueError("Cannot make_shape_compatible: blocks must have the same rank")
|
||||
|
||||
if len(lhs_shape) < len(rhs_shape):
|
||||
# Add new axes to lhs
|
||||
for dim in range(len(lhs_shape), len(rhs_shape)):
|
||||
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, dim), tl.block_type(lhs_ty.scalar, lhs_shape + [1]))
|
||||
lhs_ty = lhs.type
|
||||
lhs_shape = lhs_ty.get_block_shapes()
|
||||
elif len(rhs_shape) < len(lhs_shape):
|
||||
# Add new axes to rhs
|
||||
for dim in range(len(rhs_shape), len(lhs_shape)):
|
||||
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, dim), tl.block_type(rhs_ty.scalar, rhs_shape + [1]))
|
||||
rhs_ty = rhs.type
|
||||
rhs_shape = rhs_ty.get_block_shapes()
|
||||
assert len(rhs_shape) == len(lhs_shape)
|
||||
|
||||
ret_shape = []
|
||||
for i in range(len(lhs_shape)):
|
||||
left = lhs_shape[i]
|
||||
@@ -544,31 +592,6 @@ def broadcast_impl_value(lhs: tl.tensor,
|
||||
# (scalar, scalar) => returns original blocks
|
||||
return lhs, rhs
|
||||
|
||||
|
||||
#######
|
||||
# dequantize
|
||||
#######
|
||||
|
||||
def dequantize(input: tl.tensor,
|
||||
scale: tl.tensor,
|
||||
shift: tl.tensor,
|
||||
nbit: int,
|
||||
dst_ty: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
input_ty = input.type
|
||||
assert input_ty.is_block()
|
||||
assert input_ty.element_ty.is_int32() or input_ty.element_ty.is_int16()
|
||||
assert nbit in [2, 4, 8]
|
||||
assert dst_ty == tl.float16
|
||||
|
||||
shape = input_ty.get_block_shapes()
|
||||
factor = input_ty.element_ty.primitive_bitwidth // nbit
|
||||
dst_shape = shape[:-1] + [factor * shape[-1]]
|
||||
|
||||
dst_ty = tl.block_type(dst_ty, dst_shape)
|
||||
return tl.tensor(builder.create_dequantize(input.handle, scale.handle, shift.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
|
||||
#######
|
||||
# cast
|
||||
#######
|
||||
@@ -600,62 +623,78 @@ def cast(input: tl.tensor,
|
||||
dst_ty: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
src_ty = input.type
|
||||
if src_ty.is_block() and not dst_ty.is_block():
|
||||
if src_ty.is_block():
|
||||
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
|
||||
if src_ty == dst_ty:
|
||||
return input
|
||||
|
||||
src_sca_ty = src_ty.scalar
|
||||
dst_sca_ty = dst_ty.scalar
|
||||
# fp8 <=> bf16/fp16
|
||||
if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8():
|
||||
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()):
|
||||
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)),
|
||||
|
||||
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
|
||||
if (src_sca_ty.is_customized_floating() and dst_sca_ty.is_floating()) or \
|
||||
(src_sca_ty.is_floating() and dst_sca_ty.is_customized_floating()):
|
||||
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
# bf16 <=> (not fp32)
|
||||
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
|
||||
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
|
||||
if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
|
||||
(src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()):
|
||||
return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
|
||||
|
||||
# FP Truncation
|
||||
# Standard floating types' casting: truncation
|
||||
# fp64 => fp32, fp16, bf16
|
||||
# fp32 => fp16, bf16
|
||||
truncate_fp = src_sca_ty.is_floating() and \
|
||||
dst_sca_ty.is_floating() and \
|
||||
src_sca_ty.fp_mantissa_width > dst_sca_ty.fp_mantissa_width
|
||||
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
|
||||
if truncate_fp:
|
||||
return tl.tensor(builder.create_fp_trunc(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
# FP Extension
|
||||
# Standard floating types' casting: extension
|
||||
# fp32 => fp64
|
||||
# fp16 => fp32, fp64
|
||||
# bf16 => fp32, fp64
|
||||
ext_fp = src_sca_ty.is_floating() and \
|
||||
dst_sca_ty.is_floating() and \
|
||||
src_sca_ty.fp_mantissa_width < dst_sca_ty.fp_mantissa_width
|
||||
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
|
||||
if ext_fp:
|
||||
return tl.tensor(builder.create_fp_ext(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
# Int cast
|
||||
# Casting between integer types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
|
||||
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
|
||||
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
|
||||
return tl.tensor(builder.create_int_cast(input.handle,
|
||||
dst_ty.to_ir(builder), sign_extend),
|
||||
dst_ty)
|
||||
|
||||
# Float to Int
|
||||
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
|
||||
# TODO: is this correct?
|
||||
if dst_sca_ty.is_bool():
|
||||
return not_equal(input, tl._to_tensor(0, builder), builder)
|
||||
ty = input.dtype.to_ir(builder)
|
||||
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
||||
return not_equal(input, _0, builder)
|
||||
else:
|
||||
return tl.tensor(builder.create_int_cast(input.handle,
|
||||
dst_ty.to_ir(builder), sign_extend),
|
||||
dst_ty)
|
||||
|
||||
# Casting standard floating types to integer types
|
||||
if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
|
||||
if dst_sca_ty.is_bool():
|
||||
ty = input.dtype.to_ir(builder)
|
||||
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
||||
return not_equal(input, _0, builder)
|
||||
elif dst_sca_ty.is_int_signed():
|
||||
return tl.tensor(builder.create_fp_to_si(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
else:
|
||||
return tl.tensor(builder.create_fp_to_ui(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
# int => float
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_floating():
|
||||
# Casting integer types to standard floating types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
|
||||
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
|
||||
return tl.tensor(builder.create_ui_to_fp(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
@@ -665,7 +704,7 @@ def cast(input: tl.tensor,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
# ptr => int
|
||||
# Casting pointer types to integer types
|
||||
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
||||
bitwidth = dst_sca_ty.int_bitwidth
|
||||
if bitwidth == 64:
|
||||
@@ -676,19 +715,14 @@ def cast(input: tl.tensor,
|
||||
tl.tensor(builder.get_int64(0), tl.int64),
|
||||
builder)
|
||||
|
||||
if not src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
|
||||
# Casting integer types to pointer types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
|
||||
return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
# Ptr . Ptr
|
||||
|
||||
# Casting pointer types to pointer types
|
||||
if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
|
||||
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
# * . Bool
|
||||
if dst_sca_ty.is_bool():
|
||||
if src_sca_ty.is_ptr():
|
||||
input = cast(input, tl.int64, builder)
|
||||
other = builder.get_int64(0)
|
||||
if src_ty.is_bool():
|
||||
other = builder.create_splat(other, src_ty.get_block_shapes())
|
||||
return tl.tensor(builder.create_icmpNE(input.handle, other), dst_ty)
|
||||
|
||||
assert False, f'cannot cast {input} to {dst_ty}'
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -696,18 +730,6 @@ def cast(input: tl.tensor,
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def _parse_eviction_policy(eviction_policy):
|
||||
eviction = ir.EVICTION_POLICY.NORMAL # default
|
||||
if eviction_policy:
|
||||
if eviction_policy == "evict_last":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_LAST
|
||||
elif eviction_policy == "evict_first":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_FIRST
|
||||
else:
|
||||
raise ValueError(f"Eviction policy {eviction_policy} not supported")
|
||||
return eviction
|
||||
|
||||
|
||||
def load(ptr: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
other: Optional[tl.tensor],
|
||||
@@ -723,16 +745,18 @@ def load(ptr: tl.tensor,
|
||||
if other:
|
||||
other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)
|
||||
|
||||
if other:
|
||||
other = cast(other, ptr.type.scalar.element_ty, builder)
|
||||
ptr_ty = ptr.type.scalar
|
||||
elt_ty = ptr_ty.element_ty
|
||||
|
||||
# treat bool* as tl.int8*
|
||||
if elt_ty == tl.int1:
|
||||
elt_ty = tl.int8
|
||||
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
||||
ptr = cast(ptr, ptr_ty, builder)
|
||||
|
||||
if other:
|
||||
other = cast(other, elt_ty, builder)
|
||||
|
||||
# cache modifier
|
||||
cache = ir.CACHE_MODIFIER.NONE # default
|
||||
if cache_modifier:
|
||||
@@ -744,7 +768,14 @@ def load(ptr: tl.tensor,
|
||||
raise ValueError(f"Cache modifier {cache_modifier} not supported")
|
||||
|
||||
# eviction policy
|
||||
eviction = _parse_eviction_policy(eviction_policy)
|
||||
eviction = ir.EVICTION_POLICY.NORMAL # default
|
||||
if eviction_policy:
|
||||
if eviction_policy == "evict_last":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_LAST
|
||||
elif eviction_policy == "evict_first":
|
||||
eviction = ir.EVICTION_POLICY.EVICT_FIRST
|
||||
else:
|
||||
raise ValueError(f"Eviction policy {eviction_policy} not supported")
|
||||
|
||||
if ptr.type.is_block():
|
||||
shape = ptr.type.get_block_shapes()
|
||||
@@ -752,29 +783,22 @@ def load(ptr: tl.tensor,
|
||||
else:
|
||||
dst_ty = elt_ty
|
||||
|
||||
if not mask and not other:
|
||||
if not mask:
|
||||
if other:
|
||||
raise ValueError("`other` cannot be provided without `mask`")
|
||||
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile),
|
||||
dst_ty)
|
||||
if not mask:
|
||||
raise ValueError("`other` cannot be provided without `mask`")
|
||||
|
||||
if not other:
|
||||
other_ir = ir.undef.get(elt_ty.to_ir(builder))
|
||||
if ptr.type.is_block():
|
||||
other_ir = builder.create_splat(other_ir, ptr.type.get_block_shapes())
|
||||
other = tl.tensor(other_ir, dst_ty)
|
||||
|
||||
return tl.tensor(builder.create_masked_load(ptr.handle,
|
||||
mask.handle,
|
||||
other.handle,
|
||||
cache, eviction, is_volatile),
|
||||
dst_ty)
|
||||
else:
|
||||
return tl.tensor(builder.create_masked_load(ptr.handle,
|
||||
mask.handle,
|
||||
other.handle if other else None,
|
||||
cache, eviction, is_volatile),
|
||||
dst_ty)
|
||||
|
||||
|
||||
def store(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: Optional[tl.tensor],
|
||||
eviction_policy: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
if not ptr.type.scalar.is_ptr():
|
||||
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
||||
@@ -786,20 +810,17 @@ def store(ptr: tl.tensor,
|
||||
elt_ty = ptr_ty.element_ty
|
||||
# treat bool* as tl.int8*
|
||||
if elt_ty == tl.int1:
|
||||
# convert to bool first and then store as int8
|
||||
val = cast(val, tl.int1, builder)
|
||||
elt_ty = tl.int8
|
||||
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
||||
ptr = cast(ptr, ptr_ty, builder)
|
||||
# eviction policy
|
||||
eviction = _parse_eviction_policy(eviction_policy)
|
||||
|
||||
# cast to target data-type
|
||||
val = cast(val, elt_ty, builder)
|
||||
if not mask:
|
||||
return tl.tensor(builder.create_store(ptr.handle, val.handle, eviction), tl.void)
|
||||
return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void)
|
||||
if not mask.type.scalar.is_bool():
|
||||
raise ValueError("Mask must have boolean scalar type")
|
||||
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, eviction), tl.void)
|
||||
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void)
|
||||
|
||||
#########
|
||||
# atomic
|
||||
@@ -870,8 +891,8 @@ def atomic_max(ptr: tl.tensor,
|
||||
# return atomic_umin(i_ptr, i_val) if val < 0
|
||||
i_val = bitcast(val, tl.int32, builder)
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
|
||||
pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
|
||||
neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
|
||||
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
||||
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type)
|
||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type)
|
||||
return where(pos, pos_ret, neg_ret, builder)
|
||||
@@ -902,8 +923,8 @@ def atomic_min(ptr: tl.tensor,
|
||||
# return atomic_umax(i_ptr, i_val) if val < 0
|
||||
i_val = bitcast(val, tl.int32, builder)
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
|
||||
pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
|
||||
neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
|
||||
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
||||
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
|
||||
i_ptr.handle,
|
||||
i_val.handle,
|
||||
@@ -963,31 +984,28 @@ def atomic_xchg(ptr: tl.tensor,
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
def dot(a: tl.tensor,
|
||||
b: tl.tensor,
|
||||
trans_a: bool,
|
||||
trans_b: bool,
|
||||
def dot(lhs: tl.tensor,
|
||||
rhs: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
in_a = 1 if not trans_a else 0
|
||||
in_b = 1 if trans_b else 0
|
||||
assert a.type.is_block() and b.type.is_block()
|
||||
assert len(a.shape) == 2 and len(b.shape) == 2
|
||||
assert a.shape[in_a] == b.shape[in_b]
|
||||
assert a.shape[0] >= 16 and a.shape[1] >= 16 and b.shape[1] >= 16,\
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
|
||||
assert lhs.shape[1].value == rhs.shape[0].value
|
||||
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 16,\
|
||||
"small blocks not supported!"
|
||||
if a.type.scalar.is_int():
|
||||
if lhs.type.scalar.is_int():
|
||||
_0 = builder.get_int32(0)
|
||||
ret_scalar_ty = tl.int32
|
||||
else:
|
||||
_0 = builder.get_float32(0)
|
||||
ret_scalar_ty = tl.float32
|
||||
M = a.type.shape[in_a ^ 1]
|
||||
N = b.type.shape[in_b ^ 1]
|
||||
M = lhs.type.shape[0]
|
||||
N = rhs.type.shape[1]
|
||||
_0 = builder.create_splat(_0, [M, N])
|
||||
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
|
||||
ret = builder.create_dot(a.handle, b.handle, _0, trans_a, trans_b, allow_tf32)
|
||||
return tl.tensor(ret, ret_ty)
|
||||
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
||||
ret_ty)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
@@ -1010,11 +1028,11 @@ def where(condition: tl.tensor,
|
||||
ret_ty = x.type
|
||||
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Reductions
|
||||
# ===----------------------------------------------------------------------===
|
||||
|
||||
|
||||
def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
||||
FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor:
|
||||
scalar_ty = input.type.scalar
|
||||
@@ -1039,16 +1057,24 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
||||
if INT_OP in int_op_to_unit:
|
||||
INT_OP = int_op_to_unit[INT_OP]
|
||||
|
||||
# If we are doing an argmin or argmax we want to use an int32 output type
|
||||
out_scalar_ty = scalar_ty
|
||||
if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX:
|
||||
out_scalar_ty = tl.int32
|
||||
elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN:
|
||||
out_scalar_ty = tl.int32
|
||||
|
||||
# get result type
|
||||
shape = input.type.shape
|
||||
ret_shape = []
|
||||
for i, s in enumerate(shape):
|
||||
if i != axis:
|
||||
ret_shape.append(s)
|
||||
if len(ret_shape) == 0:
|
||||
res_ty = scalar_ty
|
||||
if ret_shape:
|
||||
res_ty = tl.block_type(out_scalar_ty, ret_shape)
|
||||
else:
|
||||
res_ty = tl.block_type(scalar_ty, ret_shape)
|
||||
# 0d-tensor -> scalar
|
||||
res_ty = out_scalar_ty
|
||||
|
||||
if scalar_ty.is_floating():
|
||||
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
|
||||
@@ -1084,25 +1110,19 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Utilities
|
||||
# -----------------------
|
||||
|
||||
def clock(builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_clock(), tl.int64)
|
||||
|
||||
|
||||
def globaltimer(builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_globaltimer, tl.int64)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===
|
||||
# Math
|
||||
# ===----------------------------------------------------------------------===
|
||||
|
||||
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
x, y = binary_op_type_checking_impl(x, y, builder)
|
||||
return tl.tensor(builder.create_umulhi(x.handle, y.handle), x.type)
|
||||
from . import libdevice
|
||||
return libdevice.mulhi(x, y, _builder=builder)
|
||||
|
||||
|
||||
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
from . import libdevice
|
||||
return libdevice.floor(x, _builder=builder)
|
||||
|
||||
|
||||
def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
@@ -1130,16 +1150,23 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||
if len(x.shape) != len(values):
|
||||
raise ValueError("Shape of input to multiple_of does not match the length of values")
|
||||
x.handle.multiple_of(values)
|
||||
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
|
||||
return x
|
||||
|
||||
|
||||
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||
if len(x.shape) != len(values):
|
||||
raise ValueError("Shape of input to max_contiguous does not match the length of values")
|
||||
x.handle.max_contiguous(values)
|
||||
x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
|
||||
return x
|
||||
|
||||
|
||||
def debug_barrier(builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_barrier(''), tl.void)
|
||||
return tl.tensor(builder.create_barrier(), tl.void)
|
||||
|
||||
|
||||
def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
|
||||
new_args = []
|
||||
for arg in args:
|
||||
new_args.append(arg.handle)
|
||||
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
|
||||
|
@@ -1,5 +1,12 @@
|
||||
# flake8: noqa: F401
|
||||
#from .conv import _conv, conv
|
||||
# from .conv import _conv, conv
|
||||
from . import blocksparse
|
||||
from .cross_entropy import _cross_entropy, cross_entropy
|
||||
from .matmul import _matmul, matmul
|
||||
|
||||
__all__ = [
|
||||
"blocksparse",
|
||||
"_cross_entropy",
|
||||
"cross_entropy",
|
||||
"_matmul",
|
||||
"matmul",
|
||||
]
|
||||
|
@@ -1,3 +1,7 @@
|
||||
# flake8: noqa: F401
|
||||
from .matmul import matmul
|
||||
from .softmax import softmax
|
||||
|
||||
__all__ = [
|
||||
"matmul",
|
||||
"softmax",
|
||||
]
|
||||
|
@@ -18,8 +18,8 @@ def num_warps(n):
|
||||
|
||||
@triton.jit
|
||||
def _blocksparse_softmax_fwd(
|
||||
Out, A, LUT, R, stride_xz,
|
||||
extent, stride_zr, stride_hr, # relative attention
|
||||
Out, A, stride_xz, LUT,
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
@@ -164,8 +164,8 @@ class _softmax(torch.autograd.Function):
|
||||
# enqueue kernel
|
||||
out = torch.empty_like(a)
|
||||
_blocksparse_softmax_fwd[grid](
|
||||
out, a, lut, rel_logits, a.stride(0),
|
||||
rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||
out, a, a.stride(0), lut,
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||
scale,
|
||||
is_causal,
|
||||
BLOCK_SIZE=block,
|
||||
|
@@ -10,7 +10,9 @@ from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcor
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
triton.compiler.init_cuda_utils()
|
||||
|
||||
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
|
||||
return tflops
|
||||
|
||||
@@ -18,14 +20,14 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
if cc < 80 and dtype == torch.float32:
|
||||
capability = torch.cuda.get_device_capability(device)
|
||||
if capability[0] < 8 and dtype == torch.float32:
|
||||
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
|
||||
@@ -59,7 +61,7 @@ def estimate_matmul_time(
|
||||
compute_ms = total_ops / tput
|
||||
|
||||
# time to load data
|
||||
num_sm = _triton.runtime.num_sm(backend, device)
|
||||
num_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"]
|
||||
active_cta_ratio = min(1, num_ctas / num_sm)
|
||||
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
|
||||
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
|
||||
@@ -97,9 +99,8 @@ def estimate_matmul_time(
|
||||
|
||||
|
||||
def early_config_prune(configs, named_args):
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
||||
dtsize = named_args['A'].element_size()
|
||||
dtype = named_args['A'].dtype
|
||||
@@ -110,7 +111,10 @@ def early_config_prune(configs, named_args):
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
|
||||
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
|
||||
# TODO: move to `cuda_utils` submodule
|
||||
triton.compiler.init_cuda_utils()
|
||||
max_shared_memory = triton.compiler.cuda_utils.get_device_properties(device)["max_shared_mem"]
|
||||
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
||||
if required_shared_memory <= max_shared_memory:
|
||||
pruned_configs.append(config)
|
||||
@@ -136,7 +140,7 @@ def early_config_prune(configs, named_args):
|
||||
pruned_configs = []
|
||||
for k, v in configs_map.items():
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
|
||||
if cc >= 80:
|
||||
if capability[0] >= 8:
|
||||
# compute cycles (only works for ampere GPUs)
|
||||
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
|
||||
mma_cycles = mmas / min(4, num_warps) * 8
|
||||
|
@@ -1,2 +1,12 @@
|
||||
from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401
|
||||
from .jit import JITFunction, KernelInterface, version_key # noqa: F401
|
||||
from .autotuner import Config, Heuristics, autotune, heuristics
|
||||
from .jit import JITFunction, KernelInterface, version_key
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
"Heuristics",
|
||||
"autotune",
|
||||
"heuristics",
|
||||
"JITFunction",
|
||||
"KernelInterface",
|
||||
"version_key",
|
||||
]
|
||||
|
@@ -4,6 +4,7 @@ import builtins
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from ..compiler import OutOfResources
|
||||
from ..testing import do_bench
|
||||
from .jit import KernelInterface
|
||||
|
||||
@@ -60,7 +61,10 @@ class Autotuner(KernelInterface):
|
||||
config.pre_hook(self.nargs)
|
||||
self.hook(args)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
return do_bench(kernel_call)
|
||||
try:
|
||||
return do_bench(kernel_call)
|
||||
except OutOfResources:
|
||||
return float('inf')
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
@@ -118,7 +122,6 @@ class Autotuner(KernelInterface):
|
||||
class Config:
|
||||
"""
|
||||
An object that represents a possible kernel configuration for the auto-tuner to try.
|
||||
|
||||
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
||||
:type meta: dict[Str, Any]
|
||||
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
|
||||
@@ -150,10 +153,8 @@ class Config:
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
@@ -164,12 +165,10 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
reset the value of the provided tensor to `zero` before running any configuration.
|
||||
|
||||
:param configs: a list of :code:`triton.Config` objects
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
@@ -204,16 +203,12 @@ def heuristics(values):
|
||||
"""
|
||||
Decorator for specifying how the values of certain meta-parameters may be computed.
|
||||
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
||||
|
||||
|
||||
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
||||
each such function takes a list of positional arguments as input.
|
||||
.type values: dict[str, Callable[[list[Any]], Any]]
|
||||
|
@@ -8,6 +8,7 @@ import os
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
|
||||
|
||||
import torch
|
||||
|
||||
@@ -19,6 +20,9 @@ try:
|
||||
except ImportError:
|
||||
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies Finder
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -94,20 +98,19 @@ def version_key():
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
class KernelInterface:
|
||||
class KernelInterface(Generic[T]):
|
||||
run: T
|
||||
|
||||
def __getitem__(self, grid):
|
||||
def __getitem__(self, grid) -> T:
|
||||
"""
|
||||
A JIT function is launched with: fn[grid](*args, **kwargs).
|
||||
Hence JITFunction.__getitem__ returns a callable proxy that
|
||||
memorizes the grid.
|
||||
"""
|
||||
def launcher(*args, **kwargs):
|
||||
return self.run(*args, grid=grid, **kwargs)
|
||||
return launcher
|
||||
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
|
||||
|
||||
|
||||
class JITFunction(KernelInterface):
|
||||
class JITFunction(KernelInterface[T]):
|
||||
|
||||
# Hook for inspecting compiled functions and modules
|
||||
cache_hook = None
|
||||
@@ -152,8 +155,8 @@ class JITFunction(KernelInterface):
|
||||
if x is None:
|
||||
return True
|
||||
return False
|
||||
divisible_by_16 = [i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize]
|
||||
equal_to_1 = [i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize]
|
||||
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
||||
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
|
||||
|
||||
@@ -177,6 +180,9 @@ class JITFunction(KernelInterface):
|
||||
triton.language.uint32: 'u32',
|
||||
triton.language.uint64: 'u64',
|
||||
triton.language.float8: 'fp8',
|
||||
triton.language.float16: 'fp16',
|
||||
triton.language.bfloat16: 'bf16',
|
||||
triton.language.float32: 'fp32',
|
||||
}[key]
|
||||
return f'*{ty}'
|
||||
if key is None:
|
||||
@@ -272,7 +278,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
bin = triton.compile(self, signature, device, constants, constexpr_key, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
||||
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args, constexpr_key)
|
||||
self.cache[device][key] = bin
|
||||
@@ -364,29 +370,55 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def jit(*args, **kwargs):
|
||||
@overload
|
||||
def jit(fn: T) -> JITFunction[T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def jit(
|
||||
*,
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
) -> Callable[[T], JITFunction[T]]:
|
||||
...
|
||||
|
||||
|
||||
def jit(
|
||||
fn: Optional[T] = None,
|
||||
*,
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
|
||||
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||
:note: When a jit'd function is called, :code:`torch.tensor` arguments are
|
||||
implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||
|
||||
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||
|
||||
* python primitives,
|
||||
* objects within the triton.language package,
|
||||
* builtins within the triton package,
|
||||
* arguments to this function,
|
||||
* other jit'd functions
|
||||
|
||||
:param fn: the function to be jit-compiled
|
||||
:type fn: Callable
|
||||
"""
|
||||
if args:
|
||||
assert len(args) == 1
|
||||
assert callable(args[0])
|
||||
return JITFunction(args[0], **kwargs)
|
||||
|
||||
def decorator(fn: T) -> JITFunction[T]:
|
||||
assert callable(fn)
|
||||
return JITFunction(
|
||||
fn,
|
||||
version=version,
|
||||
do_not_specialize=do_not_specialize,
|
||||
)
|
||||
|
||||
if fn is not None:
|
||||
return decorator(fn)
|
||||
|
||||
else:
|
||||
def decorator(fn):
|
||||
return JITFunction(fn, **kwargs)
|
||||
return decorator
|
||||
|
||||
|
||||
|
@@ -16,6 +16,9 @@ except ImportError:
|
||||
_cutlass = None
|
||||
has_cutlass = False
|
||||
|
||||
# TODO: move to separate module
|
||||
import triton
|
||||
|
||||
|
||||
def catch_oor(kernel, pytest_handle=None):
|
||||
try:
|
||||
@@ -34,12 +37,12 @@ def sparsify_tensor(x, mask, block):
|
||||
return ret
|
||||
|
||||
|
||||
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None):
|
||||
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
|
||||
if data is None:
|
||||
data = torch.randn(shape, dtype=torch.float32, device=device)
|
||||
data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
|
||||
ref_ret = data
|
||||
ref_ret = ref_ret * alpha + beta
|
||||
ref_ret = ref_ret.half().float()
|
||||
ref_ret = ref_ret.half().to(dtype)
|
||||
if trans:
|
||||
ref_ret = ref_ret.t().requires_grad_()
|
||||
ref_ret = ref_ret.detach().requires_grad_()
|
||||
@@ -336,8 +339,8 @@ def get_dram_gbps(backend=None, device=None):
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
|
||||
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
|
||||
mem_clock_khz = triton.compiler.cuda_utils.get_device_properties(device)["mem_clock_rate"] # in kHz
|
||||
bus_width = triton.compiler.cuda_utils.get_device_properties(device)["mem_bus_width"]
|
||||
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
|
||||
return bw_gbps
|
||||
|
||||
@@ -347,11 +350,13 @@ def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clo
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
|
||||
triton.compiler.init_cuda_utils()
|
||||
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4
|
||||
if not clock_rate:
|
||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
if cc < 80:
|
||||
clock_rate = triton.compiler.cuda_utils.get_device_properties(device)["sm_clock_rate"] # in kHz
|
||||
capability = torch.cuda.get_device_capability(device)
|
||||
if capability[0] < 8:
|
||||
assert dtype == torch.float16
|
||||
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
|
||||
else:
|
||||
|
61
python/triton/tools/aot.py
Normal file
61
python/triton/tools/aot.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import argparse
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as libtriton
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# valid source and target formats
|
||||
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx']
|
||||
|
||||
# set up the argument parser
|
||||
# TODO: conditional requirements
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('src', help="Source file to compile")
|
||||
parser.add_argument('--target', required=True,
|
||||
help="Target format, one of: " + ', '.join(VALID_FORMATS))
|
||||
parser.add_argument('--sm', type=int, help="Compute capability to compile for")
|
||||
parser.add_argument('--ptx-version', type=int, help="PTX version to compile for")
|
||||
|
||||
# parse the args
|
||||
args = parser.parse_args()
|
||||
|
||||
# TODO: clean-up and re-use triton.compiler primitive functions
|
||||
# check for validity of format arguments
|
||||
if args.target not in VALID_FORMATS:
|
||||
print("Invalid target format: " + args.target)
|
||||
exit(0)
|
||||
|
||||
# parse source file to MLIR module
|
||||
context = libtriton.ir.context()
|
||||
module = libtriton.ir.parse_mlir_module(args.src, context)
|
||||
module.context = context
|
||||
|
||||
# optimizer triton-ir
|
||||
module = triton.compiler.optimize_triton_ir(module)
|
||||
if args.target == 'triton-ir':
|
||||
print(module.str())
|
||||
exit(0)
|
||||
|
||||
if not args.sm:
|
||||
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm)
|
||||
if args.target == 'triton-gpu-ir':
|
||||
print(module.str())
|
||||
exit(0)
|
||||
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
module = triton.compiler.ttgir_to_llir(module, extern_libs=None, compute_capability=args.sm)
|
||||
if args.target == 'llvm-ir':
|
||||
print(module)
|
||||
exit(0)
|
||||
|
||||
if not args.ptx_version:
|
||||
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
||||
|
||||
# llvm-ir -> ptx
|
||||
module = triton.compiler.llir_to_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
|
||||
assert args.target == 'ptx'
|
||||
print(module)
|
@@ -21,7 +21,6 @@ class Symbol:
|
||||
) -> None:
|
||||
'''
|
||||
A symbol is a function declaration.
|
||||
|
||||
:param name: name of the symbol
|
||||
:param op_name: name of the operation
|
||||
:param ret_type: return type of the operation
|
||||
@@ -65,9 +64,9 @@ def convert_type(type_str) -> Optional[str]:
|
||||
elif type_str == "u64":
|
||||
return "uint64"
|
||||
elif type_str == "float":
|
||||
return "float32"
|
||||
return "fp32"
|
||||
elif type_str == "double":
|
||||
return "float64"
|
||||
return "fp64"
|
||||
else:
|
||||
# ignore other types, such as pointer types
|
||||
return None
|
||||
@@ -98,7 +97,6 @@ class ExternLibrary(ABC):
|
||||
) -> None:
|
||||
'''
|
||||
Abstract class for extern library.
|
||||
|
||||
:param name: name of the library
|
||||
:param path: path of the library
|
||||
:param format: whether to format the generated stub file
|
||||
@@ -154,7 +152,6 @@ class Libdevice(ExternLibrary):
|
||||
def __init__(self, path) -> None:
|
||||
'''
|
||||
Constructor for Libdevice.
|
||||
|
||||
:param path: path of the libdevice library
|
||||
'''
|
||||
super().__init__("libdevice", path)
|
||||
@@ -177,7 +174,6 @@ class Libdevice(ExternLibrary):
|
||||
func_strs = func_str.split("(")
|
||||
func_name = func_strs[0].replace("@", "")
|
||||
op_name = func_name.replace("__nv_", "")
|
||||
# To filter some interfaces unlisted in NVIDIA's official documents.
|
||||
if 'ieee' in op_name:
|
||||
return None
|
||||
# Get arg_types
|
||||
@@ -310,8 +306,8 @@ class Libdevice(ExternLibrary):
|
||||
for symbol in symbols:
|
||||
arg_type_symbol_dict_str += "("
|
||||
for arg_type in symbol.arg_types:
|
||||
arg_type_symbol_dict_str += f"core.{arg_type},"
|
||||
ret_type = f"core.{symbol.ret_type}"
|
||||
arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),'
|
||||
ret_type = f'core.dtype("{symbol.ret_type}")'
|
||||
arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n"
|
||||
arg_type_symbol_dict_str += "}"
|
||||
|
||||
@@ -331,7 +327,6 @@ class LLVMDisassembler:
|
||||
def __init__(self, path) -> None:
|
||||
'''
|
||||
Invoke llvm-dis to disassemble the given file.
|
||||
|
||||
:param path: path to llvm-dis
|
||||
'''
|
||||
self._path = path
|
||||
@@ -361,7 +356,6 @@ def build(
|
||||
) -> None:
|
||||
'''
|
||||
Interface function to build the library file.
|
||||
|
||||
:param llvm_dis_path: path to the llvm-dis binary
|
||||
:param lib_path: path to the external library file
|
||||
:param lib_name: name of the library
|
||||
|
@@ -1,76 +0,0 @@
|
||||
'''
|
||||
Compare cached triton kernels in 2 directories.
|
||||
|
||||
example:
|
||||
python compare_asm.py --dir0=triton-works/ --dir1=triton-fails/ --asm=ttir \
|
||||
--diff-out0=diff-works.ll --diff-out1=diff-fails.ll
|
||||
'''
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
|
||||
parser = argparse.ArgumentParser(description="unpickle")
|
||||
parser.add_argument('--dir0', dest='dir0', required=True,
|
||||
help="Triton cache dir 0")
|
||||
parser.add_argument('--dir1', dest='dir1', required=True,
|
||||
help="Triton cache dir 1")
|
||||
parser.add_argument('--asm', dest='asm',
|
||||
choices=['ttir', 'llir', 'ptx', 'cubin'], required=True)
|
||||
parser.add_argument('--early-stop', dest='early_stop', action='store_true',
|
||||
help="Stop after first diff")
|
||||
parser.set_defaults(early_stop=True)
|
||||
parser.add_argument('--diff-out0', dest='diff_out0', required=True,
|
||||
help="output file path for kernels in dir0")
|
||||
parser.add_argument('--diff-out1', dest='diff_out1', required=True,
|
||||
help="output file path for kernels in dir1")
|
||||
args = parser.parse_args()
|
||||
dir0 = args.dir0
|
||||
dir1 = args.dir1
|
||||
asm = args.asm
|
||||
|
||||
dir0_files = {}
|
||||
dir1_files = {}
|
||||
for root, _, files in os.walk(dir0):
|
||||
for file in files:
|
||||
if not file.endswith('.lock'):
|
||||
path = os.path.join(root, file)
|
||||
with open(path, 'rb') as f:
|
||||
loaded_file = pickle.load(f)
|
||||
bin = loaded_file['binary']
|
||||
key = loaded_file['key']
|
||||
info = key.split('-')[-3:] # num_warps, num_stages, signature
|
||||
dict_key = bin.name + '-'.join(info)
|
||||
dir0_files[dict_key] = bin.asm
|
||||
|
||||
for root, _, files in os.walk(dir1):
|
||||
for file in files:
|
||||
if not file.endswith('.lock'):
|
||||
path = os.path.join(root, file)
|
||||
with open(path, 'rb') as f:
|
||||
loaded_file = pickle.load(f)
|
||||
bin = loaded_file['binary']
|
||||
key = loaded_file['key']
|
||||
info = key.split('-')[-3:] # num_warps, num_stages, signature
|
||||
dict_key = bin.name + '-'.join(info)
|
||||
dir1_files[dict_key] = bin.asm
|
||||
|
||||
diff_keys = []
|
||||
for key in dir0_files:
|
||||
asm0 = dir0_files[key]
|
||||
if key not in dir1_files:
|
||||
continue
|
||||
asm1 = dir1_files[key]
|
||||
if asm0[asm] != asm1[asm]:
|
||||
diff_keys.append(key)
|
||||
|
||||
if args.early_stops:
|
||||
diff_keys = diff_keys[:1]
|
||||
if diff_keys:
|
||||
with open(args.diff_out0, 'w') as f0, open(args.diff_out1, 'w') as f1:
|
||||
for key in diff_keys:
|
||||
f0.write(f'{asm} mismatch at {key}')
|
||||
f0.write(dir0_files[key][asm])
|
||||
f0.write('\n')
|
||||
f1.write(f'{asm} mismatch at {key}')
|
||||
f1.write(dir1_files[key][asm])
|
||||
f1.write('\n')
|
@@ -80,7 +80,7 @@ def softmax_kernel(
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
|
||||
# Subtract maximum for numerical stability
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
|
||||
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
@@ -188,4 +188,4 @@ benchmark.run(show_plots=True, print_data=True)
|
||||
#
|
||||
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
|
||||
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
|
||||
# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
|
||||
# Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape.
|
||||
|
@@ -156,16 +156,7 @@ import triton.language as tl
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@@ -236,8 +227,8 @@ def matmul_kernel(
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
# you can fuse arbitrary activation functions here
|
||||
# while the accumulator is still in FP32!
|
||||
if ACTIVATION == "leaky_relu":
|
||||
accumulator = leaky_relu(accumulator)
|
||||
if ACTIVATION:
|
||||
accumulator = ACTIVATION(accumulator)
|
||||
c = accumulator.to(tl.float16)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
@@ -252,7 +243,6 @@ def matmul_kernel(
|
||||
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
x = x + 1
|
||||
return tl.where(x >= 0, x, 0.01 * x)
|
||||
|
||||
|
||||
@@ -261,7 +251,7 @@ def leaky_relu(x):
|
||||
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
|
||||
|
||||
|
||||
def matmul(a, b, activation=""):
|
||||
def matmul(a, b, activation=None):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
assert a.is_contiguous(), "matrix A must be contiguous"
|
||||
@@ -297,7 +287,7 @@ def matmul(a, b, activation=""):
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
triton_output = matmul(a, b)
|
||||
triton_output = matmul(a, b, activation=None)
|
||||
torch_output = torch.matmul(a, b)
|
||||
print(f"triton_output={triton_output}")
|
||||
print(f"torch_output={torch_output}")
|
||||
@@ -319,13 +309,13 @@ else:
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[
|
||||
128 * i for i in range(2, 33)
|
||||
8192
|
||||
], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
# possible values for `line_arg``
|
||||
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
|
||||
line_vals=['cublas', 'triton'],
|
||||
# label name for the lines
|
||||
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
|
||||
line_names=["cuBLAS", "Triton"],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
@@ -337,18 +327,9 @@ def benchmark(M, N, K, provider):
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
||||
if provider == 'cublas + relu':
|
||||
torch_relu = torch.nn.ReLU(inplace=True)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch_relu(torch.matmul(a, b))
|
||||
)
|
||||
if provider == 'triton + relu':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: matmul(a, b, activation="leaky_relu")
|
||||
)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100)
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
@@ -19,8 +19,8 @@ except ModuleNotFoundError:
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_fused(
|
||||
Out,
|
||||
A,
|
||||
Out,
|
||||
Weight,
|
||||
Bias,
|
||||
Mean, Rstd,
|
||||
@@ -36,14 +36,14 @@ def _layer_norm_fwd_fused(
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
_mean += a
|
||||
mean = tl.sum(_mean, axis=0) / N
|
||||
# compute variance
|
||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
|
||||
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
a = tl.where(cols < N, a - mean, 0.)
|
||||
_var += a * a
|
||||
var = tl.sum(_var, axis=0) / N
|
||||
@@ -57,192 +57,155 @@ def _layer_norm_fwd_fused(
|
||||
mask = cols < N
|
||||
weight = tl.load(Weight + cols, mask=mask)
|
||||
bias = tl.load(Bias + cols, mask=mask)
|
||||
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32)
|
||||
a = tl.load(A + cols, mask=mask, other=0.).to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
out = a_hat * weight + bias
|
||||
# # write-back
|
||||
tl.store(Out + cols, out, mask=mask)
|
||||
|
||||
# Backward pass (DA + partial DW + partial DB)
|
||||
|
||||
|
||||
# Backward pass (DX + partial DW + partial DB)
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dx_fused(
|
||||
_DA,
|
||||
_DOut,
|
||||
_A,
|
||||
Weight,
|
||||
Mean, Rstd,
|
||||
stride, NumRows, NumCols, eps,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||
# position of elements processed by this program
|
||||
pid = tl.program_id(0)
|
||||
row = pid
|
||||
A = _A + row * stride
|
||||
DOut = _DOut + row * stride
|
||||
DA = _DA + row * stride
|
||||
mean = tl.load(Mean + row)
|
||||
rstd = tl.load(Rstd + row)
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < N
|
||||
# offset data pointers to start at the row of interest
|
||||
X += row * stride
|
||||
DY += row * stride
|
||||
DX += row * stride
|
||||
# offset locks and weight/bias gradient pointer
|
||||
# each kernel instance accumulates partial sums for
|
||||
# DW and DB into one of GROUP_SIZE_M independent buffers
|
||||
# these buffers stay in the L2, which allow this kernel
|
||||
# to be fast
|
||||
lock_id = row % GROUP_SIZE_M
|
||||
Lock += lock_id
|
||||
Count = Lock + GROUP_SIZE_M
|
||||
DW = DW + lock_id * N + cols
|
||||
DB = DB + lock_id * N + cols
|
||||
# load data to SRAM
|
||||
_mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
|
||||
_mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
|
||||
for off in range(0, NumCols, BLOCK_SIZE_N):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < NumCols
|
||||
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
|
||||
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
|
||||
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
wdout = weight * dout
|
||||
_mean1 += a_hat * wdout
|
||||
_mean2 += wdout
|
||||
mean1 = tl.sum(_mean1, axis=0) / NumCols
|
||||
mean2 = 0.
|
||||
mean2 = tl.sum(_mean2, axis=0) / NumCols
|
||||
for off in range(0, NumCols, BLOCK_SIZE_N):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < NumCols
|
||||
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
|
||||
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
|
||||
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
|
||||
a_hat = (a - mean) * rstd
|
||||
wdout = weight * dout
|
||||
da = (wdout - (a_hat * mean1 + mean2)) * rstd
|
||||
# write-back dx
|
||||
tl.store(DA + cols, da, mask=mask)
|
||||
|
||||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
||||
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
mean = tl.load(M + row)
|
||||
rstd = tl.load(V + row)
|
||||
# compute dx
|
||||
xhat = (x - mean) * rstd
|
||||
wdy = w * dy
|
||||
xhat = tl.where(mask, xhat, 0.)
|
||||
wdy = tl.where(mask, wdy, 0.)
|
||||
mean1 = tl.sum(xhat * wdy, axis=0) / N
|
||||
mean2 = tl.sum(wdy, axis=0) / N
|
||||
dx = (wdy - (xhat * mean1 + mean2)) * rstd
|
||||
# write-back dx
|
||||
tl.store(DX + cols, dx, mask=mask)
|
||||
# accumulate partial sums for dw/db
|
||||
partial_dw = (dy * xhat).to(w.dtype)
|
||||
partial_db = (dy).to(w.dtype)
|
||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
pass
|
||||
count = tl.load(Count)
|
||||
# first store doesn't accumulate
|
||||
if count == 0:
|
||||
tl.atomic_xchg(Count, 1)
|
||||
else:
|
||||
partial_dw += tl.load(DW, mask=mask)
|
||||
partial_db += tl.load(DB, mask=mask)
|
||||
tl.store(DW, partial_dw, mask=mask)
|
||||
tl.store(DB, partial_db, mask=mask)
|
||||
# release lock
|
||||
tl.atomic_xchg(Lock, 0)
|
||||
|
||||
# Backward pass (total DW + total DB)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dwdb(
|
||||
A, DOut,
|
||||
Mean, Var,
|
||||
DW,
|
||||
DB,
|
||||
M, N,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
UNROLL: tl.constexpr = 4
|
||||
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
|
||||
for j in range(UNROLL):
|
||||
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
|
||||
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw += dout * a_hat
|
||||
db += dout
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
dw += tl.load(DW + offs, mask=mask, other=0.)
|
||||
db += tl.load(DB + offs, mask=mask, other=0.)
|
||||
sum_dw = tl.sum(dw, axis=0)
|
||||
sum_db = tl.sum(db, axis=0)
|
||||
tl.store(DW + cols, sum_dw, mask=cols < N)
|
||||
tl.store(DB + cols, sum_db, mask=cols < N)
|
||||
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
|
||||
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
|
||||
|
||||
|
||||
class LayerNorm(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, normalized_shape, weight, bias, eps):
|
||||
def forward(ctx, x, normalized_shape, weight, bias, eps):
|
||||
# allocate output
|
||||
out = torch.empty_like(a)
|
||||
y = torch.empty_like(x)
|
||||
# reshape input data into 2D tensor
|
||||
a_arg = a.reshape(-1, a.shape[-1])
|
||||
M, N = a_arg.shape
|
||||
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
||||
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // a.element_size()
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
BLOCK_SIZE = max(BLOCK_SIZE, 128)
|
||||
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
|
||||
if N > BLOCK_SIZE:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
_layer_norm_fwd_fused[(M,)](
|
||||
out,
|
||||
a_arg,
|
||||
weight,
|
||||
bias,
|
||||
mean, rstd,
|
||||
a_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
ctx.save_for_backward(
|
||||
a, weight, bias, mean, rstd,
|
||||
)
|
||||
# enqueue kernel
|
||||
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
|
||||
x_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.eps = eps
|
||||
if hasattr(bias, "config"):
|
||||
assert bias.config.grad_scale_name == weight.config.grad_scale_name
|
||||
grad_scale_name = bias.config.grad_scale_name
|
||||
else:
|
||||
grad_scale_name = None
|
||||
ctx.grad_scale_gain_bias_name = grad_scale_name
|
||||
return out
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
assert dout.is_contiguous()
|
||||
a, weight, bias, mean, var = ctx.saved_tensors
|
||||
def backward(ctx, dy):
|
||||
x, w, b, m, v = ctx.saved_tensors
|
||||
# heuristics for amount of parallel reduction stream for DG/DB
|
||||
N = weight.shape[0]
|
||||
N = w.shape[0]
|
||||
GROUP_SIZE_M = 64
|
||||
if N <= 8192: GROUP_SIZE_M = 96
|
||||
if N <= 4096: GROUP_SIZE_M = 128
|
||||
if N <= 1024: GROUP_SIZE_M = 256
|
||||
# allocate output
|
||||
da = torch.empty_like(dout)
|
||||
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
|
||||
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
||||
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
||||
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
dx = torch.empty_like(dy)
|
||||
# enqueue kernel using forward pass heuristics
|
||||
# also compute partial sums for DW and DB
|
||||
x_arg = a.reshape(-1, a.shape[-1])
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
|
||||
_layer_norm_bwd_dx_fused[(M,)](
|
||||
da,
|
||||
dout,
|
||||
a,
|
||||
weight,
|
||||
mean, var,
|
||||
x_arg.stride(0), M, N,
|
||||
ctx.eps,
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
)
|
||||
if N > 10240:
|
||||
BLOCK_SIZE_N = 128
|
||||
BLOCK_SIZE_M = 32
|
||||
num_warps = 4
|
||||
else:
|
||||
# maximize occupancy for small N
|
||||
BLOCK_SIZE_N = 16
|
||||
BLOCK_SIZE_M = 16
|
||||
num_warps = 8
|
||||
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||
_layer_norm_bwd_dwdb[grid](
|
||||
a, dout,
|
||||
mean, var,
|
||||
dweight,
|
||||
dbias,
|
||||
M,
|
||||
N,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
num_warps=num_warps
|
||||
)
|
||||
return (da, None, dweight, dbias, None)
|
||||
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
|
||||
x_arg.stride(0), N, ctx.eps,
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
GROUP_SIZE_M=GROUP_SIZE_M,
|
||||
num_warps=ctx.num_warps)
|
||||
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
|
||||
# accumulate partial sums in separate kernel
|
||||
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=128)
|
||||
return dx, None, dw, db, None
|
||||
|
||||
|
||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
|
||||
layer_norm = LayerNorm.apply
|
||||
|
||||
|
||||
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
torch.manual_seed(0)
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
@@ -277,11 +240,11 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
|
||||
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
|
||||
ylabel='GB/s',
|
||||
plot_name='layer-norm',
|
||||
args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}
|
||||
plot_name='layer-norm-backward',
|
||||
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
|
||||
)
|
||||
)
|
||||
def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):
|
||||
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
@@ -311,5 +274,5 @@ def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
# test_layer_norm(1151, 8192, torch.float16)
|
||||
bench_layer_norm.run(save_path='.', print_data=True)
|
||||
test_layer_norm(1151, 8192, torch.float16)
|
||||
# bench_layer_norm.run(save_path='.', print_data=True)
|
||||
|
@@ -15,7 +15,7 @@ import triton.language as tl
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to work around a compiler bug
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
@@ -39,7 +39,6 @@ def _fwd_kernel(
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
t_ptrs = TMP + off_hz * N_CTX + offs_m
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
@@ -47,11 +46,11 @@ def _fwd_kernel(
|
||||
q = tl.load(q_ptrs)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k, trans_b=True)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
qk *= sm_scale
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
@@ -69,8 +68,6 @@ def _fwd_kernel(
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + start_n * stride_vk)
|
||||
@@ -168,26 +165,26 @@ def _bwd_kernel(
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, k, trans_b=True)
|
||||
qk = tl.dot(q, tl.trans(k))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
|
||||
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, v, trans_b=True)
|
||||
dp += tl.dot(do, tl.trans(v))
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
|
||||
# # compute dq
|
||||
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
||||
dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
|
||||
# compute dq
|
||||
dq = tl.load(dq_ptrs)
|
||||
dq += tl.dot(ds.to(tl.float16), k)
|
||||
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
||||
# # increment pointers
|
||||
tl.store(dq_ptrs, dq)
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
@@ -198,6 +195,9 @@ def _bwd_kernel(
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
empty = torch.empty(128, device="cuda")
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@@ -208,7 +208,7 @@ class _attention(torch.autograd.Function):
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
@@ -227,6 +227,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.BLOCK = BLOCK
|
||||
ctx.grid = grid
|
||||
@@ -272,13 +273,13 @@ class _attention(torch.autograd.Function):
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
sm_scale = 0.3
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
|
||||
sm_scale = 0.2
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
@@ -287,13 +288,16 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
# p = torch.exp(p)
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# triton implementation
|
||||
# # triton implementation
|
||||
tri_out = attention(q, k, v, sm_scale)
|
||||
# print(ref_out)
|
||||
# print(tri_out)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
@@ -323,7 +327,7 @@ configs = [triton.testing.Benchmark(
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||
) for mode in ['bwd']]
|
||||
) for mode in ['fwd']]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
@@ -356,5 +360,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
# only works on A100 at the moment
|
||||
# bench_flash_attention.run(save_path='.', print_data=True)
|
||||
|
@@ -1,74 +0,0 @@
|
||||
"""
|
||||
Libdevice function
|
||||
===============
|
||||
Triton can invoke a custom function from an external library.
|
||||
In this example, we will use the `libdevice` library to apply `asin` on a tensor.
|
||||
Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.
|
||||
|
||||
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
|
||||
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
|
||||
Using triton, you can simply call `tl.libdevice.asin`.
|
||||
triton automatically selects the correct underlying device function to invoke based on input and output types.
|
||||
"""
|
||||
|
||||
# %%
|
||||
# asin Kernel
|
||||
# --------------------------
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def asin_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
x = tl.libdevice.asin(x)
|
||||
tl.store(y_ptr + offsets, x, mask=mask)
|
||||
|
||||
# %%
|
||||
# Using the default libdevice library path
|
||||
# --------------------------
|
||||
# We can use the default libdevice library path encoded in `triton/language/libdevice.py`
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
size = 98432
|
||||
x = torch.rand(size, device='cuda')
|
||||
output_triton = torch.zeros(size, device='cuda')
|
||||
output_torch = torch.asin(x)
|
||||
assert x.is_cuda and output_triton.is_cuda
|
||||
n_elements = output_torch.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
||||
|
||||
# %%
|
||||
# Customize the libdevice library path
|
||||
# --------------------------
|
||||
# We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.
|
||||
|
||||
output_triton = torch.empty_like(x)
|
||||
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,
|
||||
extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})
|
||||
print(output_torch)
|
||||
print(output_triton)
|
||||
print(
|
||||
f'The maximum difference between torch and triton is '
|
||||
f'{torch.max(torch.abs(output_torch - output_triton))}'
|
||||
)
|
Reference in New Issue
Block a user