Merge branch 'master' into keren/improve-hook

This commit is contained in:
Jokeren
2023-01-04 17:40:10 -05:00
285 changed files with 26795 additions and 50131 deletions

View File

@@ -1,5 +0,0 @@
# Run the benchmarks
Install the required dependencies via `pip install -r requirements-bench.txt` from the triton/python/bench folder.
Run the benchmarks through `python3 bench/run.py`, this will produce an HTML report in a results folder.

View File

@@ -1,92 +0,0 @@
import torch
import triton
# -------------------------------
# Matrix Multiplication
# -------------------------------
nt = {False: 'n', True: 't'}
square_confs = [
triton.testing.Benchmark(
x_names=['M', 'N', 'K'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg='block',
line_vals=[16, 32, 64, 128],
line_names=['Block16', 'Block32', 'Block64', 'Block128'],
ylabel='TFLOPS',
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args={'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)
for AT in [False] for BT in [False]
for op_mode in ['dsd'] for layout_mode in ['dense']
]
@triton.testing.perf_report(square_confs)
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
Z, H = 1, 1
make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode]
# create layout
shape = {'sdd': (M, N), 'dsd': (K, M) if AT else (M, K), 'dds': (N, K) if BT else (K, N)}[op_mode]
layout = make_layout(H, shape[0] // block, shape[1] // block)
# creat inputs
a = torch.randn((Z, H, K, M) if AT else (Z, H, M, K), dtype=dtype, device='cuda')
b = torch.randn((Z, H, N, K) if BT else (Z, H, K, N), dtype=dtype, device='cuda')
# create op
tflops = lambda ms: num_flops / ms * 1e3
if provider == 'triton':
op = triton.ops.blocksparse.matmul(layout, block, op_mode, device="cuda", trans_a=AT, trans_b=BT)
# inputs
a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = {
'sdd': 2 * Z * K * float(layout.sum()) * block * block,
'dsd': 2 * Z * N * float(layout.sum()) * block * block,
'dds': 2 * Z * M * float(layout.sum()) * block * block
}[op_mode] * 1e-12
return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
# -------------------------------
# Softmax
# -------------------------------
square_confs = [
triton.testing.Benchmark(
x_names=['M', 'N'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg='block',
line_vals=[16, 32, 64],
line_names=['Block16', 'Block32', 'Block64'],
ylabel='GBPS',
plot_name=f'{layout_mode}-square',
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)
for layout_mode in ['dense', 'tril']
]
@triton.testing.perf_report(square_confs)
def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
Z, H = 1, 1
make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode]
layout = make_layout(H, M // block, N // block)
a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda')
if provider == 'triton':
a = triton.testing.sparsify_tensor(a, layout, block)
op = triton.ops.blocksparse.softmax(layout, block, device="cuda")
gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
bench_matmul.run(print_data=True, show_plots=True)

View File

@@ -1,41 +0,0 @@
import torch
import triton
confs = [
triton.testing.Benchmark(
x_names=['N'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'Torch'],
ylabel='GBPS',
plot_name=f'{mode}-2048',
args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
)
for mode in ['forward', 'backward']
]
@triton.testing.perf_report(confs)
def bench_op(M, N, dtype, mode, provider):
# create inputs
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
gbps = lambda ms: num_gb / ms * 1e3
# forward pass
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
'triton': triton.ops.cross_entropy}[provider]
if mode == 'forward':
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
if mode == 'backward':
y = op(x, idx)
dy = torch.randn_like(y)
fn = lambda: y.backward(dy, retain_graph=True)
mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=[x])
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
if __name__ == '__main__':
bench_op.run(print_data=True)

View File

@@ -1,67 +0,0 @@
import torch
import triton
def rounded_linspace(low, high, steps, div):
ret = torch.linspace(low, high, steps)
ret = torch.div(ret.int() + div - 1, div, rounding_mode='trunc') * div
ret = torch.unique(ret)
return list(map(int, ret))
# Square benchmarks
nt = {False: "n", True: "t"}
square_confs = [
triton.testing.Benchmark(
x_names=["M", "N", "K"],
x_vals=rounded_linspace(512, 8192, 32, 128),
line_arg="provider",
line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS",
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
args={"AT": AT, "BT": BT, "dtype": torch.float16},
) for AT in [False] for BT in [False]
]
# Transformer training benchmarks
transformer_confs = [
triton.testing.Benchmark(
x_names=[x],
x_vals=rounded_linspace(NK // 16, NK, 32, 128),
line_arg="provider",
line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [12288]
for i, x in enumerate(["N", "K"])
for M in [2048]
]
@triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT:
a = a.t()
if BT:
b = b.t()
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
if provider == "cublas":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
if provider == "cutlass":
cutlass_matmul = triton.testing.cutlass_matmul
try:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
except Exception:
return None
return None

View File

@@ -1,2 +0,0 @@
pandas >= 1.3.3
matplotlib >= 3.4.3

View File

@@ -1,44 +0,0 @@
import argparse
import inspect
import os
import sys
import triton
def run_all(result_dir, names):
if not os.path.exists(result_dir):
os.makedirs(result_dir)
for mod in os.listdir(os.path.dirname(os.path.realpath(__file__))):
# skip non python files
if not mod.endswith('.py'):
continue
# skip file not in provided names
if names and names not in mod:
continue
# skip files that don't start with 'bench_'
if not mod.startswith('bench_'):
continue
print(f'running {mod}...')
mod = __import__(os.path.splitext(mod)[0])
benchmarks = inspect.getmembers(mod, lambda x: isinstance(x, triton.testing.Mark))
for name, bench in benchmarks:
curr_dir = os.path.join(result_dir, mod.__name__.replace('bench_', ''))
if len(benchmarks) > 1:
curr_dir = os.path.join(curr_dir, name.replace('bench_', ''))
if not os.path.exists(curr_dir):
os.makedirs(curr_dir)
bench.run(save_path=curr_dir)
def main(args):
parser = argparse.ArgumentParser(description="Run the benchmark suite.")
parser.add_argument("-r", "--result-dir", type=str, default='results', required=False)
parser.add_argument("-n", "--names", type=str, default='', required=False)
parser.set_defaults(feature=False)
args = parser.parse_args(args)
run_all(args.result_dir, args.names)
if __name__ == '__main__':
main(sys.argv[1:])

View File

@@ -0,0 +1,19 @@
import triton
import triton.language as tl
# triton kernel
@triton.jit
def kernel(X, stride_xm,
Z, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1
Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs))
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
print(ret)

13
python/examples/empty.py Normal file
View File

@@ -0,0 +1,13 @@
import torch
import triton
import triton.language as tl
@triton.jit
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
pass
X = torch.randn(1, device="cuda")
pgm = kernel[(1,)](X, 1, 1, BLOCK=1024)

View File

@@ -1,5 +1,4 @@
import distutils
import distutils.spawn
import os
import platform
import re
@@ -25,42 +24,54 @@ def get_build_type():
return "Debug"
elif check_env_flag("REL_WITH_DEB_INFO"):
return "RelWithDebInfo"
elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"):
return "TritonRelBuildWithAsserts"
else:
return "Release"
# TODO: change to release when stable enough
return "TritonRelBuildWithAsserts"
def use_system_llvm():
if platform.system() == "Windows":
return True
versions = ['-11.0', '-11', '-11-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
return any(p is not None for p in paths)
# --- third party packages -----
class Package(NamedTuple):
package: str
name: str
url: str
test_file: str
include_flag: str
lib_flag: str
syspath_var_name: str
def get_pybind11_package_info():
name = "pybind11-2.10.0"
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
return Package("pybind11", name, url, "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
def get_llvm_package_info():
# download if nothing is installed
system = platform.system()
system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
if use_assert_enabled_llvm:
name = 'llvm+mlir-14.0.0-x86_64-{}-assert'.format(system_suffix)
url = "https://github.com/shintaro-iwasaki/llvm-releases/releases/download/llvm-14.0.0-329fda39c507/{}.tar.xz".format(name)
else:
name = 'clang+llvm-14.0.0-x86_64-{}'.format(system_suffix)
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/{}.tar.xz".format(name)
return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
def get_thirdparty_packages(triton_cache_path):
class Package(NamedTuple):
package: str
name: str
url: str
test_file: str
include_flag: str
lib_flag: str
packages = [
Package("pybind11", "pybind11-2.10.0", "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz", "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "")
]
if not use_system_llvm():
# donwload LLVM if no suitable system LLVM is installed
packages.append(
Package("llvm", "clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
)
packages = [get_pybind11_package_info(), get_llvm_package_info()]
thirdparty_cmake_args = []
for p in packages:
package_root_dir = os.path.join(triton_cache_path, p.package)
package_dir = os.path.join(package_root_dir, p.name)
test_file_path = os.path.join(package_dir, p.test_file)
if p.syspath_var_name in os.environ:
package_dir = os.environ[p.syspath_var_name]
if not os.path.exists(test_file_path):
try:
shutil.rmtree(package_root_dir)
@@ -77,6 +88,8 @@ def get_thirdparty_packages(triton_cache_path):
thirdparty_cmake_args.append("-D{}={}/lib".format(p.lib_flag, package_dir))
return thirdparty_cmake_args
# ---- cmake extension ----
class CMakeExtension(Extension):
def __init__(self, name, path, sourcedir=""):
@@ -113,22 +126,27 @@ class CMakeBuild(build_ext):
self.build_extension(ext)
def build_extension(self, ext):
lit_dir = shutil.which('lit')
triton_cache_path = os.path.join(os.environ["HOME"], ".triton")
# lit is used by the test suite
thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path)
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
# python directories
python_include_dirs = [distutils.sysconfig.get_python_inc()]
python_include_dir = distutils.sysconfig.get_python_inc()
cmake_args = [
"-DLLVM_ENABLE_WERROR=ON",
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DBUILD_TUTORIALS=OFF",
"-DBUILD_PYTHON_MODULE=ON",
# '-DPYTHON_EXECUTABLE=' + sys.executable,
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
"-DTRITON_BUILD_TUTORIALS=OFF",
"-DTRITON_BUILD_PYTHON_MODULE=ON",
"-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON",
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
"-DLLVM_EXTERNAL_LIT=" + lit_dir,
] + thirdparty_cmake_args
# configuration
cfg = get_build_type()
build_args = ["--config", cfg]
@@ -155,16 +173,17 @@ setup(
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
long_description="",
packages=["triton", "triton/_C", "triton/language", "triton/runtime", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
install_requires=[
"cmake",
"filelock",
"torch",
"lit",
],
package_data={
"triton/ops": ["*.c"],
"triton/ops/blocksparse": ["*.c"],
"triton/language": ["*.bc"],
"triton/language": ["*.bc"]
},
include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")],
@@ -180,6 +199,7 @@ setup(
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.6",
],
test_suite="tests",
extras_require={
"tests": [
"autopep8",

View File

@@ -1,202 +0,0 @@
#include "cutlass/library/handle.h"
#include "cutlass/library/library.h"
#include "cutlass/library/operation_table.h"
#include "cutlass/library/singleton.h"
#include "pybind11/pybind11.h"
#include "triton/tools/bench.hpp"
using namespace cutlass;
using namespace cutlass::library;
std::map<std::vector<size_t>, const Operation *> op_cache_;
static int const kHostWorkspaceSize = (4 << 10);
static int const kDeviceWorkspaceSize = (4 << 20);
void run(int M, int N, int K,
int lda, int ldb, int ldc, int ldd,
void const *ptr_A, void const *ptr_B, void const *ptr_C, void *ptr_D,
void const *alpha, void const *beta,
ScalarPointerMode scalar_mode,
const Operation *operation,
cudaStream_t stream) {
GemmUniversalConfiguration configuration{
GemmUniversalMode::kGemm,
{M, N, K},
1,
lda,
ldb,
ldc,
ldd};
// host workspace size
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed)
throw std::runtime_error("Unable to find gemm operation");
char host_workspace[kHostWorkspaceSize];
// device workspace size
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
if (uint64_t(kDeviceWorkspaceSize) < device_workspace_size_needed)
throw std::runtime_error("Unable to find gemm operation");
static void *device_workspace;
// Initialize host and device workspaces
Status status = operation->initialize(&configuration, host_workspace, device_workspace, stream);
if (status != cutlass::Status::kSuccess)
throw std::runtime_error("Unable to initialize workspace");
// Run the operator
GemmArguments arguments{ptr_A, ptr_B, ptr_C, ptr_D, alpha, beta, scalar_mode};
operation->run(&arguments, host_workspace, device_workspace, stream);
}
const Operation *autotune(int M, int N, int K,
NumericTypeID element_compute,
NumericTypeID element_scalar,
void const *alpha,
NumericTypeID element_A,
LayoutTypeID layout_A,
ComplexTransform transform_A,
void const *ptr_A,
int lda,
NumericTypeID element_B,
LayoutTypeID layout_B,
ComplexTransform transform_B,
void const *ptr_B,
int ldb,
void const *beta,
NumericTypeID element_C,
void const *ptr_C,
int ldc,
void *ptr_D,
int ldd,
ScalarPointerMode scalar_mode,
int device_id,
cudaStream_t stream) {
// index operation table with functional key
GemmFunctionalKey key(
Provider::kCUTLASS,
GemmKind::kUniversal,
element_compute,
element_scalar,
element_A,
layout_A,
transform_A,
element_B,
layout_B,
transform_B,
element_C);
auto operators_it = Singleton::get().operation_table.gemm_operations.find(key);
if (operators_it == Singleton::get().operation_table.gemm_operations.end())
throw std::runtime_error("Unable to find gemm operation");
if (operators_it->second.empty())
throw std::runtime_error("Unable to find gemm operation");
cudaDeviceProp device_prop;
cudaError_t error = cudaGetDeviceProperties(&device_prop, device_id);
if (error != cudaSuccess)
throw std::runtime_error("Unable to get device properties");
int cc = device_prop.major * 10 + device_prop.minor;
// index operation table with preference key
// assume 8-bytes aligned memory pointers
int alignment = 8;
GemmPreferenceKey preference_key(cc, alignment);
auto autotune_it = operators_it->second.find(preference_key);
if (autotune_it == operators_it->second.end())
throw std::runtime_error("Unable to find gemm operation");
const std::vector<const Operation *> &operations = autotune_it->second;
if (operations.empty())
throw std::runtime_error("Unable to find gemm operation");
// auto-tune
const Operation *best = nullptr;
double best_ms = std::numeric_limits<double>::max();
for (const Operation *op : operations) {
auto fn = [&]() { run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D,
alpha, beta, scalar_mode, op, stream); };
triton::driver::cu_stream tt_stream((CUstream)stream, false);
double ms = triton::tools::bench(fn, &tt_stream, 10, 25);
if (ms < best_ms) {
best_ms = ms;
best = op;
}
}
return best;
}
// map of torch datatypes to cutlass datatypes
std::map<std::string, NumericTypeID> type_map = {
{"float16", NumericTypeID::kF16},
{"float32", NumericTypeID::kF32},
{"float64", NumericTypeID::kF64}};
void cutlass_matmul(uintptr_t A, uintptr_t B, uintptr_t C,
size_t M, size_t N, size_t K,
size_t stride_a_0, size_t stride_a_1,
size_t stride_b_0, size_t stride_b_1,
size_t stride_c_0, size_t stride_c_1,
std::string type_a, std::string type_b, std::string type_c,
size_t dev_id, uint64_t stream_handle) {
void *ptr_A = (void *)A;
void *ptr_B = (void *)B;
void *ptr_C = (void *)C;
void *ptr_D = ptr_C;
size_t lda = stride_a_0;
size_t ldb = stride_b_0;
size_t ldc = stride_c_1;
size_t ldd = ldc;
float alpha = 1.0f;
float beta = 0.0f;
// layout for A
LayoutTypeID layout_A;
if (stride_a_0 == 1)
layout_A = LayoutTypeID::kColumnMajor;
else if (stride_a_1 == 1)
layout_A = LayoutTypeID::kRowMajor;
else
throw std::runtime_error("A layout is not supported");
// layout for B
LayoutTypeID layout_B;
if (stride_b_0 == 1)
layout_B = LayoutTypeID::kColumnMajor;
else if (stride_b_1 == 1)
layout_B = LayoutTypeID::kRowMajor;
else
throw std::runtime_error("B layout is not supported");
// data types
NumericTypeID element_compute = NumericTypeID::kF32;
NumericTypeID element_A = type_map[type_a];
NumericTypeID element_B = type_map[type_b];
NumericTypeID element_C = type_map[type_c];
// misc. flags
ScalarPointerMode scalar_mode = ScalarPointerMode::kHost;
NumericTypeID element_scalar = NumericTypeID::kF32;
ComplexTransform transform_A = ComplexTransform::kNone;
ComplexTransform transform_B = ComplexTransform::kNone;
// runtime flags
cudaStream_t stream = (cudaStream_t)stream_handle;
// auto-tune
std::vector<size_t> tune_key = {M, N, K, (size_t)element_A, (size_t)element_B, (size_t)element_C,
dev_id, (size_t)element_compute, (size_t)scalar_mode};
auto it = op_cache_.find(tune_key);
if (it == op_cache_.end()) {
const Operation *op = autotune(M, N, K, element_compute, element_scalar, &alpha,
element_A, layout_A, transform_A, ptr_A, lda,
element_B, layout_B, transform_B, ptr_B, ldb,
&beta, element_C, ptr_C, ldc, ptr_D, ldd, scalar_mode,
dev_id, stream);
it = op_cache_.insert({tune_key, op}).first;
}
run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D, &alpha, &beta,
scalar_mode, it->second, stream);
}
void init_cutlass(pybind11::module &m) {
pybind11::module subm = m.def_submodule("cutlass");
subm.def("matmul", &cutlass_matmul, "matrix multiplication");
}

View File

@@ -1,696 +0,0 @@
#include "triton/ir/builder.h"
#include <functional>
#include <iostream>
#include <pybind11/pybind11.h>
namespace ir = triton::ir;
namespace py = pybind11;
static const std::string _builder_doc = R"pbdoc(
:param builder: IR builder to generate code into, optional, set automatically when called inside a @triton.jit function
:type builder: triton.ir.builder
)pbdoc";
#define VA_ARGS(...) , ##__VA_ARGS__
#define DEF_FUNC(MOD, PY_NAME, C_FUNC, ...) \
MOD.def(PY_NAME, C_FUNC, (C_FUNC##_docstr + _builder_doc).c_str(), \
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
void throw_not_implemented(std::string key) {
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. This is likely a bug on our side.");
}
void throw_not_int_or_float(std::string key) {
throw std::runtime_error("`" + key + "` only supported for integer and floating point types.");
}
enum type_code {
_bool,
int8,
int16,
int32,
int64,
float16,
float32,
float64
};
ir::type *make_ir(type_code ty, ir::builder *builder) {
switch (ty) {
case float16:
return builder->get_half_ty();
case float32:
return builder->get_float_ty();
default:
throw_not_implemented("make_ir");
}
}
type_code from_ir(ir::type *ty) {
if (ty->is_half_ty())
return float16;
if (ty->is_float_ty())
return float32;
throw_not_implemented("from_ir");
}
/*----------------------------------------------
definition of triton.cast / triton.ir.value.to
----------------------------------------------*/
std::string cast_docstr = R"pbdoc(
Tries to cast a block to a new data type.
:param input: The input block.
:type input: triton.ir.value
)pbdoc";
ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) {
ir::type *src_ty = input->get_type();
ir::type *dst_ty = make_ir(_dtype, builder);
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
// FP Truncation
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
if (truncate_fp)
return builder->create_fp_trunc(input, dst_ty);
// FP Extension
bool ext_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width();
if (ext_fp)
return builder->create_fp_ext(input, dst_ty);
// Int cast
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
return builder->create_int_cast(input, dst_ty, true);
// Float -> Int
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty())
return builder->create_fp_to_si(input, dst_ty);
// int -> Float
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty())
return builder->create_si_to_fp(input, dst_ty);
// Ptr -> Ptr
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::BitCast, input, dst_ty);
// * -> Bool
if (dst_sca_ty->is_bool_ty()) {
if (src_sca_ty->is_pointer_ty())
input = cast(input, int64, builder);
ir::value *other = builder->get_int64(0);
if (src_ty->is_bool_ty())
other = builder->create_splat(other, src_ty->get_block_shapes());
return builder->create_icmpNE(input, other);
}
throw_not_implemented("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
}
/*----------------------------------------------
definition of triton.broadcast_check
----------------------------------------------*/
std::string try_broadcast_docstr = R"pbdoc(
Tries to broadcast two blocks to a common compatible shape.
:param input: The first input block.
:type input: triton.ir.value
:param other: The second input block.
:type other: triton.ir.value
)pbdoc";
std::tuple<ir::value *, ir::value *> try_broadcast(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type();
// make_shape_compatible(block, scalar)
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());
// make_shape_compatible(scalar, block)
else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty())
lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes());
// make_shape_compatible(block, block)
else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) {
auto lhs_shape = lhs_ty->get_block_shapes();
auto rhs_shape = rhs_ty->get_block_shapes();
if (lhs_shape.size() != rhs_shape.size())
throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank");
ir::type::block_shapes_t ret_shape;
for (size_t i = 0; i < lhs_shape.size(); ++i) {
unsigned left = lhs_shape[i];
unsigned right = rhs_shape[i];
if (left == 1)
ret_shape.push_back(right);
else if (right == 1)
ret_shape.push_back(left);
else if (left == right)
ret_shape.push_back(left);
else
throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) +
": " + std::to_string(left) + " and " + std::to_string(right));
}
if (lhs_shape != ret_shape)
lhs = builder->create_broadcast(lhs, ret_shape);
if (rhs_shape != ret_shape)
rhs = builder->create_broadcast(rhs, ret_shape);
}
return std::make_tuple(lhs, rhs);
}
/*----------------------------------------------
definition of triton.broadcast_to
----------------------------------------------*/
std::string broadcast_to_docstr = R"pbdoc(
Tries to broadcast a block to a new shape.
:param input: The input block.
:type input: triton.value
:param shape: The new shape.
:type shape: tuple of int
)pbdoc";
ir::value *broadcast_to(ir::value *input, const ir::type::block_shapes_t &shape, ir::builder *builder) {
if (!input->get_type()->is_block_ty())
return builder->create_splat(input, shape);
auto src_shape = input->get_type()->get_block_shapes();
if (src_shape.size() != shape.size())
throw std::runtime_error("Cannot broadcast");
return builder->create_broadcast(input, shape);
}
/*----------------------------------------------
definition of triton.load
----------------------------------------------*/
std::string load_docstr = R"pbdoc(
Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`.
:param pointer: Pointer to the data to be loaded.
:type pointer: Block of triton.pointer
:param mask: if mask[idx] is false, do not load the data at `pointer[idx]`.
:type mask: Block of triton.bool, optional
:param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]`
:type other: Block of triton.value, optional
)pbdoc";
ir::value *load(ir::value *pointer, std::optional<ir::value *> _mask, std::optional<ir::value *> _other, ir::builder *builder) {
if (!_mask.has_value() && !_other.has_value())
return builder->create_load(pointer);
if (!_mask.has_value())
throw std::runtime_error("`other` cannot be provided without `mask`");
ir::value *mask = _mask.value();
ir::type *elt_ty = pointer->get_type()->get_scalar_ty()->get_pointer_element_ty();
auto shape = pointer->get_type()->get_block_shapes();
ir::value *other = _other.has_value() ? _other.value() : ir::undef_value::get(elt_ty);
other = cast(other, from_ir(elt_ty), builder);
other = broadcast_to(other, shape, builder);
mask = broadcast_to(mask, shape, builder);
return builder->create_masked_load(pointer, mask, other);
}
/*----------------------------------------------
definition of triton.store
----------------------------------------------*/
std::string store_docstr = R"pbdoc(
Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`.
:param pointer: The memory locations where the elements of `value` are stored.
:type pointer: Block of triton.pointer
:param value: The block of elements to be stored.
:type value: Block of triton.value
:param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`.
:type mask: Block of triton.bool, optional
)pbdoc";
ir::value *store(ir::value *ptr, ir::value *val, std::optional<ir::value *> _mask, ir::builder *builder) {
if (!_mask.has_value())
return builder->create_store(ptr, val);
ir::value *mask = _mask.value();
return builder->create_masked_store(ptr, val, mask);
}
/*----------------------------------------------
definition of triton.dot
----------------------------------------------*/
std::string dot_docstr = R"pbdoc(
Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions.
:param input: The first block to be multiplied.
:type input: 2D block of scalar-type in {`float16`, `float32`}
:param other: The second block to be multiplied.
:type other: 2D block of scalar-type in {`float16`, `float32`}
)pbdoc";
ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
ir::value *_0 = builder->get_float32(0);
unsigned M = lhs->get_type()->get_block_shapes()[0];
unsigned N = rhs->get_type()->get_block_shapes()[1];
_0 = builder->create_splat(_0, {M, N});
return builder->create_dot(lhs, rhs, _0);
}
/*----------------------------------------------
definition of triton.where
----------------------------------------------*/
std::string where_docstr = R"pbdoc(
Returns a block of elements from either `x` or `y`, depending on `condition`.
Note that `x` and `y` are always evaluated regardless of the value of `condition`.
If you want to avoid unintended memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
:param condition: When True (nonzero), yield x, otherwise yield y.
:type condition: Block of triton.bool
:param x: values selected at indices where condition is True.
:param y: values selected at indices where condition is False.
)pbdoc";
ir::value *where(ir::value *condition, ir::value *x, ir::value *y, ir::builder *builder) {
return builder->create_select(condition, x, y);
};
/*----------------------------------------------
definition of triton.arange
----------------------------------------------*/
std::string arange_docstr = R"pbdoc(
Returns contiguous values within the open interval [start, end).
:param start: Start of the interval.
:type start: int
:param stop: End of the interval.
:type stop: int
)pbdoc";
ir::value *arange(int start, int end, ir::builder *builder) {
return builder->get_range(start, end);
};
/*----------------------------------------------
definition of triton.program_id
----------------------------------------------*/
std::string program_id_docstr = R"pbdoc(
Returns the id of the current program instance along the given `axis`.
Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
)pbdoc";
ir::value *program_id(int axis, ir::builder *builder) {
return builder->create_get_program_id(axis);
};
/*----------------------------------------------
definition of triton.num_programs
----------------------------------------------*/
std::string num_programs_docstr = R"pbdoc(
Returns the number of program instances launched along the given `axis`.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
)pbdoc";
ir::value *num_programs(int axis, ir::builder *builder) {
return builder->create_get_num_programs(axis);
};
/*----------------------------------------------
definition of triton.zeros
----------------------------------------------*/
std::string zeros_docstr = R"pbdoc(
Returns a block filled with the scalar value 0 and the given shape.
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., tl.float16
:type dtype: triton.ir.dtype
)pbdoc";
ir::value *zeros(ir::type::block_shapes_t shape, type_code _dtype, ir::builder *builder) {
ir::type *dtype = make_ir(_dtype, builder);
ir::value *_0 = ir::constant::get_null_value(dtype);
return builder->create_splat(_0, shape);
};
/*----------------------------------------------
definition of triton.exp
----------------------------------------------*/
std::string _exp_docstr = R"pbdoc(
Returns the element-wise exponential of `input`.
)pbdoc";
ir::value *_exp(ir::value *input, ir::builder *builder) {
return builder->create_exp(input);
};
/*----------------------------------------------
definition of triton.log
----------------------------------------------*/
std::string _log_docstr = R"pbdoc(
Returns the element-wise natural logarithm of `input`.
)pbdoc";
ir::value *_log(ir::value *input, ir::builder *builder) {
return builder->create_log(input);
};
/*----------------------------------------------
definition of triton.sqrt
----------------------------------------------*/
std::string sqrt_docstr = R"pbdoc(
Returns the element-wise square root of `input`.
)pbdoc";
ir::value *sqrt(ir::value *input, ir::builder *builder) {
return builder->create_sqrt(input);
};
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
if (scalar_ty->is_floating_point_ty())
return builder->create_reduce(input, FLOAT_OP, axis);
else if (scalar_ty->is_integer_ty())
return builder->create_reduce(input, INT_OP, axis);
else
throw_not_int_or_float(name);
}
/*----------------------------------------------
definition of triton.min
----------------------------------------------*/
std::string min_docstr = R"pbdoc(
Returns the minimum value of `input`.
)pbdoc";
ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
};
/*----------------------------------------------
definition of triton.arg_min
----------------------------------------------*/
std::string min_docstr = R"pbdoc(
Returns the minimum value's index of `input`.
)pbdoc";
ir::value *argmin(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "argmin", ir::reduce_inst::ARGFMIN, ir::reduce_inst::ARGMIN);
};
/*----------------------------------------------
definition of triton.max
----------------------------------------------*/
std::string max_docstr = R"pbdoc(
Returns the maximum value of `input`.
)pbdoc";
ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
};
/*----------------------------------------------
definition of triton.arg_max
----------------------------------------------*/
std::string max_docstr = R"pbdoc(
Returns the maximum value's index of `input`.
)pbdoc";
ir::value *argmax(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "argmax", ir::reduce_inst::ARGFMAX, ir::reduce_inst::ARGMAX);
};
/*----------------------------------------------
definition of triton.sum
----------------------------------------------*/
std::string sum_docstr = R"pbdoc(
Returns the sum of `input`.
)pbdoc";
ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
};
/*----------------------------------------------
definition of triton.atomic_cas
----------------------------------------------*/
std::string atomic_cas_docstr = R"pbdoc(
Atomic compare-and-swap.
)pbdoc";
ir::value *atomic_cas(ir::value *ptr, ir::value *cmp, ir::value *val, ir::builder *builder) {
return builder->create_atomic_cas(ptr, cmp, val);
};
/*----------------------------------------------
definition of triton.atomic_xchg
----------------------------------------------*/
std::string atomic_xchg_docstr = R"pbdoc(
Atomic exchange.
)pbdoc";
ir::value *atomic_xchg(ir::value *ptr, ir::value *val, ir::builder *builder) {
return builder->create_atomic_exch(ptr, val);
};
/*----------------------------------------------
debug barrier
----------------------------------------------*/
std::string debug_barrier_docstr = R"pbdoc(
Temporary hacky fixup for when the compiler forgets to insert sync barriers
)pbdoc";
ir::value *debug_barrier(ir::builder *builder) {
return builder->create_barrier();
}
#define DEF_BINARY_OP(MOD, PY_NAME, C_FUNC, ...) \
MOD.def(PY_NAME, binary_op(C_FUNC), (C_FUNC##_docstr + _builder_doc).c_str(), \
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
template <class FN>
std::function<ir::value *(ir::value *, ir::value *, ir::builder *builder)>
binary_op(const FN &fn) {
auto ret = [&fn](ir::value *self, ir::value *other, ir::builder *builder) {
//std::tie(self, other) = try_broadcast(self, other, builder);
return fn(self, other, builder);
};
return ret;
}
/*----------------------------------------------
definition of self + other
----------------------------------------------*/
std::string add_docstr = R"pbdoc(
Returns self + other, element-wise.
)pbdoc";
ir::value *add(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// ptr + offset
if (scalar_ty->is_pointer_ty())
return builder->create_gep(self, {other});
// float + float
else if (scalar_ty->is_floating_point_ty())
return builder->create_fadd(self, other);
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_add(self, other);
throw_not_implemented("add");
}
/*----------------------------------------------
definition of self - other
----------------------------------------------*/
std::string sub_docstr = R"pbdoc(
Returns self - other, element-wise.
)pbdoc";
ir::value *sub(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// ptr + offset
if (scalar_ty->is_pointer_ty())
return builder->create_gep(self, {other});
// float + float
if (scalar_ty->is_floating_point_ty())
return builder->create_fsub(self, other);
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_sub(self, other);
throw_not_implemented("sub");
}
/*----------------------------------------------
definition of self * other
----------------------------------------------*/
std::string mul_docstr = R"pbdoc(
Returns self * other, element-wise.
)pbdoc";
ir::value *mul(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float * float
if (scalar_ty->is_floating_point_ty())
return builder->create_fmul(self, other);
// int * int
else if (scalar_ty->is_integer_ty())
return builder->create_mul(self, other);
throw_not_implemented("mul");
}
/*----------------------------------------------
definition of self > other
----------------------------------------------*/
std::string greater_than_docstr = R"pbdoc(
Returns self > other, element-wise.
)pbdoc";
ir::value *greater_than(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float > float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGT(self, other);
// int > int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGT(self, other);
throw_not_implemented("greater_than");
}
/*----------------------------------------------
definition of self >= other
----------------------------------------------*/
std::string greater_equal_docstr = R"pbdoc(
Returns self >= other, element-wise.
)pbdoc";
ir::value *greater_equal(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float >= float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGE(self, other);
// int >= int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGE(self, other);
throw_not_implemented("greater_equal");
}
/*----------------------------------------------
definition of self < other
----------------------------------------------*/
std::string less_than_docstr = R"pbdoc(
Returns self < other, element-wise.
)pbdoc";
ir::value *less_than(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLT(self, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLT(self, other);
throw_not_implemented("less_than");
}
/*----------------------------------------------
definition of self <= other
----------------------------------------------*/
std::string less_equal_docstr = R"pbdoc(
Returns self <= other, element-wise.
)pbdoc";
ir::value *less_equal(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLE(self, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLE(self, other);
throw_not_implemented("less_equal");
}
/*----------------------------------------------
definition of self == other
----------------------------------------------*/
std::string equal_docstr = R"pbdoc(
Returns self == other, element-wise.
)pbdoc";
ir::value *equal(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOEQ(self, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpEQ(self, other);
throw_not_implemented("equal");
}
/*----------------------------------------------
definition of self / other
----------------------------------------------*/
std::string _div_docstr = R"pbdoc(
Returns self / other, element-wise.
)pbdoc";
ir::value *_div(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float / float
if (scalar_ty->is_floating_point_ty())
return builder->create_fdiv(self, other);
// int / int
else if (scalar_ty->is_integer_ty())
return builder->create_sdiv(self, other);
throw_not_implemented("div");
}
/*----------------------------------------------
definition of self % other
----------------------------------------------*/
std::string mod_docstr = R"pbdoc(
Returns self % other, element-wise.
)pbdoc";
ir::value *mod(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float % int
if (scalar_ty->is_floating_point_ty())
return builder->create_frem(self, other);
// int % int
else if (scalar_ty->is_integer_ty())
return builder->create_srem(self, other);
throw_not_implemented("mod");
}
/*----------------------------------------------
definition of self & other
----------------------------------------------*/
std::string _and_docstr = R"pbdoc(
Returns self & other, element-wise.
)pbdoc";
ir::value *_and(ir::value *self, ir::value *other, ir::builder *builder) {
return builder->create_and(self, other);
}
/*----------------------------------------------
definition of minimum(self, other)
----------------------------------------------*/
std::string minimum_docstr = R"pbdoc(
Returns element-wise minimum of self and other
)pbdoc";
ir::value *minimum(ir::value *self, ir::value *other, ir::builder *builder) {
return where(less_than(self, other, builder), self, other, builder);
}
/*----------------------------------------------
definition of self[slices]
----------------------------------------------*/
enum slice_mode_t {
NEWAXIS,
ALL
};
std::string subscript_docstr = R"pbdoc(
returns self[slices].
:param slices: The slices to subscript with.
:type slices: List of `None` or `:` slices.
)pbdoc";
ir::value *subscript(ir::value *self, std::vector<py::object> slices, ir::builder *builder) {
std::vector<slice_mode_t> modes;
for (py::object slice : slices) {
py::object none = py::none();
py::object all = py::make_tuple(none, none, none);
if (slice.is(none))
modes.push_back(NEWAXIS);
else if (all.attr("__eq__")(slice))
modes.push_back(ALL);
else
throw std::runtime_error("slice must be None or (None, None, None)");
}
ir::type::block_shapes_t shape;
size_t curr = 0;
for (slice_mode_t mode : modes) {
if (mode == NEWAXIS)
shape.push_back(1);
else {
assert(mode == ALL);
shape.push_back(self->get_type()->get_block_shapes()[curr++]);
}
}
return builder->create_reshape(self, shape);
}

View File

@@ -8,8 +8,4 @@ void init_cutlass(pybind11::module &m);
PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API";
init_triton(m);
init_superblocking(m);
#ifdef WITH_CUTLASS_BINDINGS
init_cutlass(m);
#endif
}

View File

@@ -1,119 +0,0 @@
#include <iostream>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
// row-major 3d tensor
class tensor_3d {
public:
tensor_3d(int size_0, int size_1, int size_2, int *data = nullptr) : data_(size_0 * size_1 * size_2, 0) {
if (data)
std::copy(data, data + data_.size(), data_.begin());
stride_0_ = size_1 * size_2;
stride_1_ = size_2;
stride_2_ = 1;
}
int &operator()(int i, int j, int k) {
return data_[i * stride_0_ + j * stride_1_ + k];
}
private:
std::vector<int> data_;
int stride_0_;
int stride_1_;
int stride_2_;
};
std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width, int H, int M, int N) {
tensor_3d tmp(H, M, N);
std::vector<int> current(H, 0);
int num = 0;
std::vector<int> lut(H * M * N * 4);
for (size_t h = 0; h < H; h++) {
// surrounding indices
std::vector<int> ii_left(max_width, -1);
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
// start the dynamic programming algorithm
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
int v = layout(h, m, n);
if (v == 0)
continue;
int n_left = ii_left[max_width - 1];
int m_top = ii_top[max_width - 1][n];
int top = (m_top >= 0) ? tmp(h, m_top, n) : 0;
int left = (n_left >= 0) ? tmp(h, m, n_left) : 0;
int topleft = (m_top >= 0 && n_left >= 0) ? tmp(h, m_top, n_left) : 0;
int width = std::min(left, std::min(top, topleft)) + 1;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for (int nn = n_left + 1; nn < n; nn++)
if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n])
width = 1;
tmp(h, m, n) = width;
// update n_left ring buffer
for (int k = 0; k < max_width - 1; k++)
ii_left[k] = ii_left[k + 1];
ii_left[max_width - 1] = n;
// update ii_top ring buffer
for (int k = 0; k < max_width - 1; k++)
ii_top[k][n] = ii_top[k + 1][n];
ii_top[max_width - 1][n] = m;
// block is too small -- skip
if (width != max_width)
continue;
// retained blocks are set to zeros
for (size_t km = 0; km < max_width; km++)
for (size_t kn = 0; kn < max_width; kn++) {
int mm = ii_top[km][n];
int nn = ii_left[kn];
if (mm < 0 || nn < 0)
continue;
layout(h, mm, nn) = 0;
tmp(h, mm, nn) = 0;
lut[num++] = (int)h;
lut[num++] = (int)mm;
lut[num++] = (int)nn;
lut[num++] = idx(h, mm, nn);
}
}
}
}
lut.resize(num);
return lut;
}
typedef std::pair<int, pybind11::array_t<int>> lut_t;
std::vector<lut_t> superblock(uintptr_t LAYOUT, int H, int M, int N, int start_width) {
std::vector<lut_t> ret;
int current = 0;
tensor_3d layout(H, M, N, (int *)LAYOUT);
tensor_3d idx(H, M, N);
for (int64_t h = 0; h < H; h++)
for (int64_t m = 0; m < M; m++)
for (int64_t n = 0; n < N; n++) {
if (layout(h, m, n) == 0)
continue;
idx(h, m, n) = current++;
}
// create lut
for (int max_width = start_width; max_width > 0; max_width /= 2) {
auto lut = segment_blocks(layout, idx, max_width, H, M, N);
if (lut.size() == 0)
continue;
ret.push_back(std::make_pair(max_width, pybind11::array_t<int>(lut.size(), lut.data())));
}
return ret;
}
void init_superblocking(pybind11::module &m) {
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,56 @@
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
torch_type = {
"bool": torch.bool,
'int8': torch.int8,
'uint8': torch.uint8,
'int16': torch.int16,
"int32": torch.int32,
'int64': torch.long,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
"float32": torch.float32,
"float64": torch.float64
}
def get_tensor(shape, data_type, b_positive=False):
x = None
if data_type.startswith('int'):
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
else:
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
return x
# @pytest.mark.parametrize('data_type',
# [("int8"),
# ('int16'),
# ('int32'),
# ("int64"),
# ('float16'),
# ("float32"),
# ("float64")])
def printf(data_type):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.printf("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = get_tensor(shape, data_type)
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
printf("float16")
printf("int8")

View File

@@ -1,5 +1,6 @@
# flake8: noqa: F821,F841
import itertools
import os
import re
from typing import Optional, Union
@@ -104,8 +105,8 @@ def check_type_supported(dtype):
'''
skip test if dtype is not supported on the current device
'''
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
cc = torch.cuda.get_device_capability()
if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
@@ -414,8 +415,8 @@ def test_where(dtype):
def test_where_broadcast():
@triton.jit
def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
mask = tl.load(cond_ptr + yoffsets)
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
@@ -424,8 +425,8 @@ def test_where_broadcast():
@triton.jit
def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
xoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [BLOCK_SIZE, 1])
yoffsets = tl.reshape(tl.arange(0, BLOCK_SIZE), [1, BLOCK_SIZE])
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
mask = 0
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
res = tl.where(mask, vals, 0.)
@@ -462,9 +463,6 @@ def test_unary_op(dtype_x, expr, device='cuda'):
# ----------------
# test math ops
# ----------------
# @pytest.mark.parametrize("expr", [
# 'exp', 'log', 'cos', 'sin'
# ])
@pytest.mark.parametrize("expr", [
@@ -490,9 +488,12 @@ def make_ptr_str(name, shape):
return f"{name} + {' + '.join(offsets)}"
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
@pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', d)
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
for s in ['None, :', ':, None',
'None, :, :',
':, :, None']
for d in ['int32', 'uint32', 'uint16']
])
def test_index1d(expr, dtype_str, device='cuda'):
@@ -605,8 +606,8 @@ def test_tuples():
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
if dtype_x_str == 'float16':
pytest.skip("Only test atomic float16 ops on devices with sm >= 70")
n_programs = 5
@@ -651,9 +652,10 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
@pytest.mark.parametrize("axis", [0, 1])
def test_tensor_atomic_rmw(axis, device="cuda"):
shape0, shape1 = 8, 8
@pytest.mark.parametrize("shape, axis",
[(shape, axis) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32)] for axis in [0, 1]])
def test_tensor_atomic_rmw(shape, axis, device="cuda"):
shape0, shape1 = shape
# triton kernel
@triton.jit
@@ -662,14 +664,18 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
off1 = tl.arange(0, SHAPE1)
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
z = tl.sum(x, axis=AXIS)
tl.atomic_add(Z + off0, z)
if AXIS == 1:
tl.atomic_add(Z + off0, z)
else:
tl.atomic_add(Z + off1, z)
rs = RandomState(17)
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
# reference result
z_ref = np.sum(x, axis=axis)
z_ref = np.sum(x, axis=axis, keepdims=False)
# triton result
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
z_shape = (shape0, ) if axis == 1 else (shape1, )
z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device)
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
@@ -724,6 +730,10 @@ def test_atomic_cas():
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# bfloat16 on cc < 80 will not be tested
check_type_supported(dtype_x)
check_type_supported(dtype_z)
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
x0 = 43 if dtype_x in int_dtypes else 43.5
if dtype_x in float_dtypes and dtype_z == 'int1':
@@ -737,9 +747,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, Z, BITCAST: tl.constexpr):
x = tl.load(X)
x_ptr = X + tl.arange(0, 1)
z_ptr = Z + tl.arange(0, 1)
x = tl.load(x_ptr)
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
tl.store(Z, z)
tl.store(z_ptr, z)
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
# triton result
@@ -869,9 +881,19 @@ def test_f16_to_f8_rounding():
# ---------------
def get_reduced_dtype(dtype_str, op):
if op == 'argmin' or op == 'argmax':
return 'int32'
if dtype_str in ['int8', 'uint8', 'int16', 'uint16']:
return 'int32'
if dtype_str == 'bfloat16':
return 'float32'
return dtype_str
@pytest.mark.parametrize("op, dtype_str, shape",
[(op, dtype, shape)
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for op in ['min', 'max', 'sum']
for dtype in dtypes_with_bfloat16
for shape in [32, 64, 128, 512]])
def test_reduce1d(op, dtype_str, shape, device='cuda'):
@@ -892,7 +914,7 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
'argmin': np.argmin, 'argmax': np.argmax}[op]
# numpy result
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
z_dtype_str = get_reduced_dtype(dtype_str, op)
z_tri_dtype_str = z_dtype_str
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
z_dtype_str = 'float32'
@@ -919,21 +941,35 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
np.testing.assert_equal(z_ref, z_tri)
# TODO: [Qingyi] Fix argmin / argmax
reduce_configs1 = [
(op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for op in ['min', 'max', 'sum']
for axis in [1]
]
# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory
# exceeds the limit of 99KB
reduce2d_shapes = [(2, 32), (4, 32), (4, 128)]
# TODO: fix and uncomment
# , (32, 64), (64, 128)]
if 'V100' in torch.cuda.get_device_name(0):
reduce2d_shapes += [(128, 256) and (32, 1024)]
reduce_configs2 = [
(op, 'float32', shape, axis)
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
for op in ['min', 'max', 'sum']
for shape in reduce2d_shapes
for axis in [0, 1]
]
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
# triton kernel
@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
@@ -954,7 +990,7 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
x_tri = to_triton(x)
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
'argmin': np.argmin, 'argmax': np.argmax}[op]
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
z_dtype_str = get_reduced_dtype(dtype_str, op)
z_tri_dtype_str = z_dtype_str
# numpy result
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
@@ -992,7 +1028,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
@pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm)
for dtype in ['bfloat16', 'float16', 'float32']
# TODO: bfloat16
for dtype in ['float16', 'float32']
for shape in [(64, 64), (128, 128)]
for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'):
@@ -1038,25 +1075,37 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# ---------------
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
[(epilogue, allow_tf32, dtype)
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
[(*shape, 4, False, False, epilogue, allow_tf32, dtype)
for shape in [(64, 64, 64)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for allow_tf32 in [True, False]
for dtype in ['float16']
if not (allow_tf32 and (dtype in ['float16']))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
for dtype in ['float16', 'float32']
if not (allow_tf32 and (dtype in ['float16']))] +
[(*shape_nw, col_a, col_b, 'none', allow_tf32, dtype)
for shape_nw in [[128, 256, 32, 8],
[128, 16, 32, 4],
[32, 128, 64, 4],
[128, 128, 64, 4],
[64, 128, 128, 4],
[32, 128, 64, 2],
[128, 128, 64, 2],
[64, 128, 128, 4]]
for allow_tf32 in [True]
for col_a in [True, False]
for col_b in [True, False]
for dtype in ['int8', 'float16', 'float32']])
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, device='cuda'):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if cc < 80:
if capability[0] < 8:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
M, N, K = 128, 128, 64
num_warps = 8
trans_a, trans_b = False, False
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
# triton kernel
@triton.jit
@@ -1068,7 +1117,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
ALLOW_TF32: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
COL_A: tl.constexpr, COL_B: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_l = tl.arange(0, BLOCK_N)
@@ -1077,7 +1126,9 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32)
x = tl.load(Xs)
y = tl.load(Ys)
z = tl.dot(x, y, allow_tf32=ALLOW_TF32)
if ADD_MATRIX:
z += tl.load(Zs)
if ADD_ROWS:
@@ -1093,16 +1144,24 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
den = tl.sum(num, 1)
z = num / den[:, None]
if CHAIN_DOT:
# tl.store(Zs, z)
# tl.debug_barrier()
z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A)
w = tl.load(Ws)
z = tl.dot(z.to(w.dtype), w)
tl.store(Zs, z)
# input
rs = RandomState(17)
x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
if allow_tf32:
if col_a:
x = numpy_random((K, M), dtype_str=dtype, rs=rs).T
else:
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
if col_b:
y = numpy_random((N, K), dtype_str=dtype, rs=rs).T
else:
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
w = numpy_random((N, N), dtype_str=dtype, rs=rs)
if 'int' not in dtype:
x *= .1
y *= .1
if dtype == 'float32' and allow_tf32:
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
@@ -1110,7 +1169,11 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
y_tri = to_triton(y, device=device)
w_tri = to_triton(w, device=device)
# triton result
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
if dtype == 'int8':
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
else:
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
@@ -1118,7 +1181,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
TRANS_A=trans_a, TRANS_B=trans_b,
COL_A=col_a, COL_B=col_b,
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
@@ -1128,9 +1191,12 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
ALLOW_TF32=allow_tf32,
num_warps=num_warps)
# torch result
x_ref = x.T if trans_a else x
y_ref = y.T if trans_b else y
z_ref = np.matmul(x_ref, y_ref)
if dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32),
y.astype(np.float32())).astype(np.int32)
else:
z_ref = np.matmul(x, y)
if epilogue == 'add-matrix':
z_ref += z
if epilogue == 'add-rows':
@@ -1142,17 +1208,21 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
denom = np.sum(num, axis=-1, keepdims=True)
z_ref = num / denom
if epilogue == 'chain-dot':
z_ref = np.matmul(z_ref.T if trans_a else z_ref, w)
z_ref = np.matmul(z_ref, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
if dtype == 'float32':
# XXX: Somehow there's a larger difference when we use float32
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if allow_tf32:
if dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32':
elif dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
@@ -1216,7 +1286,7 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
in_offsets = tl.arange(0, out_size)
# Load inputs.
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1.0)
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
# Store output
output_offsets = tl.arange(0, out_size)
tl.store(out_ptr + output_offsets, x)
@@ -1227,16 +1297,12 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
triton.testing.allclose(output, reference_out)
# 'bfloat16': torch.bfloat16,
# Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
M = 32
@@ -1325,6 +1391,7 @@ def test_vectorization(N):
else:
assert "ld.global.b32" in ptx
# triton.testing.assert_almost_equal(dst, src[:N])
# ---------------
# test store
# ---------------
@@ -1402,6 +1469,10 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
JITFunction.cache_hook = None
assert spec_type == value_type
# --------------------
# value specialization
# --------------------
@pytest.mark.parametrize(
"value, overflow",
@@ -1552,9 +1623,23 @@ def test_num_warps_pow2():
# -------------
def system_libdevice_path() -> str:
_SYSTEM_LIBDEVICE_SEARCH_PATHS = [
'/usr/lib/cuda/nvvm/libdevice/libdevice.10.bc',
'/usr/local/cuda/nvvm/libdevice/libdevice.10.bc',
]
SYSTEM_LIBDEVICE_PATH: Optional[str] = None
for _p in _SYSTEM_LIBDEVICE_SEARCH_PATHS:
if os.path.exists(_p):
SYSTEM_LIBDEVICE_PATH = _p
assert SYSTEM_LIBDEVICE_PATH is not None, \
"Could not find libdevice.10.bc path"
return SYSTEM_LIBDEVICE_PATH
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('int32', 'libdevice.ffs', ''),
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
('float32', 'libdevice.pow', system_libdevice_path()),
('float64', 'libdevice.norm4d', '')])
def test_libdevice_tensor(dtype_str, expr, lib_path):
@@ -1621,3 +1706,95 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# compare
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
# -----------------------
# test layout conversions
# -----------------------
# TODO: backend hsould be tested separately
class MmaLayout:
def __init__(self, version, warps_per_cta):
self.version = version
self.warps_per_cta = str(warps_per_cta)
def __str__(self):
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>"
class BlockedLayout:
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
self.sz_per_thread = str(size_per_thread)
self.threads_per_warp = str(threads_per_warp)
self.warps_per_cta = str(warps_per_cta)
self.order = str(order)
def __str__(self):
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
layouts = [
# MmaLayout(version=1, warps_per_cta=[1, 4]),
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
# MmaLayout(version=1, warps_per_cta=[4, 1]),
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
]
@pytest.mark.parametrize("shape", [(128, 128)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("dst_layout", layouts)
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
if str(src_layout) == str(dst_layout):
pytest.skip()
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
pytest.skip()
ir = f"""
#src = {src_layout}
#dst = {dst_layout}
""" + """
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #src>
%4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<128x1xi32, #src>
%6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src>
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
tt.store %14, %13 : tensor<128x128xf16, #dst>
return
}
}
"""
x = to_triton(numpy_random(shape, dtype_str=dtype))
z = torch.empty_like(x)
# write the IR to a temporary file using mkstemp
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr())
assert torch.equal(z, x)

View File

@@ -1,261 +0,0 @@
# flake8: noqa: F821,F841
import random
import torch
import triton
import triton.language as tl
@triton.jit
def dequantize_kernel_int8(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
mask = w_offsets < (size // 4)
input_ptrs = input_ptr + 1 + w_offsets
input = tl.load(input_ptrs, mask=mask, other=0)
scale_shift = tl.load(input_ptr)
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
output = tl.dequantize(input, scale, shift, 8)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_scale_shift_int8(
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
):
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
mask = w_offsets < (size // 4)
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(scale_ptr)
shift = tl.load(shift_ptr)
output = tl.dequantize(input, scale, shift, 8)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_int4(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = input_ptr + 1 + w_offsets
input = tl.load(input_ptrs, mask=mask, other=0)
scale_shift = tl.load(input_ptr)
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
output = tl.dequantize(input, scale, shift, 4)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_scale_shift_int4(
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(scale_ptr)
shift = tl.load(shift_ptr)
output = tl.dequantize(input, scale, shift, 4)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_int2(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = tl.multiple_of(input_ptr + 2 + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(input_ptr).to(tl.float16, bitcast=True)
shift = tl.load(input_ptr + 1).to(tl.float16, bitcast=True)
output = tl.dequantize(input, scale, shift, 2)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
@triton.jit
def dequantize_kernel_scale_shift_int2(
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
):
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
mask = w_offsets < (size // 8)
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
input = tl.load(input_ptrs, mask=mask, other=0)
scale = tl.load(scale_ptr)
shift = tl.load(shift_ptr)
output = tl.dequantize(input, scale, shift, 2)
offsets = tl.arange(0, BLOCK_SIZE)
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
tl.store(output_ptrs, output, mask=offsets < size)
def test_dequantize_int8() -> None:
for i in range(10):
if i < 5:
size = random.randrange(16, 128, 4)
else:
size = random.randrange(132, 1024, 4)
device = torch.device(torch.cuda.current_device())
scale_val = random.uniform(0.1, 4.0)
shift_val = random.uniform(-10.0, 10.0)
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
scale_shift = torch.tensor(
[scale_val, shift_val],
dtype=torch.float16,
device=device,
).view(torch.int32)
input_int8 = torch.randint(
0, 256, (size,), dtype=torch.uint8, device=device
)
input_int32 = input_int8.view(torch.int32)
input = torch.cat((scale_shift, input_int32))
expected = (input_int8 * scale + shift).to(torch.float16)
output = torch.empty([size], dtype=torch.float16, device=device)
block_size = max(triton.next_power_of_2(size), 128)
grid = (1,)
dequantize_kernel_int8[grid](
output, input, size, BLOCK_SIZE=block_size, num_warps=1
)
rtol, atol = 1e-02, 1e-02
assert torch.allclose(output, expected, rtol, atol)
output = torch.empty([size], dtype=torch.float16, device=device)
dequantize_kernel_scale_shift_int8[grid](
output,
input_int32,
scale,
shift,
size,
BLOCK_SIZE=block_size,
num_warps=1,
)
assert torch.allclose(output, expected, rtol, atol)
def test_dequantize_int4() -> None:
for i in range(10):
if i < 5:
size = random.randrange(16, 256, 8)
else:
size = random.randrange(264, 1024, 8)
device = torch.device(torch.cuda.current_device())
scale_val = random.uniform(0.1, 4.0)
shift_val = random.uniform(-10.0, 10.0)
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
scale_shift = torch.tensor(
[scale_val, shift_val],
dtype=torch.float16,
device=device,
).view(torch.int32)
input_int8 = torch.randint(
0, 256, (size // 2,), dtype=torch.uint8, device=device
)
input_int32 = input_int8.view(torch.int32)
input_int8_h1 = input_int8 >> 4
input_int8_h0 = input_int8 & 15
input_int4_val = torch.stack(
(input_int8_h0, input_int8_h1), dim=1
).flatten()
input = torch.cat((scale_shift, input_int32))
expected = (input_int4_val * scale + shift).to(torch.float16)
output = torch.empty([size], dtype=torch.float16, device=device)
block_size = max(triton.next_power_of_2(size), 256)
grid = (1,)
dequantize_kernel_int4[grid](
output, input, size, BLOCK_SIZE=block_size, num_warps=1
)
rtol, atol = 1e-02, 1e-02
assert torch.allclose(output, expected, rtol, atol)
output = torch.empty([size], dtype=torch.float16, device=device)
dequantize_kernel_scale_shift_int4[grid](
output,
input_int32,
scale,
shift,
size,
BLOCK_SIZE=block_size,
num_warps=1,
)
assert torch.allclose(output, expected, rtol, atol)
def test_dequantize_int2() -> None:
for i in range(10):
if i < 5:
size = random.randrange(16, 256, 8)
else:
size = random.randrange(264, 1024, 8)
device = torch.device(torch.cuda.current_device())
scale_val = random.uniform(0.1, 4.0)
shift_val = random.uniform(-10.0, 10.0)
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
scale_shift = torch.tensor(
[scale_val, shift_val],
dtype=torch.float16,
device=device,
).view(torch.int16)
input_int8 = torch.randint(
0, 256, (size // 4,), dtype=torch.uint8, device=device
)
input_int16 = input_int8.view(torch.int16)
input_int8_q3 = input_int8 >> 6
input_int8_q2 = (input_int8 >> 4) & 3
input_int8_q1 = (input_int8 >> 2) & 3
input_int8_q0 = input_int8 & 3
input_int2_val = torch.stack(
(input_int8_q0, input_int8_q1, input_int8_q2, input_int8_q3), dim=1
).flatten()
input = torch.cat((scale_shift, input_int16))
expected = (input_int2_val * scale + shift).to(torch.float16)
output = torch.empty([size], dtype=torch.float16, device=device)
block_size = max(triton.next_power_of_2(size), 256)
grid = (1,)
dequantize_kernel_int2[grid](
output, input, size, BLOCK_SIZE=block_size, num_warps=1
)
rtol, atol = 1e-02, 1e-02
assert torch.allclose(output, expected, rtol, atol)
output = torch.empty([size], dtype=torch.float16, device=device)
dequantize_kernel_scale_shift_int2[grid](
output,
input_int16,
scale,
shift,
size,
BLOCK_SIZE=block_size,
num_warps=1,
)
assert torch.allclose(output, expected, rtol, atol)

View File

@@ -0,0 +1,22 @@
import os
import subprocess
import sys
dir_path = os.path.dirname(os.path.realpath(__file__))
printf_path = os.path.join(dir_path, "printf_helper.py")
def test_printf():
proc = subprocess.Popen([sys.executable, printf_path], stdout=subprocess.PIPE, shell=False)
(outs, err) = proc.communicate()
outs = outs.split()
new_lines = set()
for line in outs:
try:
value = int(float(line))
new_lines.add(value)
except Exception as e:
print(e)
for i in range(128):
assert i in new_lines
assert len(new_lines) == 128

View File

@@ -2,13 +2,13 @@ import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@pytest.mark.parametrize("TRANS_A", [False, True])
@pytest.mark.parametrize("TRANS_B", [False, True])
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
# TODO: float32 fails
@pytest.mark.parametrize("DTYPE", [torch.float16])
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
seed = 0
@@ -32,9 +32,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# create data
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1)
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1)
dc_ref, dc_tri = triton.testing.make_pair(c_shape)
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE)
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE)
dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE)
# compute [torch]
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
a_ref = do_mask(a_ref) if is_dsd else a_ref
@@ -126,8 +126,8 @@ def test_attention_fwd_bwd(
batch_size=2,
n_heads=2,
):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
# inputs

View File

@@ -2,20 +2,19 @@ import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize("M, N, dtype, mode",
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['bfloat16', 'float16', 'float32']
for dtype in ['float16', 'float32']
for mode in ['forward', 'backward']
]
)
def test_op(M, N, dtype, mode):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and dtype == "bfloat16":
capability = torch.cuda.get_device_capability()
if capability[0] < 8 and dtype == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype]
# create inputs

View File

@@ -4,7 +4,6 @@ import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize(
@@ -67,10 +66,10 @@ import triton._C.libtriton.triton as _triton
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 70:
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if cc < 80 and DTYPE == "bfloat16":
if capability[0] < 8 and DTYPE == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
if DTYPE == "bfloat16" and SPLIT_K != 1:
pytest.skip("bfloat16 matmuls don't allow split_k for now")

View File

@@ -1,15 +1,52 @@
"""isort:skip_file"""
# flake8: noqa: F401
__version__ = '2.0.0'
# ---------------------------------------
# Note: import order is significant here.
# TODO: torch needs to be imported first
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
import torch # noqa: F401
# submodules
from .utils import *
from .runtime import Config, autotune, heuristics, JITFunction, KernelInterface
from . import impl
from .utils import (
cdiv,
MockTensor,
next_power_of_2,
reinterpret,
TensorWrapper,
)
from .runtime import (
autotune,
Config,
heuristics,
JITFunction,
KernelInterface,
)
from .runtime.jit import jit
from .compiler import compile, CompilationError
from . import language
from . import testing
from . import ops
__all__ = [
"autotune",
"cdiv",
"CompilationError",
"compile",
"Config",
"heuristics",
"impl",
"jit",
"JITFunction",
"KernelInterface",
"language",
"MockTensor",
"next_power_of_2",
"ops",
"reinterpret",
"runtime",
"TensorWrapper",
"testing",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,18 @@
"""Triton internal implementation details.
Client libraries should not import interfaces from the `triton.impl` module;
as the details are subject to change.
APIs defined in the `triton.impl` module which are public will be re-exported
in other relevant `triton` module namespaces.
"""
from .base import builtin, extern, is_builtin
from triton._C.libtriton.triton import ir
__all__ = [
"builtin",
"extern",
"ir",
"is_builtin",
]

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from functools import wraps
from typing import TypeVar
T = TypeVar("T")
TRITON_BUILTIN = "__triton_builtin__"
def builtin(fn: T) -> T:
"""Mark a function as a builtin."""
assert callable(fn)
@wraps(fn)
def wrapper(*args, **kwargs):
if "_builder" not in kwargs or kwargs["_builder"] is None:
raise ValueError(
"Did you forget to add @triton.jit ? "
"(`_builder` argument must be provided outside of JIT functions.)"
)
return fn(*args, **kwargs)
setattr(wrapper, TRITON_BUILTIN, True)
return wrapper
def is_builtin(fn) -> bool:
"""Is this a registered triton builtin function?"""
return getattr(fn, TRITON_BUILTIN, False)
def extern(fn: T) -> T:
"""A decorator for external functions."""
return builtin(fn)

View File

@@ -1,4 +1,181 @@
# flake8: noqa: F401
from . import core, extern, libdevice, random
from .core import *
from .random import *
"""isort:skip_file"""
# Import order is significant here.
from ..impl import (
ir,
builtin,
)
from . import libdevice
from .core import (
abs,
arange,
argmin,
argmax,
atomic_add,
atomic_and,
atomic_cas,
atomic_max,
atomic_min,
atomic_or,
atomic_xchg,
atomic_xor,
bfloat16,
block_type,
broadcast,
broadcast_to,
cat,
cdiv,
constexpr,
cos,
debug_barrier,
dot,
dtype,
exp,
fdiv,
float16,
float32,
float64,
float8,
function_type,
int1,
int16,
int32,
int64,
int8,
load,
log,
max,
max_contiguous,
maximum,
min,
minimum,
multiple_of,
num_programs,
pi32_t,
pointer_type,
printf,
program_id,
ravel,
reshape,
sigmoid,
sin,
softmax,
sqrt,
store,
sum,
swizzle2d,
tensor,
trans,
triton,
uint16,
uint32,
uint64,
uint8,
umulhi,
view,
void,
where,
xor_sum,
zeros,
zeros_like,
)
from .random import (
pair_uniform_to_normal,
philox,
philox_impl,
rand,
rand4x,
randint,
randint4x,
randn,
randn4x,
uint32_to_uniform_float,
)
__all__ = [
"abs",
"arange",
"argmin",
"argmax",
"atomic_add",
"atomic_and",
"atomic_cas",
"atomic_max",
"atomic_min",
"atomic_or",
"atomic_xchg",
"atomic_xor",
"bfloat16",
"block_type",
"broadcast",
"broadcast_to",
"builtin",
"cat",
"cdiv",
"constexpr",
"cos",
"debug_barrier",
"dot",
"dtype",
"exp",
"fdiv",
"float16",
"float32",
"float64",
"float8",
"function_type",
"int1",
"int16",
"int32",
"int64",
"int8",
"ir",
"libdevice",
"load",
"log",
"max",
"max_contiguous",
"maximum",
"min",
"minimum",
"multiple_of",
"num_programs",
"pair_uniform_to_normal",
"philox",
"philox_impl",
"pi32_t",
"pointer_type",
"printf",
"program_id",
"rand",
"rand4x",
"randint",
"randint4x",
"randn",
"randn4x",
"ravel",
"reshape",
"sigmoid",
"sin",
"softmax",
"sqrt",
"store",
"sum",
"swizzle2d",
"tensor",
"trans",
"triton",
"uint16",
"uint32",
"uint32_to_uniform_float",
"uint64",
"uint8",
"umulhi",
"view",
"void",
"where",
"xor_sum",
"zeros",
"zeros_like",
]

View File

@@ -1,13 +1,14 @@
from __future__ import annotations
from enum import Enum
from functools import wraps
from typing import List
from typing import Callable, List, TypeVar
import triton
from . import semantic
from . import builtin, semantic
from triton._C.libtriton.triton import ir
T = TypeVar('T')
def _to_tensor(x, builder):
if isinstance(x, bool):
@@ -17,41 +18,28 @@ def _to_tensor(x, builder):
if -2**31 <= x < 2**31:
return tensor(builder.get_int32(x), int32)
elif 2**31 <= x < 2**32:
return tensor(builder.get_uint32(x), uint32)
return tensor(builder.get_int32(x), uint32)
elif -2**63 <= x < 2**63:
return tensor(builder.get_int64(x), int64)
elif 2**63 <= x < 2**64:
return tensor(builder.get_uint64(x), uint64)
return tensor(builder.get_int64(x), uint64)
else:
raise RuntimeError(f'Nonrepresentable integer {x}.')
elif isinstance(x, float):
return tensor(builder.get_float32(x), float32)
elif isinstance(x, constexpr):
if x.value is None:
return None
return _to_tensor(x.value, builder)
elif isinstance(x, tensor):
return x
elif x is None:
return None
assert False, f'cannot convert {x} to tensor'
def builtin(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if '_builder' not in kwargs or \
kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return fn(*args, **kwargs)
return wrapper
class dtype:
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64']
CUSTOMIZED_FP_TYPES = ['fp8']
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
OTHER_TYPES = ['void']
class SIGNEDNESS(Enum):
@@ -133,6 +121,12 @@ class dtype:
def is_floating(self):
return self.name in dtype.FP_TYPES
def is_customized_floating(self):
return self.name in dtype.CUSTOMIZED_FP_TYPES
def is_standard_floating(self):
return self.name in dtype.STANDARD_FP_TYPES
def is_int_signed(self):
return self.name in dtype.SINT_TYPES
@@ -146,7 +140,7 @@ class dtype:
return self.is_int1()
def is_void(self):
return self.name == 'void'
raise RuntimeError("Not implemented")
def is_block(self):
return False
@@ -216,7 +210,7 @@ class pointer_type(dtype):
self.name = self.__str__()
def to_ir(self, builder: ir.builder) -> ir.pointer_type:
return ir.type.make_ptr(self.element_ty.to_ir(builder), 1)
return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1)
def __str__(self):
return f'pointer<{self.element_ty}>'
@@ -241,22 +235,27 @@ class pointer_type(dtype):
class block_type(dtype):
def __init__(self, element_ty: dtype, shape: List[int]):
def __init__(self, element_ty: dtype, shape: List):
self.element_ty = element_ty
# FIXME:
# block_type's shape is a list of int
# while tensor's shape is a list of constexpr
# Note that block_type's shape is a list of int
# while tensor's shape is a list of constexpr.
# shape can be empty ([]) when an input is a 0D tensor.
if not shape:
raise TypeError('0d block_type is forbidden')
if isinstance(shape[0], constexpr):
shape = [s.value for s in shape]
self.shape = shape
self.numel = 1
for i, s in enumerate(self.shape):
if isinstance(s, constexpr):
self.shape[i] = s.value
self.numel *= self.shape[i]
for s in self.shape:
self.numel *= s
self.name = self.__str__()
def to_ir(self, builder: ir.builder) -> ir.block_type:
return ir.type.make_block(self.element_ty.to_ir(builder), self.shape)
return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape)
def __str__(self):
return f'<{self.shape}, {self.element_ty}>'
@@ -284,28 +283,17 @@ class block_type(dtype):
class function_type(dtype):
def __init__(self, ret_type: dtype, param_types: List[dtype]) -> None:
self.ret_type = ret_type
def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None:
self.ret_types = ret_types
self.param_types = param_types
def __str__(self):
return f'fn ({self.param_types}) -> {self.ret_type}'
return f'fn ({self.param_types}) -> {self.ret_types}'
def to_ir(self, builder: ir.builder):
ir_param_types = [ty.to_ir(builder) for ty in self.param_types]
return ir.type.make_function(self.ret_type.to_ir(builder), ir_param_types)
class tuple_type(dtype):
def __init__(self, element_types: List[dtype]) -> None:
self.element_types = element_types
def __str__(self):
return f'<{self.element_types}>'
def to_ir(self, builder: ir.builder):
ir_element_types = [ty.to_ir(builder) for ty in self.element_types]
return ir.struct_type.get(ir_element_types, True)
ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types]
return builder.get_function_ty(ir_param_types, ret_types)
# scalar types
@@ -346,83 +334,96 @@ class constexpr:
def __repr__(self) -> str:
return f"constexpr[{self.value}]"
def __add__(self, other):
return constexpr(self.value + other.value)
def __radd__(self, other):
return constexpr(other.value + self.value)
def __sub__(self, other):
return constexpr(self.value - other.value)
def __rsub__(self, other):
return constexpr(other.value - self.value)
def __mul__(self, other):
return constexpr(self.value * other.value)
def __mod__(self, other):
return constexpr(self.value % other.value)
def __rmul__(self, other):
return constexpr(other.value * self.value)
def __truediv__(self, other):
return constexpr(self.value / other.value)
def __rtruediv__(self, other):
return constexpr(other.value / self.value)
def __floordiv__(self, other):
return constexpr(self.value // other.value)
def __rfloordiv__(self, other):
return constexpr(other.value // self.value)
def __gt__(self, other):
return constexpr(self.value > other.value)
def __rgt__(self, other):
return constexpr(other.value > self.value)
def __ge__(self, other):
return constexpr(self.value >= other.value)
def __rge__(self, other):
return constexpr(other.value >= self.value)
def __lt__(self, other):
return constexpr(self.value < other.value)
def __rlt__(self, other):
return constexpr(other.value < self.value)
def __le__(self, other):
return constexpr(self.value <= other.value)
def __rle__(self, other):
return constexpr(other.value <= self.value)
def __eq__(self, other):
return constexpr(self.value == other.value)
def __ne__(self, other):
return constexpr(self.value != other.value)
def __bool__(self):
return bool(self.value)
def __ge__(self, other):
other = other.value if isinstance(other, constexpr) else other
return self.value >= other
def __neg__(self):
return constexpr(-self.value)
def __gt__(self, other):
other = other.value if isinstance(other, constexpr) else other
return self.value > other
def __pos__(self):
return constexpr(+self.value)
def __le__(self, other):
other = other.value if isinstance(other, constexpr) else other
return self.value <= other
def __lt__(self, other):
other = other.value if isinstance(other, constexpr) else other
return self.value < other
def __eq__(self, other):
other = other.value if isinstance(other, constexpr) else other
return self.value == other
def __invert__(self):
return constexpr(~self.value)
def __call__(self, *args, **kwds):
return self.value(*args, **kwds)
def to(self, dtype, bitcast=False, _builder=None):
if dtype in [float8, float16, bfloat16]:
raise ValueError("floating point constexpr must be float64")
if dtype.is_int():
ret_ty = int
elif dtype.is_bool():
ret_ty = bool
elif dtype.is_floating():
ret_ty = float
return constexpr(ret_ty(self.value))
class tensor:
# infer dtype from ir type
@staticmethod
def _to_dtype(ir_type):
# block type
if ir_type.is_block():
scalar_ty = tensor._to_dtype(ir_type.scalar)
return block_type(scalar_ty, ir_type.get_block_shapes())
# pointer type
if ir_type.is_ptr():
element_ty = tensor._to_dtype(ir_type.element)
return pointer_type(element_ty)
# primitive type
if ir_type.is_void(): return void
if ir_type.is_int1(): return int1
if ir_type.is_int8(): return int8
if ir_type.is_int16(): return int16
if ir_type.is_int32(): return int32
if ir_type.is_int64(): return int64
if ir_type.is_fp8(): return float8
if ir_type.is_fp16(): return float16
if ir_type.is_bf16(): return bfloat16
if ir_type.is_fp32(): return float32
if ir_type.is_fp64(): return float64
raise ValueError(f"Unsupported type {ir_type.repr()}")
def __init__(self, handle, type: dtype):
# IR handle
self.handle = handle
# Block shape
self.shape = (1, )
if self.handle.type.is_block():
self.shape = self.handle.type.shape
if type.is_block():
self.shape = type.shape
self.numel = 1
for s in self.shape:
self.numel *= s
is_pow2 = (self.numel and (not (self.numel & (self.numel - 1))))
if not is_pow2:
raise ValueError("Triton tensors must have a power-of-two number of elements")
self.numel = constexpr(self.numel)
self.type = type # Tensor type (can be block_type)
# Following the practice in pytorch, dtype is scalar type
@@ -580,22 +581,34 @@ class tensor:
other = _to_tensor(other, _builder)
return semantic.not_equal(self, other, _builder)
@builtin
def logical_and(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.logical_and(self, other, _builder)
@builtin
def logical_or(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.logical_or(self, other, _builder)
@builtin
def __getitem__(self, slices, _builder=None):
if isinstance(slices, slice):
slices = [slices]
src_shape = self.shape
dst_shape = []
curr = 0
for sl in slices:
ret = self
for dim, sl in enumerate(slices):
if isinstance(sl, constexpr) and sl.value is None:
dst_shape.append(1)
ret = semantic.expand_dims(ret, dim, _builder)
elif sl == slice(None, None, None):
dst_shape.append(src_shape[curr].value)
curr += 1
ret = semantic.reshape(self, dst_shape, _builder)
pass
else:
assert False, "unsupported"
return ret
@property
def T(self):
assert False, "Transposition must be created by the AST Visitor"
@builtin
def to(self, dtype, bitcast=False, _builder=None):
if isinstance(bitcast, constexpr):
@@ -685,20 +698,6 @@ def zeros(shape, dtype, _builder=None):
return semantic.zeros(shape, dtype, _builder)
# -----------------------
# dequantize
# -----------------------
@builtin
def dequantize(input, scale, shift, nbit, dst_ty=float16, _builder=None):
"""
Tries to dequantize the input to given dtype
"""
nbit = _constexpr_to_value(nbit)
return semantic.dequantize(input, scale, shift, nbit, dst_ty, _builder)
# -----------------------
# Shape Manipulation
# -----------------------
@@ -731,7 +730,12 @@ def broadcast_to(input, shape, _builder=None):
@builtin
def cat(input, other, _builder=None):
def trans(input, _builder=None):
return semantic.trans(input, _builder)
@builtin
def cat(input, other, can_reorder=False, _builder=None):
"""
Concatenate the given blocks
@@ -739,14 +743,19 @@ def cat(input, other, _builder=None):
:type input:
:param other: The second input tensor.
:type other:
:param reorder: Compiler hint. If true, the compiler is
allowed to reorder elements while concatenating inputs.
Only use if the order does not matter (e.g., result is
only used in reduction ops)
"""
return semantic.cat(input, other, _builder)
return semantic.cat(input, other, can_reorder, _builder)
@builtin
def reshape(input, shape, _builder=None):
def view(input, shape, _builder=None):
"""
Tries to reshape the given tensor to a new shape.
Returns a tensor with the same elements as `input` but a different shape.
The order of the elements may not be preserved.
:param input: The input tensor.
:type input:
@@ -755,20 +764,26 @@ def reshape(input, shape, _builder=None):
"""
shape = [x.value for x in shape]
return semantic.reshape(input, shape, _builder)
return semantic.view(input, shape, _builder)
@builtin
def reshape(input, shape, _builder=None):
# TODO: should be more than just a view
shape = [x.value for x in shape]
return semantic.view(input, shape, _builder)
# -----------------------
# Linear Algebra
# -----------------------
@builtin
def dot(input, other, trans_a=False, trans_b=False, allow_tf32=True, _builder=None):
def dot(input, other, allow_tf32=True, _builder=None):
"""
Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions.
The two blocks must be two-dimensional and have compatible inner dimensions.
:param input: The first tensor to be multiplied.
:type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
@@ -776,7 +791,7 @@ def dot(input, other, trans_a=False, trans_b=False, allow_tf32=True, _builder=No
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
"""
allow_tf32 = _constexpr_to_value(allow_tf32)
return semantic.dot(input, other, trans_a, trans_b, allow_tf32, _builder)
return semantic.dot(input, other, allow_tf32, _builder)
# -----------------------
@@ -814,7 +829,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
@builtin
def store(pointer, value, mask=None, eviction_policy="", _builder=None):
def store(pointer, value, mask=None, _builder=None):
"""
Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
@@ -831,46 +846,27 @@ def store(pointer, value, mask=None, eviction_policy="", _builder=None):
value = _to_tensor(value, _builder)
if mask is not None:
mask = _to_tensor(mask, _builder)
return semantic.store(pointer, value, mask, eviction_policy, _builder)
return semantic.store(pointer, value, mask, _builder)
# -----------------------
# Atomic Memory Operations
# -----------------------
@builtin
def atomic_cas(pointer, cmp, val, _builder=None):
"""
Performs an atomic compare-and-swap at the memory location specified by :code:`pointer`.
def _add_atomic_docstr(name: str) -> Callable[[T], T]:
Return the data stored at :code:`pointer` before the atomic operation.
:param pointer: The memory locations to compare-and-swap.
:type pointer: Block of dtype=triton.PointerDType
:param cmp: The values expected to be found in the atomic object
:type cmp: Block of dtype=`pointer.dtype.element_ty`
:param val: The values to copy in case the expected value matches the contained value.
:type val: Block of dtype=`pointer.dtype.element_ty`
"""
cmp = _to_tensor(cmp, _builder)
val = _to_tensor(val, _builder)
return semantic.atomic_cas(pointer, cmp, val, _builder)
def _add_atomic_docstr(name):
def _decorator(func):
def _decorator(func: T) -> T:
docstr = """
Performs an atomic {name} at the memory location specified by :code:`pointer`.
Return the data stored at :code:`pointer` before the atomic operation.
:param pointer: The memory locations to apply {name}.
:param pointer: The memory locations to compare-and-swap.
:type pointer: Block of dtype=triton.PointerDType
:param val: The values to {name} in the atomic object.
:param cmp: The values expected to be found in the atomic object
:type cmp: Block of dtype=`pointer.dtype.element_ty`
:param val: The values to copy in case the expected value matches the contained value.
:type val: Block of dtype=`pointer.dtype.element_ty`
:param mask: If mask[idx] is false, do not apply {name}.
:type mask: Block of triton.int1, optional
"""
func.__doc__ = docstr.format(name=name)
return func
@@ -878,6 +874,14 @@ def _add_atomic_docstr(name):
return _decorator
@builtin
@_add_atomic_docstr("compare-and-swap")
def atomic_cas(pointer, cmp, val, _builder=None):
cmp = _to_tensor(cmp, _builder)
val = _to_tensor(val, _builder)
return semantic.atomic_cas(pointer, cmp, val, _builder)
@builtin
@_add_atomic_docstr("exchange")
def atomic_xchg(pointer, val, mask=None, _builder=None):
@@ -972,9 +976,9 @@ def fdiv(x, y, ieee_rounding=False, _builder=None):
return semantic.fdiv(x, y, ieee_rounding, _builder)
def _add_math_1arg_docstr(name):
def _add_math_1arg_docstr(name: str) -> Callable[[T], T]:
def _decorator(func):
def _decorator(func: T) -> T:
docstr = """
Computes the element-wise {name} of :code:`x`
@@ -1021,9 +1025,9 @@ def sqrt(x, _builder=None):
# Reductions
# -----------------------
def _add_reduction_docstr(name):
def _add_reduction_docstr(name: str) -> Callable[[T], T]:
def _decorator(func):
def _decorator(func: T) -> T:
docstr = """
Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
@@ -1077,19 +1081,6 @@ def xor_sum(input, axis, _builder=None):
axis = _constexpr_to_value(axis)
return semantic.xor_sum(input, axis, _builder)
# -----------------------
# Utilities
# -----------------------
@builtin
def globaltimer(_builder=None):
return semantic.globaltimer(_builder)
@builtin
def clock(_builder=None):
return semantic.clock(_builder)
# -----------------------
# Internal for debugging
@@ -1189,7 +1180,7 @@ def sigmoid(x):
@triton.jit
@_add_math_1arg_docstr("softmax")
def softmax(x, ieee_rounding: constexpr = False):
def softmax(x, ieee_rounding=False):
z = x - triton.language.max(x, 0)
num = triton.language.exp(z)
den = triton.language.sum(num, 0)
@@ -1204,13 +1195,13 @@ def ravel(x):
:param x: the input tensor
:type x: Block
"""
return triton.language.reshape(x, [x.numel])
return triton.language.view(x, [x.numel])
@triton.jit
def swizzle2d(i, j, size_i, size_j, size_g):
"""
transformes indices of a row-major size_i*size_j matrix into those
Transforms indices of a row-major size_i*size_j matrix into those
of one where indices are row major for each group of size_j rows.
For example, for size_i = size_j = 4 and size_g = 2, it will transform
[[0 , 1 , 2 , 3 ],
@@ -1243,26 +1234,22 @@ def swizzle2d(i, j, size_i, size_j, size_g):
@triton.jit
def zeros_like(input):
return zeros(input.shape, input.dtype)
# -----------------------
# Dynamic Parallelism
# -----------------------
# class LaunchProxy:
# def __init__(self, fn, args, constants, grid, num_warps) -> None:
# self.args = args
# self.grid = grid
# self.constants = constants
# self.num_warps = num_warps
# self.fn = fn
# @builtin
# def launch(fn, args, grid, num_warps=None, _builder=None):
# constants = {i: x for i, x in enumerate(args) if isinstance(x, constexpr)}
# args = [_to_ir(x, builder=_builder) for x in args if not isinstance(x, constexpr)]
# grid = [_to_ir(x, builder=_builder) for x in grid]
# if num_warps is None:
# num_warps = _to_ir(4, builder=_builder)
# return LaunchProxy(fn, args, constants, grid, num_warps)
@builtin
def printf(prefix, *args, _builder=None):
import string
new_prefix = prefix
if isinstance(prefix, constexpr):
new_prefix = prefix.value
assert isinstance(new_prefix, str), f"{new_prefix} is not string"
b_ascii = True
for ch in new_prefix:
if ch not in string.printable:
b_ascii = False
break
assert b_ascii, f"{new_prefix} is not an ascii string"
new_args = []
for arg in args:
new_args.append(_to_tensor(arg, _builder))
return semantic.printf(new_prefix, new_args, _builder)

View File

@@ -6,7 +6,6 @@ from . import core, semantic
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, _builder=None):
'''
Dispatch a function to a library
:param func: the function to dispatch
:param lib_name: the name of the library
:param lib_path: the path of the library
@@ -14,7 +13,6 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
:param arg_type_symbol_dict: the type of the arguments
:param ret_shape: the shape of the return value
:param _builder: the builder
:return: the return value of the function
'''
if len(arg_type_symbol_dict) == 0:
@@ -42,20 +40,19 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
else:
symbol = arg_type_symbol_dict[arg_types][0]
ret_type = arg_type_symbol_dict[arg_types][1]
ret_type = core.block_type(ret_type, ret_shape) if ret_shape is not None else ret_type
if ret_shape:
ret_type = core.block_type(ret_type, ret_shape)
return core.tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder)), ret_type)
def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, _builder=None):
'''
Dispatch an elementwise function to a library
:param lib_name: the name of the library
:param lib_path: the path of the library
:param args: the arguments of the function
:param arg_type_symbol_dict: the type of the arguments
:param _builder: the builder
:return: the return value of the function
'''
dispatch_args = args.copy()
@@ -87,27 +84,5 @@ def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict:
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder)
ret_shape = broadcast_arg.shape
func = getattr(_builder, "create_extern_elementwise")
func = getattr(_builder, "create_external_elementwise")
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
class ExternalFunction:
'''
A wrapper for external functions
'''
def __init__(self, fn):
self.fn = fn
def __call__(self, *args, **kwargs):
if '_builder' not in kwargs or \
kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return self.fn(*args, **kwargs)
def extern(fn):
'''
A decorator for external functions
'''
return ExternalFunction(fn)

BIN
python/triton/language/libdevice.10.bc Normal file → Executable file

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,10 @@
import triton
from . import core as tl
PHILOX_KEY_A: tl.constexpr = -1640531527 # 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = -1150833019 # 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = -766435501 # 0xD2511F53
PHILOX_ROUND_B: tl.constexpr = -845247145 # 0xCD9E8D57
PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
PHILOX_ROUND_A: tl.constexpr = 0xD2511F53
PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
# -------------------

View File

@@ -7,12 +7,12 @@ from triton._C.libtriton.triton import ir
# Create custom exception that prints message "hello"
class IncompatibleTypeErrorimpl(Exception):
class IncompatibleTypeErrorImpl(Exception):
def __init__(self, type_a, type_b):
self.type_a = type_a
self.type_b = type_b
self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__()
super(IncompatibleTypeErrorimpl, self).__init__(self.message)
super(IncompatibleTypeErrorImpl, self).__init__(self.message)
# ===----------------------------------------------------------------------===##
@@ -88,13 +88,13 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t
def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
if type_a.is_ptr():
if not allow_ptr_a:
raise IncompatibleTypeErrorimpl(type_a, type_b)
raise IncompatibleTypeErrorImpl(type_a, type_b)
# T* + U* with T != U
if type_b.is_ptr() and (type_a != type_b):
raise IncompatibleTypeErrorimpl(type_a, type_b)
raise IncompatibleTypeErrorImpl(type_a, type_b)
# T* + float
if type_b.is_floating():
raise IncompatibleTypeErrorimpl(type_a, type_b)
raise IncompatibleTypeErrorImpl(type_a, type_b)
def binary_op_type_checking_impl(lhs: tl.tensor,
@@ -129,7 +129,7 @@ def add(input: tl.tensor,
if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
input, other = other, input
if input_scalar_ty.is_ptr():
return tl.tensor(builder.create_gep(input.handle, [other.handle]), input.type)
return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type)
# float + float
elif input_scalar_ty.is_floating():
return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
@@ -146,7 +146,7 @@ def sub(input: tl.tensor,
scalar_ty = input.type.scalar
# ptr - offset
if scalar_ty.is_ptr():
return tl.tensor(builder.create_gep(input.handle, [minus(other, builder).handle]),
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle),
input.type)
# float - float
if scalar_ty.is_floating():
@@ -223,10 +223,9 @@ def fdiv(input: tl.tensor,
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
raise ValueError("both operands of fdiv must have floating poscalar type")
raise ValueError("both operands of fdiv must have floating scalar type")
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
ret = builder.create_fdiv(input.handle, other.handle)
ret.set_fdiv_ieee_rounding(ieee_rounding)
return tl.tensor(ret, input.type)
@@ -238,7 +237,11 @@ def mod(input: tl.tensor,
other_scalar_ty = other.type.scalar
# float % float
if scalar_ty.is_floating():
return tl.tensor(builder.create_frem(input.handle, other.handle), input.type)
# input - input.div(other, rounding_mode="floor") * other
ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder),
other, builder),
builder)
return ret
# % int
elif scalar_ty.is_int():
if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
@@ -263,7 +266,7 @@ def bitwise_op_type_checking_impl(input: tl.tensor,
input_sca_ty = input.type.scalar
other_sca_ty = other.type.scalar
if not input_sca_ty.is_int() or not other_sca_ty.is_int():
raise IncompatibleTypeErrorimpl(input_sca_ty, other_sca_ty)
raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty)
if ret_sca_ty != input_sca_ty:
input = cast(input, ret_sca_ty, builder)
@@ -293,6 +296,22 @@ def xor_(input: tl.tensor,
return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
if not input.type.is_int1():
input = bitcast(input, tl.dtype("int1"), builder)
if not other.type.is_int1():
other = bitcast(other, tl.dtype("int1"), builder)
return and_(input, other, builder)
def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
if not input.type.is_int1():
input = bitcast(input, tl.dtype("int1"), builder)
if not other.type.is_int1():
other = bitcast(other, tl.dtype("int1"), builder)
return or_(input, other, builder)
def lshr(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
@@ -320,7 +339,7 @@ def minus(input: tl.tensor,
input_sca_ty = input.type.scalar
if input_sca_ty.is_ptr():
raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
_0 = tl.tensor(ir.constant.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
_0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
return sub(_0, input, builder)
@@ -329,7 +348,7 @@ def invert(input: tl.tensor,
input_sca_ty = input.type.scalar
if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
_1 = tl.tensor(ir.constant.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
_1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
return xor_(input, _1, builder)
@@ -449,11 +468,11 @@ def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
shape = [end - start]
ret_ty = tl.block_type(tl.int32, shape)
return tl.tensor(builder.get_range(start, end), ret_ty)
return tl.tensor(builder.create_make_range(start, end), ret_ty)
def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
_0 = ir.constant.get_null_value(dtype.to_ir(builder))
_0 = builder.get_null_value(dtype.to_ir(builder))
ret_ty = tl.block_type(dtype, shape)
return tl.tensor(builder.create_splat(_0, shape), ret_ty)
@@ -462,24 +481,40 @@ def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
# ===----------------------------------------------------------------------===//
def reshape(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
def view(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
# TODO: disable when TritonToTritonGPU handles views properly
# assert len(input.shape) == len(dst_shape)
numel = 1
for s in dst_shape:
numel *= s
if input.type.numel != numel:
raise ValueError("cannot reshape block of different shape")
raise ValueError("cannot view block of different shape")
ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty)
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
assert lhs.type.shape[1:] == rhs.type.shape[1:]
ret_shape = [lhs.type.shape[0] + rhs.type.shape[0]]
ret_ty = tl.block_type(lhs.type.scalar, ret_shape)
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_ty)
def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
dst_shape = [s for s in input.type.shape]
dst_shape.insert(axis, 1)
ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor:
assert can_reorder, "current implementation of `cat` always may reorder elements"
assert len(lhs.shape) == 1
ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type)
def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
if len(input.shape) != 2:
raise ValueError("Only 2D tensors can be transposed")
ret_type = tl.block_type(input.type.scalar, [input.shape[1], input.shape[0]])
return tl.tensor(builder.create_trans(input.handle), ret_type)
def broadcast_impl_shape(input: tl.tensor,
@@ -496,7 +531,7 @@ def broadcast_impl_shape(input: tl.tensor,
for i in range(len(src_shape)):
if shape[i] != src_shape[i] and src_shape[i] != 1:
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
f" must match the existing size ({src_shape[1]}) at non-singleton dimension"
f" must match the existing size ({src_shape[i]}) at non-singleton dimension"
f" {i}: {src_shape}, {shape}")
ret_ty = tl.block_type(input.type.scalar, shape)
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
@@ -520,8 +555,21 @@ def broadcast_impl_value(lhs: tl.tensor,
elif lhs_ty.is_block() and rhs_ty.is_block():
lhs_shape = lhs_ty.get_block_shapes()
rhs_shape = rhs_ty.get_block_shapes()
if len(lhs_shape) != len(rhs_shape):
raise ValueError("Cannot make_shape_compatible: blocks must have the same rank")
if len(lhs_shape) < len(rhs_shape):
# Add new axes to lhs
for dim in range(len(lhs_shape), len(rhs_shape)):
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, dim), tl.block_type(lhs_ty.scalar, lhs_shape + [1]))
lhs_ty = lhs.type
lhs_shape = lhs_ty.get_block_shapes()
elif len(rhs_shape) < len(lhs_shape):
# Add new axes to rhs
for dim in range(len(rhs_shape), len(lhs_shape)):
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, dim), tl.block_type(rhs_ty.scalar, rhs_shape + [1]))
rhs_ty = rhs.type
rhs_shape = rhs_ty.get_block_shapes()
assert len(rhs_shape) == len(lhs_shape)
ret_shape = []
for i in range(len(lhs_shape)):
left = lhs_shape[i]
@@ -544,31 +592,6 @@ def broadcast_impl_value(lhs: tl.tensor,
# (scalar, scalar) => returns original blocks
return lhs, rhs
#######
# dequantize
#######
def dequantize(input: tl.tensor,
scale: tl.tensor,
shift: tl.tensor,
nbit: int,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
input_ty = input.type
assert input_ty.is_block()
assert input_ty.element_ty.is_int32() or input_ty.element_ty.is_int16()
assert nbit in [2, 4, 8]
assert dst_ty == tl.float16
shape = input_ty.get_block_shapes()
factor = input_ty.element_ty.primitive_bitwidth // nbit
dst_shape = shape[:-1] + [factor * shape[-1]]
dst_ty = tl.block_type(dst_ty, dst_shape)
return tl.tensor(builder.create_dequantize(input.handle, scale.handle, shift.handle, dst_ty.to_ir(builder)), dst_ty)
#######
# cast
#######
@@ -600,62 +623,78 @@ def cast(input: tl.tensor,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
src_ty = input.type
if src_ty.is_block() and not dst_ty.is_block():
if src_ty.is_block():
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
if src_ty == dst_ty:
return input
src_sca_ty = src_ty.scalar
dst_sca_ty = dst_ty.scalar
# fp8 <=> bf16/fp16
if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8():
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)),
dst_ty)
if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()):
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)),
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
if (src_sca_ty.is_customized_floating() and dst_sca_ty.is_floating()) or \
(src_sca_ty.is_floating() and dst_sca_ty.is_customized_floating()):
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
dst_ty)
# bf16 <=> (not fp32)
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
(src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()):
return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
# FP Truncation
# Standard floating types' casting: truncation
# fp64 => fp32, fp16, bf16
# fp32 => fp16, bf16
truncate_fp = src_sca_ty.is_floating() and \
dst_sca_ty.is_floating() and \
src_sca_ty.fp_mantissa_width > dst_sca_ty.fp_mantissa_width
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
if truncate_fp:
return tl.tensor(builder.create_fp_trunc(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
# FP Extension
# Standard floating types' casting: extension
# fp32 => fp64
# fp16 => fp32, fp64
# bf16 => fp32, fp64
ext_fp = src_sca_ty.is_floating() and \
dst_sca_ty.is_floating() and \
src_sca_ty.fp_mantissa_width < dst_sca_ty.fp_mantissa_width
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
if ext_fp:
return tl.tensor(builder.create_fp_ext(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
# Int cast
# Casting between integer types
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
return tl.tensor(builder.create_int_cast(input.handle,
dst_ty.to_ir(builder), sign_extend),
dst_ty)
# Float to Int
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
# TODO: is this correct?
if dst_sca_ty.is_bool():
return not_equal(input, tl._to_tensor(0, builder), builder)
ty = input.dtype.to_ir(builder)
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
return not_equal(input, _0, builder)
else:
return tl.tensor(builder.create_int_cast(input.handle,
dst_ty.to_ir(builder), sign_extend),
dst_ty)
# Casting standard floating types to integer types
if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
if dst_sca_ty.is_bool():
ty = input.dtype.to_ir(builder)
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
return not_equal(input, _0, builder)
elif dst_sca_ty.is_int_signed():
return tl.tensor(builder.create_fp_to_si(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
else:
return tl.tensor(builder.create_fp_to_ui(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
# int => float
if src_sca_ty.is_int() and dst_sca_ty.is_floating():
# Casting integer types to standard floating types
if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
return tl.tensor(builder.create_ui_to_fp(input.handle,
dst_ty.to_ir(builder)),
@@ -665,7 +704,7 @@ def cast(input: tl.tensor,
dst_ty.to_ir(builder)),
dst_ty)
# ptr => int
# Casting pointer types to integer types
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
bitwidth = dst_sca_ty.int_bitwidth
if bitwidth == 64:
@@ -676,19 +715,14 @@ def cast(input: tl.tensor,
tl.tensor(builder.get_int64(0), tl.int64),
builder)
if not src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
# Casting integer types to pointer types
if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty)
# Ptr . Ptr
# Casting pointer types to pointer types
if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
# * . Bool
if dst_sca_ty.is_bool():
if src_sca_ty.is_ptr():
input = cast(input, tl.int64, builder)
other = builder.get_int64(0)
if src_ty.is_bool():
other = builder.create_splat(other, src_ty.get_block_shapes())
return tl.tensor(builder.create_icmpNE(input.handle, other), dst_ty)
assert False, f'cannot cast {input} to {dst_ty}'
# ===----------------------------------------------------------------------===//
@@ -696,18 +730,6 @@ def cast(input: tl.tensor,
# ===----------------------------------------------------------------------===//
def _parse_eviction_policy(eviction_policy):
eviction = ir.EVICTION_POLICY.NORMAL # default
if eviction_policy:
if eviction_policy == "evict_last":
eviction = ir.EVICTION_POLICY.EVICT_LAST
elif eviction_policy == "evict_first":
eviction = ir.EVICTION_POLICY.EVICT_FIRST
else:
raise ValueError(f"Eviction policy {eviction_policy} not supported")
return eviction
def load(ptr: tl.tensor,
mask: Optional[tl.tensor],
other: Optional[tl.tensor],
@@ -723,16 +745,18 @@ def load(ptr: tl.tensor,
if other:
other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)
if other:
other = cast(other, ptr.type.scalar.element_ty, builder)
ptr_ty = ptr.type.scalar
elt_ty = ptr_ty.element_ty
# treat bool* as tl.int8*
if elt_ty == tl.int1:
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)
if other:
other = cast(other, elt_ty, builder)
# cache modifier
cache = ir.CACHE_MODIFIER.NONE # default
if cache_modifier:
@@ -744,7 +768,14 @@ def load(ptr: tl.tensor,
raise ValueError(f"Cache modifier {cache_modifier} not supported")
# eviction policy
eviction = _parse_eviction_policy(eviction_policy)
eviction = ir.EVICTION_POLICY.NORMAL # default
if eviction_policy:
if eviction_policy == "evict_last":
eviction = ir.EVICTION_POLICY.EVICT_LAST
elif eviction_policy == "evict_first":
eviction = ir.EVICTION_POLICY.EVICT_FIRST
else:
raise ValueError(f"Eviction policy {eviction_policy} not supported")
if ptr.type.is_block():
shape = ptr.type.get_block_shapes()
@@ -752,29 +783,22 @@ def load(ptr: tl.tensor,
else:
dst_ty = elt_ty
if not mask and not other:
if not mask:
if other:
raise ValueError("`other` cannot be provided without `mask`")
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile),
dst_ty)
if not mask:
raise ValueError("`other` cannot be provided without `mask`")
if not other:
other_ir = ir.undef.get(elt_ty.to_ir(builder))
if ptr.type.is_block():
other_ir = builder.create_splat(other_ir, ptr.type.get_block_shapes())
other = tl.tensor(other_ir, dst_ty)
return tl.tensor(builder.create_masked_load(ptr.handle,
mask.handle,
other.handle,
cache, eviction, is_volatile),
dst_ty)
else:
return tl.tensor(builder.create_masked_load(ptr.handle,
mask.handle,
other.handle if other else None,
cache, eviction, is_volatile),
dst_ty)
def store(ptr: tl.tensor,
val: tl.tensor,
mask: Optional[tl.tensor],
eviction_policy: str,
builder: ir.builder) -> tl.tensor:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
@@ -786,20 +810,17 @@ def store(ptr: tl.tensor,
elt_ty = ptr_ty.element_ty
# treat bool* as tl.int8*
if elt_ty == tl.int1:
# convert to bool first and then store as int8
val = cast(val, tl.int1, builder)
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)
# eviction policy
eviction = _parse_eviction_policy(eviction_policy)
# cast to target data-type
val = cast(val, elt_ty, builder)
if not mask:
return tl.tensor(builder.create_store(ptr.handle, val.handle, eviction), tl.void)
return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void)
if not mask.type.scalar.is_bool():
raise ValueError("Mask must have boolean scalar type")
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, eviction), tl.void)
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void)
#########
# atomic
@@ -870,8 +891,8 @@ def atomic_max(ptr: tl.tensor,
# return atomic_umin(i_ptr, i_val) if val < 0
i_val = bitcast(val, tl.int32, builder)
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type)
return where(pos, pos_ret, neg_ret, builder)
@@ -902,8 +923,8 @@ def atomic_min(ptr: tl.tensor,
# return atomic_umax(i_ptr, i_val) if val < 0
i_val = bitcast(val, tl.int32, builder)
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
i_ptr.handle,
i_val.handle,
@@ -963,31 +984,28 @@ def atomic_xchg(ptr: tl.tensor,
# ===----------------------------------------------------------------------===//
def dot(a: tl.tensor,
b: tl.tensor,
trans_a: bool,
trans_b: bool,
def dot(lhs: tl.tensor,
rhs: tl.tensor,
allow_tf32: bool,
builder: ir.builder) -> tl.tensor:
in_a = 1 if not trans_a else 0
in_b = 1 if trans_b else 0
assert a.type.is_block() and b.type.is_block()
assert len(a.shape) == 2 and len(b.shape) == 2
assert a.shape[in_a] == b.shape[in_b]
assert a.shape[0] >= 16 and a.shape[1] >= 16 and b.shape[1] >= 16,\
assert lhs.type.is_block() and rhs.type.is_block()
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
assert lhs.shape[1].value == rhs.shape[0].value
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 16,\
"small blocks not supported!"
if a.type.scalar.is_int():
if lhs.type.scalar.is_int():
_0 = builder.get_int32(0)
ret_scalar_ty = tl.int32
else:
_0 = builder.get_float32(0)
ret_scalar_ty = tl.float32
M = a.type.shape[in_a ^ 1]
N = b.type.shape[in_b ^ 1]
M = lhs.type.shape[0]
N = rhs.type.shape[1]
_0 = builder.create_splat(_0, [M, N])
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
ret = builder.create_dot(a.handle, b.handle, _0, trans_a, trans_b, allow_tf32)
return tl.tensor(ret, ret_ty)
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)
# ===----------------------------------------------------------------------===//
@@ -1010,11 +1028,11 @@ def where(condition: tl.tensor,
ret_ty = x.type
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
# ===----------------------------------------------------------------------===//
# Reductions
# ===----------------------------------------------------------------------===
def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor:
scalar_ty = input.type.scalar
@@ -1039,16 +1057,24 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
if INT_OP in int_op_to_unit:
INT_OP = int_op_to_unit[INT_OP]
# If we are doing an argmin or argmax we want to use an int32 output type
out_scalar_ty = scalar_ty
if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX:
out_scalar_ty = tl.int32
elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN:
out_scalar_ty = tl.int32
# get result type
shape = input.type.shape
ret_shape = []
for i, s in enumerate(shape):
if i != axis:
ret_shape.append(s)
if len(ret_shape) == 0:
res_ty = scalar_ty
if ret_shape:
res_ty = tl.block_type(out_scalar_ty, ret_shape)
else:
res_ty = tl.block_type(scalar_ty, ret_shape)
# 0d-tensor -> scalar
res_ty = out_scalar_ty
if scalar_ty.is_floating():
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
@@ -1084,25 +1110,19 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR)
# -----------------------
# Utilities
# -----------------------
def clock(builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_clock(), tl.int64)
def globaltimer(builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_globaltimer, tl.int64)
# ===----------------------------------------------------------------------===
# Math
# ===----------------------------------------------------------------------===
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
x, y = binary_op_type_checking_impl(x, y, builder)
return tl.tensor(builder.create_umulhi(x.handle, y.handle), x.type)
from . import libdevice
return libdevice.mulhi(x, y, _builder=builder)
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
from . import libdevice
return libdevice.floor(x, _builder=builder)
def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor:
@@ -1130,16 +1150,23 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values")
x.handle.multiple_of(values)
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
return x
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to max_contiguous does not match the length of values")
x.handle.max_contiguous(values)
x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
return x
def debug_barrier(builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_barrier(''), tl.void)
return tl.tensor(builder.create_barrier(), tl.void)
def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
new_args = []
for arg in args:
new_args.append(arg.handle)
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)

View File

@@ -1,5 +1,12 @@
# flake8: noqa: F401
#from .conv import _conv, conv
# from .conv import _conv, conv
from . import blocksparse
from .cross_entropy import _cross_entropy, cross_entropy
from .matmul import _matmul, matmul
__all__ = [
"blocksparse",
"_cross_entropy",
"cross_entropy",
"_matmul",
"matmul",
]

View File

@@ -1,3 +1,7 @@
# flake8: noqa: F401
from .matmul import matmul
from .softmax import softmax
__all__ = [
"matmul",
"softmax",
]

View File

@@ -18,8 +18,8 @@ def num_warps(n):
@triton.jit
def _blocksparse_softmax_fwd(
Out, A, LUT, R, stride_xz,
extent, stride_zr, stride_hr, # relative attention
Out, A, stride_xz, LUT,
R, extent, stride_zr, stride_hr, # relative attention
scale, is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
@@ -164,8 +164,8 @@ class _softmax(torch.autograd.Function):
# enqueue kernel
out = torch.empty_like(a)
_blocksparse_softmax_fwd[grid](
out, a, lut, rel_logits, a.stride(0),
rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
out, a, a.stride(0), lut,
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
scale,
is_causal,
BLOCK_SIZE=block,

View File

@@ -10,7 +10,9 @@ from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcor
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
triton.compiler.init_cuda_utils()
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
return tflops
@@ -18,14 +20,14 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
return tflops
def get_tflops(backend, device, num_ctas, num_warps, dtype):
cc = _triton.runtime.cc(backend, device)
if cc < 80 and dtype == torch.float32:
capability = torch.cuda.get_device_capability(device)
if capability[0] < 8 and dtype == torch.float32:
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
@@ -59,7 +61,7 @@ def estimate_matmul_time(
compute_ms = total_ops / tput
# time to load data
num_sm = _triton.runtime.num_sm(backend, device)
num_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"]
active_cta_ratio = min(1, num_ctas / num_sm)
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
@@ -97,9 +99,8 @@ def estimate_matmul_time(
def early_config_prune(configs, named_args):
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
cc = _triton.runtime.cc(backend, device)
capability = torch.cuda.get_device_capability()
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
dtsize = named_args['A'].element_size()
dtype = named_args['A'].dtype
@@ -110,7 +111,10 @@ def early_config_prune(configs, named_args):
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
# TODO: move to `cuda_utils` submodule
triton.compiler.init_cuda_utils()
max_shared_memory = triton.compiler.cuda_utils.get_device_properties(device)["max_shared_mem"]
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory <= max_shared_memory:
pruned_configs.append(config)
@@ -136,7 +140,7 @@ def early_config_prune(configs, named_args):
pruned_configs = []
for k, v in configs_map.items():
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
if cc >= 80:
if capability[0] >= 8:
# compute cycles (only works for ampere GPUs)
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
mma_cycles = mmas / min(4, num_warps) * 8

View File

@@ -1,2 +1,12 @@
from .autotuner import Config, Heuristics, autotune, heuristics # noqa: F401
from .jit import JITFunction, KernelInterface, version_key # noqa: F401
from .autotuner import Config, Heuristics, autotune, heuristics
from .jit import JITFunction, KernelInterface, version_key
__all__ = [
"Config",
"Heuristics",
"autotune",
"heuristics",
"JITFunction",
"KernelInterface",
"version_key",
]

View File

@@ -4,6 +4,7 @@ import builtins
import time
from typing import Dict
from ..compiler import OutOfResources
from ..testing import do_bench
from .jit import KernelInterface
@@ -60,7 +61,10 @@ class Autotuner(KernelInterface):
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return do_bench(kernel_call)
try:
return do_bench(kernel_call)
except OutOfResources:
return float('inf')
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
@@ -118,7 +122,6 @@ class Autotuner(KernelInterface):
class Config:
"""
An object that represents a possible kernel configuration for the auto-tuner to try.
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
:type meta: dict[Str, Any]
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
@@ -150,10 +153,8 @@ class Config:
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
@@ -164,12 +165,10 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
@@ -204,16 +203,12 @@ def heuristics(values):
"""
Decorator for specifying how the values of certain meta-parameters may be computed.
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
.. highlight:: python
.. code-block:: python
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
each such function takes a list of positional arguments as input.
.type values: dict[str, Callable[[list[Any]], Any]]

View File

@@ -8,6 +8,7 @@ import os
import subprocess
import textwrap
from collections import defaultdict, namedtuple
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
import torch
@@ -19,6 +20,9 @@ try:
except ImportError:
get_cuda_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream
T = TypeVar('T')
# -----------------------------------------------------------------------------
# Dependencies Finder
# -----------------------------------------------------------------------------
@@ -94,20 +98,19 @@ def version_key():
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
class KernelInterface:
class KernelInterface(Generic[T]):
run: T
def __getitem__(self, grid):
def __getitem__(self, grid) -> T:
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
def launcher(*args, **kwargs):
return self.run(*args, grid=grid, **kwargs)
return launcher
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
class JITFunction(KernelInterface):
class JITFunction(KernelInterface[T]):
# Hook for inspecting compiled functions and modules
cache_hook = None
@@ -152,8 +155,8 @@ class JITFunction(KernelInterface):
if x is None:
return True
return False
divisible_by_16 = [i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize]
equal_to_1 = [i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize]
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
@@ -177,6 +180,9 @@ class JITFunction(KernelInterface):
triton.language.uint32: 'u32',
triton.language.uint64: 'u64',
triton.language.float8: 'fp8',
triton.language.float16: 'fp16',
triton.language.bfloat16: 'bf16',
triton.language.float32: 'fp32',
}[key]
return f'*{ty}'
if key is None:
@@ -272,7 +278,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature, device, constants, constexpr_key, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args, constexpr_key)
self.cache[device][key] = bin
@@ -364,29 +370,55 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
# -----------------------------------------------------------------------------
def jit(*args, **kwargs):
@overload
def jit(fn: T) -> JITFunction[T]:
...
@overload
def jit(
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
) -> Callable[[T], JITFunction[T]]:
...
def jit(
fn: Optional[T] = None,
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are
implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* objects within the triton.language package,
* builtins within the triton package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
if args:
assert len(args) == 1
assert callable(args[0])
return JITFunction(args[0], **kwargs)
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
)
if fn is not None:
return decorator(fn)
else:
def decorator(fn):
return JITFunction(fn, **kwargs)
return decorator

View File

@@ -16,6 +16,9 @@ except ImportError:
_cutlass = None
has_cutlass = False
# TODO: move to separate module
import triton
def catch_oor(kernel, pytest_handle=None):
try:
@@ -34,12 +37,12 @@ def sparsify_tensor(x, mask, block):
return ret
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None):
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
if data is None:
data = torch.randn(shape, dtype=torch.float32, device=device)
data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
ref_ret = data
ref_ret = ref_ret * alpha + beta
ref_ret = ref_ret.half().float()
ref_ret = ref_ret.half().to(dtype)
if trans:
ref_ret = ref_ret.t().requires_grad_()
ref_ret = ref_ret.detach().requires_grad_()
@@ -336,8 +339,8 @@ def get_dram_gbps(backend=None, device=None):
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
mem_clock_khz = _triton.runtime.memory_clock_rate(backend, device)
bus_width = _triton.runtime.global_memory_bus_width(backend, device)
mem_clock_khz = triton.compiler.cuda_utils.get_device_properties(device)["mem_clock_rate"] # in kHz
bus_width = triton.compiler.cuda_utils.get_device_properties(device)["mem_bus_width"]
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
return bw_gbps
@@ -347,11 +350,13 @@ def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clo
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
triton.compiler.init_cuda_utils()
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4
if not clock_rate:
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
cc = _triton.runtime.cc(backend, device)
if cc < 80:
clock_rate = triton.compiler.cuda_utils.get_device_properties(device)["sm_clock_rate"] # in kHz
capability = torch.cuda.get_device_capability(device)
if capability[0] < 8:
assert dtype == torch.float16
ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
else:

View File

@@ -0,0 +1,61 @@
import argparse
import triton
import triton._C.libtriton.triton as libtriton
if __name__ == '__main__':
# valid source and target formats
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx']
# set up the argument parser
# TODO: conditional requirements
parser = argparse.ArgumentParser()
parser.add_argument('src', help="Source file to compile")
parser.add_argument('--target', required=True,
help="Target format, one of: " + ', '.join(VALID_FORMATS))
parser.add_argument('--sm', type=int, help="Compute capability to compile for")
parser.add_argument('--ptx-version', type=int, help="PTX version to compile for")
# parse the args
args = parser.parse_args()
# TODO: clean-up and re-use triton.compiler primitive functions
# check for validity of format arguments
if args.target not in VALID_FORMATS:
print("Invalid target format: " + args.target)
exit(0)
# parse source file to MLIR module
context = libtriton.ir.context()
module = libtriton.ir.parse_mlir_module(args.src, context)
module.context = context
# optimizer triton-ir
module = triton.compiler.optimize_triton_ir(module)
if args.target == 'triton-ir':
print(module.str())
exit(0)
if not args.sm:
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
# triton-ir -> triton-gpu-ir
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm)
if args.target == 'triton-gpu-ir':
print(module.str())
exit(0)
# triton-gpu-ir -> llvm-ir
module = triton.compiler.ttgir_to_llir(module, extern_libs=None, compute_capability=args.sm)
if args.target == 'llvm-ir':
print(module)
exit(0)
if not args.ptx_version:
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
# llvm-ir -> ptx
module = triton.compiler.llir_to_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
assert args.target == 'ptx'
print(module)

View File

@@ -21,7 +21,6 @@ class Symbol:
) -> None:
'''
A symbol is a function declaration.
:param name: name of the symbol
:param op_name: name of the operation
:param ret_type: return type of the operation
@@ -65,9 +64,9 @@ def convert_type(type_str) -> Optional[str]:
elif type_str == "u64":
return "uint64"
elif type_str == "float":
return "float32"
return "fp32"
elif type_str == "double":
return "float64"
return "fp64"
else:
# ignore other types, such as pointer types
return None
@@ -98,7 +97,6 @@ class ExternLibrary(ABC):
) -> None:
'''
Abstract class for extern library.
:param name: name of the library
:param path: path of the library
:param format: whether to format the generated stub file
@@ -154,7 +152,6 @@ class Libdevice(ExternLibrary):
def __init__(self, path) -> None:
'''
Constructor for Libdevice.
:param path: path of the libdevice library
'''
super().__init__("libdevice", path)
@@ -177,7 +174,6 @@ class Libdevice(ExternLibrary):
func_strs = func_str.split("(")
func_name = func_strs[0].replace("@", "")
op_name = func_name.replace("__nv_", "")
# To filter some interfaces unlisted in NVIDIA's official documents.
if 'ieee' in op_name:
return None
# Get arg_types
@@ -310,8 +306,8 @@ class Libdevice(ExternLibrary):
for symbol in symbols:
arg_type_symbol_dict_str += "("
for arg_type in symbol.arg_types:
arg_type_symbol_dict_str += f"core.{arg_type},"
ret_type = f"core.{symbol.ret_type}"
arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),'
ret_type = f'core.dtype("{symbol.ret_type}")'
arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n"
arg_type_symbol_dict_str += "}"
@@ -331,7 +327,6 @@ class LLVMDisassembler:
def __init__(self, path) -> None:
'''
Invoke llvm-dis to disassemble the given file.
:param path: path to llvm-dis
'''
self._path = path
@@ -361,7 +356,6 @@ def build(
) -> None:
'''
Interface function to build the library file.
:param llvm_dis_path: path to the llvm-dis binary
:param lib_path: path to the external library file
:param lib_name: name of the library

View File

@@ -1,76 +0,0 @@
'''
Compare cached triton kernels in 2 directories.
example:
python compare_asm.py --dir0=triton-works/ --dir1=triton-fails/ --asm=ttir \
--diff-out0=diff-works.ll --diff-out1=diff-fails.ll
'''
import argparse
import os
import pickle
parser = argparse.ArgumentParser(description="unpickle")
parser.add_argument('--dir0', dest='dir0', required=True,
help="Triton cache dir 0")
parser.add_argument('--dir1', dest='dir1', required=True,
help="Triton cache dir 1")
parser.add_argument('--asm', dest='asm',
choices=['ttir', 'llir', 'ptx', 'cubin'], required=True)
parser.add_argument('--early-stop', dest='early_stop', action='store_true',
help="Stop after first diff")
parser.set_defaults(early_stop=True)
parser.add_argument('--diff-out0', dest='diff_out0', required=True,
help="output file path for kernels in dir0")
parser.add_argument('--diff-out1', dest='diff_out1', required=True,
help="output file path for kernels in dir1")
args = parser.parse_args()
dir0 = args.dir0
dir1 = args.dir1
asm = args.asm
dir0_files = {}
dir1_files = {}
for root, _, files in os.walk(dir0):
for file in files:
if not file.endswith('.lock'):
path = os.path.join(root, file)
with open(path, 'rb') as f:
loaded_file = pickle.load(f)
bin = loaded_file['binary']
key = loaded_file['key']
info = key.split('-')[-3:] # num_warps, num_stages, signature
dict_key = bin.name + '-'.join(info)
dir0_files[dict_key] = bin.asm
for root, _, files in os.walk(dir1):
for file in files:
if not file.endswith('.lock'):
path = os.path.join(root, file)
with open(path, 'rb') as f:
loaded_file = pickle.load(f)
bin = loaded_file['binary']
key = loaded_file['key']
info = key.split('-')[-3:] # num_warps, num_stages, signature
dict_key = bin.name + '-'.join(info)
dir1_files[dict_key] = bin.asm
diff_keys = []
for key in dir0_files:
asm0 = dir0_files[key]
if key not in dir1_files:
continue
asm1 = dir1_files[key]
if asm0[asm] != asm1[asm]:
diff_keys.append(key)
if args.early_stops:
diff_keys = diff_keys[:1]
if diff_keys:
with open(args.diff_out0, 'w') as f0, open(args.diff_out1, 'w') as f1:
for key in diff_keys:
f0.write(f'{asm} mismatch at {key}')
f0.write(dir0_files[key][asm])
f0.write('\n')
f1.write(f'{asm} mismatch at {key}')
f1.write(dir1_files[key][asm])
f1.write('\n')

View File

@@ -80,7 +80,7 @@ def softmax_kernel(
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
@@ -188,4 +188,4 @@ benchmark.run(show_plots=True, print_data=True)
#
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
# Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape.

View File

@@ -156,16 +156,7 @@ import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
],
key=['M', 'N', 'K'],
)
@@ -236,8 +227,8 @@ def matmul_kernel(
b_ptrs += BLOCK_SIZE_K * stride_bk
# you can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
if ACTIVATION == "leaky_relu":
accumulator = leaky_relu(accumulator)
if ACTIVATION:
accumulator = ACTIVATION(accumulator)
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
@@ -252,7 +243,6 @@ def matmul_kernel(
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
@triton.jit
def leaky_relu(x):
x = x + 1
return tl.where(x >= 0, x, 0.01 * x)
@@ -261,7 +251,7 @@ def leaky_relu(x):
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
def matmul(a, b, activation=""):
def matmul(a, b, activation=None):
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
assert a.is_contiguous(), "matrix A must be contiguous"
@@ -297,7 +287,7 @@ def matmul(a, b, activation=""):
torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
triton_output = matmul(a, b)
triton_output = matmul(a, b, activation=None)
torch_output = torch.matmul(a, b)
print(f"triton_output={triton_output}")
print(f"torch_output={torch_output}")
@@ -319,13 +309,13 @@ else:
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[
128 * i for i in range(2, 33)
8192
], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
line_vals=['cublas', 'triton'],
# label name for the lines
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
line_names=["cuBLAS", "Triton"],
# line styles
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel="TFLOPS", # label name for the y-axis
@@ -337,18 +327,9 @@ def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
if provider == 'cublas + relu':
torch_relu = torch.nn.ReLU(inplace=True)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_relu(torch.matmul(a, b))
)
if provider == 'triton + relu':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(a, b, activation="leaky_relu")
)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)

View File

@@ -19,8 +19,8 @@ except ModuleNotFoundError:
@triton.jit
def _layer_norm_fwd_fused(
Out,
A,
Out,
Weight,
Bias,
Mean, Rstd,
@@ -36,14 +36,14 @@ def _layer_norm_fwd_fused(
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0., eviction_policy="evict_last").to(tl.float32)
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
a = tl.where(cols < N, a - mean, 0.)
_var += a * a
var = tl.sum(_var, axis=0) / N
@@ -57,192 +57,155 @@ def _layer_norm_fwd_fused(
mask = cols < N
weight = tl.load(Weight + cols, mask=mask)
bias = tl.load(Bias + cols, mask=mask)
a = tl.load(A + cols, mask=mask, other=0., eviction_policy="evict_first").to(tl.float32)
a = tl.load(A + cols, mask=mask, other=0.).to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# # write-back
tl.store(Out + cols, out, mask=mask)
# Backward pass (DA + partial DW + partial DB)
# Backward pass (DX + partial DW + partial DB)
@triton.jit
def _layer_norm_bwd_dx_fused(
_DA,
_DOut,
_A,
Weight,
Mean, Rstd,
stride, NumRows, NumCols, eps,
BLOCK_SIZE_N: tl.constexpr,
):
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock, stride, N, eps,
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
# position of elements processed by this program
pid = tl.program_id(0)
row = pid
A = _A + row * stride
DOut = _DOut + row * stride
DA = _DA + row * stride
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
# offset data pointers to start at the row of interest
X += row * stride
DY += row * stride
DX += row * stride
# offset locks and weight/bias gradient pointer
# each kernel instance accumulates partial sums for
# DW and DB into one of GROUP_SIZE_M independent buffers
# these buffers stay in the L2, which allow this kernel
# to be fast
lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M
DW = DW + lock_id * N + cols
DB = DB + lock_id * N + cols
# load data to SRAM
_mean1 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
_mean2 = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
for off in range(0, NumCols, BLOCK_SIZE_N):
cols = off + tl.arange(0, BLOCK_SIZE_N)
mask = cols < NumCols
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
a_hat = (a - mean) * rstd
wdout = weight * dout
_mean1 += a_hat * wdout
_mean2 += wdout
mean1 = tl.sum(_mean1, axis=0) / NumCols
mean2 = 0.
mean2 = tl.sum(_mean2, axis=0) / NumCols
for off in range(0, NumCols, BLOCK_SIZE_N):
cols = off + tl.arange(0, BLOCK_SIZE_N)
mask = cols < NumCols
a = tl.load(A + cols, mask=mask, other=0).to(tl.float32)
dout = tl.load(DOut + cols, mask=mask, other=0).to(tl.float32)
weight = tl.load(Weight + cols, mask=mask, other=0).to(tl.float32)
a_hat = (a - mean) * rstd
wdout = weight * dout
da = (wdout - (a_hat * mean1 + mean2)) * rstd
# write-back dx
tl.store(DA + cols, da, mask=mask)
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(M + row)
rstd = tl.load(V + row)
# compute dx
xhat = (x - mean) * rstd
wdy = w * dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N
mean2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * mean1 + mean2)) * rstd
# write-back dx
tl.store(DX + cols, dx, mask=mask)
# accumulate partial sums for dw/db
partial_dw = (dy * xhat).to(w.dtype)
partial_db = (dy).to(w.dtype)
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
count = tl.load(Count)
# first store doesn't accumulate
if count == 0:
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(DW, mask=mask)
partial_db += tl.load(DB, mask=mask)
tl.store(DW, partial_dw, mask=mask)
tl.store(DB, partial_db, mask=mask)
# release lock
tl.atomic_xchg(Lock, 0)
# Backward pass (total DW + total DB)
@triton.jit
def _layer_norm_bwd_dwdb(
A, DOut,
Mean, Var,
DW,
DB,
M, N,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
UNROLL: tl.constexpr = 4
for i in range(0, M, BLOCK_SIZE_M * UNROLL):
for j in range(UNROLL):
rows = i + j * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows < M, other=0.)
rstd = tl.load(Var + rows, mask=rows < M, other=0.)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw += dout * a_hat
db += dout
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.)
db += tl.load(DB + offs, mask=mask, other=0.)
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(DW + cols, sum_dw, mask=cols < N)
tl.store(DB + cols, sum_db, mask=cols < N)
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
class LayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, a, normalized_shape, weight, bias, eps):
def forward(ctx, x, normalized_shape, weight, bias, eps):
# allocate output
out = torch.empty_like(a)
y = torch.empty_like(x)
# reshape input data into 2D tensor
a_arg = a.reshape(-1, a.shape[-1])
M, N = a_arg.shape
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // a.element_size()
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
_layer_norm_fwd_fused[(M,)](
out,
a_arg,
weight,
bias,
mean, rstd,
a_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
ctx.save_for_backward(
a, weight, bias, mean, rstd,
)
# enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
ctx.save_for_backward(x, weight, bias, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
if hasattr(bias, "config"):
assert bias.config.grad_scale_name == weight.config.grad_scale_name
grad_scale_name = bias.config.grad_scale_name
else:
grad_scale_name = None
ctx.grad_scale_gain_bias_name = grad_scale_name
return out
return y
@staticmethod
def backward(ctx, dout):
assert dout.is_contiguous()
a, weight, bias, mean, var = ctx.saved_tensors
def backward(ctx, dy):
x, w, b, m, v = ctx.saved_tensors
# heuristics for amount of parallel reduction stream for DG/DB
N = weight.shape[0]
N = w.shape[0]
GROUP_SIZE_M = 64
if N <= 8192: GROUP_SIZE_M = 96
if N <= 4096: GROUP_SIZE_M = 128
if N <= 1024: GROUP_SIZE_M = 256
# allocate output
da = torch.empty_like(dout)
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
dx = torch.empty_like(dy)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
x_arg = a.reshape(-1, a.shape[-1])
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
dweight = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
dbias = torch.empty((weight.shape[0],), dtype=weight.dtype, device=weight.device)
_layer_norm_bwd_dx_fused[(M,)](
da,
dout,
a,
weight,
mean, var,
x_arg.stride(0), M, N,
ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
if N > 10240:
BLOCK_SIZE_N = 128
BLOCK_SIZE_M = 32
num_warps = 4
else:
# maximize occupancy for small N
BLOCK_SIZE_N = 16
BLOCK_SIZE_M = 16
num_warps = 8
grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
_layer_norm_bwd_dwdb[grid](
a, dout,
mean, var,
dweight,
dbias,
M,
N,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
num_warps=num_warps
)
return (da, None, dweight, dbias, None)
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
x_arg.stride(0), N, ctx.eps,
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=ctx.num_warps)
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
# accumulate partial sums in separate kernel
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128)
return dx, None, dw, db, None
def layer_norm(a, normalized_shape, weight, bias, eps):
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
layer_norm = LayerNorm.apply
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
torch.manual_seed(0)
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
@@ -277,11 +240,11 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
ylabel='GB/s',
plot_name='layer-norm',
args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'}
plot_name='layer-norm-backward',
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
)
)
def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
@@ -311,5 +274,5 @@ def bench_layer_norm(M, N, dtype, provider, mode, eps=1e-5, device='cuda'):
return gbps(ms), gbps(max_ms), gbps(min_ms)
# test_layer_norm(1151, 8192, torch.float16)
bench_layer_norm.run(save_path='.', print_data=True)
test_layer_norm(1151, 8192, torch.float16)
# bench_layer_norm.run(save_path='.', print_data=True)

View File

@@ -15,7 +15,7 @@ import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
TMP, L, M, # NOTE: TMP is a scratchpad buffer to work around a compiler bug
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
@@ -39,7 +39,6 @@ def _fwd_kernel(
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
t_ptrs = TMP + off_hz * N_CTX + offs_m
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
@@ -47,11 +46,11 @@ def _fwd_kernel(
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk += tl.dot(q, tl.trans(k))
qk *= sm_scale
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
@@ -69,8 +68,6 @@ def _fwd_kernel(
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
@@ -168,26 +165,26 @@ def _bwd_kernel(
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, k, trans_b=True)
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v, trans_b=True)
dp += tl.dot(do, tl.trans(v))
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
# # compute dq
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
# compute dq
dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(tl.float16), k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
# # increment pointers
tl.store(dq_ptrs, dq)
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
@@ -198,6 +195,9 @@ def _bwd_kernel(
tl.store(dk_ptrs, dk)
empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
@@ -208,7 +208,7 @@ class _attention(torch.autograd.Function):
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
@@ -227,6 +227,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
@@ -272,13 +273,13 @@ class _attention(torch.autograd.Function):
attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
@@ -287,13 +288,16 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
# # triton implementation
tri_out = attention(q, k, v, sm_scale)
# print(ref_out)
# print(tri_out)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
@@ -323,7 +327,7 @@ configs = [triton.testing.Benchmark(
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
) for mode in ['bwd']]
) for mode in ['fwd']]
@triton.testing.perf_report(configs)
@@ -356,5 +360,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
# only works on A100 at the moment
# bench_flash_attention.run(save_path='.', print_data=True)

View File

@@ -1,74 +0,0 @@
"""
Libdevice function
===============
Triton can invoke a custom function from an external library.
In this example, we will use the `libdevice` library to apply `asin` on a tensor.
Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
Using triton, you can simply call `tl.libdevice.asin`.
triton automatically selects the correct underlying device function to invoke based on input and output types.
"""
# %%
# asin Kernel
# --------------------------
import torch
import triton
import triton.language as tl
@triton.jit
def asin_kernel(
x_ptr,
y_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
x = tl.libdevice.asin(x)
tl.store(y_ptr + offsets, x, mask=mask)
# %%
# Using the default libdevice library path
# --------------------------
# We can use the default libdevice library path encoded in `triton/language/libdevice.py`
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
output_triton = torch.zeros(size, device='cuda')
output_torch = torch.asin(x)
assert x.is_cuda and output_triton.is_cuda
n_elements = output_torch.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
# %%
# Customize the libdevice library path
# --------------------------
# We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.
output_triton = torch.empty_like(x)
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,
extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)