[RUNTIME] Added auto-alignment mechanism (#71)

This PR adds an automatic memory alignment mechanism in the Triton runtime. Specifically, the JIT compiler detects the alignment (in bytes) of each pointer argument as well as the largest power of two divisor (between 1 and 16) of each integer argument. Proper .aligned and .multipleof attributes are then added to the Triton-IR on-the-fly for all auto-tunable kernels. There is a cache that remembers all the kernels compiled for each possible configuration.

This PR also includes substantial cleaning of the Python API. This adds 2-3us overhead, mostly due to accessing integer #defines from the auto-tuned compilation options. The previous solution was slightly faster but hacky and potentially unsafe, so this is preferred for now.
This commit is contained in:
Philippe Tillet
2021-03-04 01:51:11 -05:00
committed by Philippe Tillet
parent ff62f7fffc
commit 62835a0979
19 changed files with 668 additions and 707 deletions

View File

@@ -1,12 +1,19 @@
import triton
import torch
import os
# square benchmarks
def rounded_linspace(low, high, steps, div):
ret = torch.linspace(low, high, steps)
ret = (ret.int() + div - 1) // div * div
ret = torch.unique(ret)
return list(map(int, ret))
# Square benchmarks
nt = {False: "n", True: "t"}
square_confs = [
triton.testing.Benchmark(
x_names=["M", "N", "K"],
x_vals=[512 * i for i in range(1, 16)],
x_vals=rounded_linspace(512, 8192, 17, 128),
y_name="provider",
y_vals=["torch", "triton", "cutlass"],
y_lines=["Torch", "Triton", "CUTLASS"],
@@ -17,16 +24,29 @@ square_confs = [
) for AT in [False, True] for BT in [False, True]
]
@triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
import os
# Transformer training benchmarks
transformer_confs = [
triton.testing.Benchmark(
x_names=[x],
x_vals = rounded_linspace(NK//16, NK, 33, 128),
y_name="provider",
y_vals=["torch", "triton", "cutlass"],
y_lines=["Torch", "Triton", "CUTLASS"],
ylabel="TFLOPS",
loglog=False,
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [8192]\
for i, x in enumerate(["N", "K"])\
for M in [2048]
]
@triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT:
a = a.t()
if BT:
b = b.t()
if AT: a = a.t()
if BT: b = b.t()
num_flops = 2 * M * N * K
if provider == "torch":
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
@@ -40,7 +60,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
import subprocess
import tempfile
import pandas as pd
# run program specified by CUTLASS_PROFILER env variable
layout_a = "column" if AT else "row"
layout_b = "column" if BT else "row"
@@ -61,6 +80,7 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
f"--warmup-iterations={warmup}",
f"--profiling-iterations={rep}",
f"--output={fname}",
"--dist=uniform,min:0,max:1,scale:-1",
"--verbose=false",
]
# run cmd
@@ -70,6 +90,3 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
return cutlass_tflops
return None
if __name__ == "__main__":
bench_op.run()

View File

@@ -38,4 +38,4 @@ def main(args):
run_all(args.result_dir, args.with_plots, args.names)
if __name__ == '__main__':
main(sys.argv[1:])
main(sys.argv[1:])

View File

@@ -5,38 +5,16 @@
#include <cuda_runtime_api.h>
#include <torch/extension.h>
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
namespace torch_utils {
void register_device(int64_t dev_id) {
if (tt_devices.find(dev_id) != tt_devices.end())
return;
triton::driver::device *device;
if (dev_id >= 0) {
CUdevice handle;
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
device = new triton::driver::cu_device(handle, false);
} else
device = new triton::driver::host_device();
tt_devices[dev_id].reset(device);
uint64_t cu_device(int64_t dev_id) {
CUdevice handle;
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
return (uint64_t)handle;
}
void register_stream(int64_t dev_id) {
if (tt_streams.find(dev_id) != tt_streams.end())
return;
triton::driver::stream *stream;
if (dev_id >= 0) {
CUstream handle = (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
stream = new triton::driver::cu_stream(handle, false);
} else
stream = new triton::driver::host_stream();
tt_streams[dev_id].reset(stream);
}
void synchronize(int64_t dev_id) {
tt_streams[dev_id]->synchronize();
uint64_t cu_stream(int64_t dev_id) {
return (uint64_t)c10::cuda::getCurrentCUDAStream(dev_id).stream();
}
void set_device(int64_t dev_id) {
@@ -44,23 +22,11 @@ void set_device(int64_t dev_id) {
C10_CUDA_CHECK(cudaSetDevice(dev_id));
}
torch::Tensor move_out_of_pool(torch::Tensor x) {
if (x.nbytes() == 0)
return torch::empty_like(x);
void *data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void *)data, x.sizes(), x.strides(), [data](void *ptr) { cudaFree(data); }, x.options());
ret.copy_(x);
return ret;
}
} // namespace torch_utils
void init_torch_utils(pybind11::module &m) {
pybind11::module subm = m.def_submodule("torch_utils");
subm.def("register_device", &torch_utils::register_device);
subm.def("register_stream", &torch_utils::register_stream);
subm.def("cu_device", &torch_utils::cu_device);
subm.def("cu_stream", &torch_utils::cu_stream);
subm.def("set_device", &torch_utils::set_device);
subm.def("synchronize", &torch_utils::synchronize);
subm.def("move_out_of_pool", &torch_utils::move_out_of_pool);
}

View File

@@ -1,10 +1,4 @@
#include "triton/driver/stream.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/lang/code_gen.h"
#include "triton/lang/cpp.h"
#include "triton/lang/parser.h"
#include "triton/runtime/arg.h"
#include "triton/runtime/function.h"
#include <pybind11/buffer_info.h>
#include <pybind11/functional.h>
@@ -13,72 +7,22 @@
#include <regex>
#include <string>
namespace py = pybind11;
using namespace triton;
namespace rt = triton::runtime;
namespace drv = triton::driver;
namespace lng = triton::lang;
std::unordered_map<const rt::options_t *, pybind11::object> opt_cache_;
std::map<int, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
extern std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
/*****************************************************************************/
/* Python bindings for triton::tools */
/*****************************************************************************/
/* Function utilities */
void register_fn(int op_id, int dev_id,
const std::string &src, const rt::options_t &opt,
const rt::function::autotune_vals_t &autotune_vals,
const std::vector<std::string> &autotune_key) {
if (id_fn_map.find(op_id) == id_fn_map.end()) {
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key));
}
for (const auto &k : id_fn_map[op_id]->get_kernels()) {
const rt::options_t *opt = &k.first;
pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference);
for (auto x : opt->defines)
if (std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
opt_cache_[&k.second->opt] = obj;
}
}
void delete_fn(int op_id) {
id_fn_map.erase(op_id);
}
void cleanup() {
id_fn_map.clear();
opt_cache_.clear();
}
size_t make_op_id() {
return id_fn_map.size();
}
std::vector<rt::arg_type> get_fn_signature(size_t op_id) {
return id_fn_map[op_id]->get_kernels()[0].second->get_sig();
}
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
// as a string constructed with struct.pack in python
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string &args, size_t grid_0, size_t grid_1, size_t grid_2) {
rt::function *fn = id_fn_map.at(op_id).get();
(*fn)((void **)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]);
}
pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string &args, const rt::function::grid_fn_ty &grid) {
rt::function *fn = id_fn_map.at(op_id).get();
auto wrapper = [&grid](const rt::options_t &opt) {
pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference);
for (auto x : opt.defines)
if (std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
return grid(*obj.cast<rt::options_t *>());
};
rt::kernel *kernel = fn->autotune((void **)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]);
return opt_cache_.at(&kernel->opt);
}
/*!
@brief Function for extracting kernels out of a given source-string
This can be important to enable pre-processor macros (or tunable parameters) that should only
be defined within the scope of a single kernel function
*/
std::string extract_kernels(const std::string &str, const std::vector<std::string> &names) {
if (names.empty())
return str;
@@ -94,50 +38,82 @@ std::string extract_kernels(const std::string &str, const std::vector<std::strin
std::string name = it->str(1);
kernels.push_back(std::make_tuple(name, pos, len));
}
// check that all the kernels provided actually exist
for (const std::string &name : names) {
// check that str matches any string in kernels using std::any_of
auto pred = [&name](const std::tuple<std::string, int, int> &t) { return std::get<0>(t) == name; };
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
if (!found)
throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
}
// extract functions
// simple parsing logic to extract the declaration and body of each specified kernel
std::string ret;
for (const auto &k : kernels) {
std::string name;
int pos, len;
std::tie(name, pos, len) = k;
if (std::find(names.begin(), names.end(), name) != names.end()) {
std::string def = str.substr(pos, str.size() - pos);
int count, pos;
// skip over declaration
count = 1;
pos = def.find('(');
while (!(def[pos++] == ')' && count == 0) && pos < def.size()) {
count += def[pos] == '(';
count -= def[pos] == ')';
}
// skip over definition
count = 1;
pos = def.find('{', pos);
while (!(def[pos++] == '}' && count == 0) && pos < def.size()) {
count += def[pos] == '{';
count -= def[pos] == '}';
}
ret += def.substr(0, pos);
ret += '\n';
if (std::find(names.begin(), names.end(), name) == names.end())
continue;
std::string def = str.substr(pos, str.size() - pos);
// skip over declaration
// by finding matching ')' for first '('
int count = 1;
pos = def.find('(');
while (!(def[pos++] == ')' && count == 0) && pos < def.size()) {
count += def[pos] == '(';
count -= def[pos] == ')';
}
// skip over definition
// by finding matching '{' for first '}'
count = 1;
pos = def.find('{', pos);
while (!(def[pos++] == '}' && count == 0) && pos < def.size()) {
count += def[pos] == '{';
count -= def[pos] == '}';
}
ret += def.substr(0, pos);
ret += '\n';
}
return ret;
}
void init_triton(pybind11::module &m) {
pybind11::module subm = m.def_submodule("triton");
// bindings for triton classes
pybind11::enum_<rt::arg_type>(subm, "arg_type")
void init_triton_tools(py::module &&m) {
m.def("extract_kernels", &extract_kernels);
}
/*****************************************************************************/
/* Python bindings for triton::driver */
/*****************************************************************************/
void init_triton_driver(py::module &&m) {
// base device
py::class_<drv::device>(m, "device");
// cuda device
py::class_<drv::cu_device, driver::device>(m, "cu_device")
.def(py::init<CUdevice, bool>());
// host device
py::class_<drv::host_device, driver::device>(m, "host_device")
.def(py::init<>());
// base stream
py::class_<drv::stream>(m, "stream");
// host stream
py::class_<drv::host_stream, drv::stream>(m, "host_stream")
.def(py::init<>());
// cuda stream
py::class_<drv::cu_stream, drv::stream>(m, "cu_stream")
// py doesn't support opaque pointer (e.g., CUstream) so
// we assume it has been converted to uint64_t
.def(py::init([](uint64_t handle, bool take_ownership) {
return std::unique_ptr<driver::cu_stream>(new driver::cu_stream((CUstream)handle, take_ownership));
}));
}
/*****************************************************************************/
/* Python bindings for triton::runtime */
/*****************************************************************************/
void init_triton_runtime(py::module &&m) {
// argument type
py::enum_<rt::arg_type>(m, "arg_type")
.value("int1", rt::INT1_T)
.value("int8", rt::INT8_T)
.value("int16", rt::INT16_T)
@@ -147,23 +123,38 @@ void init_triton(pybind11::module &m) {
.value("float", rt::FLOAT_T)
.value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T);
pybind11::enum_<rt::asm_mode_t>(subm, "asm_mode")
// assembly mode
py::enum_<rt::asm_mode_t>(m, "asm_mode")
.value("ptx", rt::ASM_NV_PTX)
.value("sass", rt::ASM_NV_SASS);
pybind11::class_<rt::options_t>(subm, "options", pybind11::dynamic_attr())
.def(pybind11::init<>())
// compilation options
py::class_<rt::options_t>(m, "options", py::dynamic_attr())
.def(py::init<>())
.def_readwrite("defines", &rt::options_t::defines)
.def_readwrite("num_warps", &rt::options_t::num_warps);
.def_readwrite("num_warps", &rt::options_t::num_warps)
.def("__getattr__", [](rt::options_t *opt, const std::string &name) {
return opt->D<int>(name);
});
// kernel
py::class_<rt::kernel>(m, "kernel")
.def("__call__", &rt::kernel::operator())
.def_readonly("opt", &rt::kernel::opt);
// tune conf
py::class_<rt::config>(m, "config")
.def(py::init<std::map<std::string, std::string>, int>(),
py::arg("defines") = std::map<std::string, std::string>(),
py::arg("num_warps"));
// hooks into triton constructs since frameworks may not use pybind11
subm.def("extract_kernels", &extract_kernels);
subm.def("get_fn_signature", &get_fn_signature);
subm.def("register_fn", &register_fn);
subm.def("delete_fn", &delete_fn);
subm.def("make_op_id", &make_op_id);
subm.def("cleanup", &cleanup);
subm.def("autotune", &autotune, pybind11::return_value_policy::reference);
subm.def("launch_kernel", &launch_kernel);
// function
py::class_<rt::function>(m, "function")
.def(py::init<const std::string &, const rt::options_t &, driver::device *, const std::vector<rt::config> &, const std::vector<std::string> &>())
.def("autotune", &rt::function::autotune, py::return_value_policy::reference_internal)
.def("signature", &rt::function::get_signature);
}
void init_triton(py::module &m) {
py::module subm = m.def_submodule("triton");
init_triton_driver(std::move(subm.def_submodule("driver")));
init_triton_runtime(std::move(subm.def_submodule("runtime")));
init_triton_tools(std::move(subm.def_submodule("tools")));
}

View File

@@ -50,8 +50,9 @@ import torch
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
torch.manual_seed(0)
defines = {"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}
triton.ops._matmul._kernels = dict()
triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)]
triton.ops._matmul._CONFIGS = [triton.config(defines=defines, num_warps=NWARP)]
if M is None:
M = TM
if N is None:

View File

@@ -1,27 +1,21 @@
import os
import struct
from typing import Optional, Dict, List
import torch
# C bindings
import triton._C.libtriton.triton as _triton
import triton._C.libtriton.torch_utils as _torch_utils
# Make sure internal C resources are cleaned up upon exit
import atexit
@atexit.register
def cleanup():
_triton.cleanup()
codes = {
_triton.arg_type.int1: 'B', _triton.arg_type.int8: 'B', _triton.arg_type.int32: 'I', _triton.arg_type.int64: 'Q',
_triton.arg_type.half: 'H', _triton.arg_type.float: 'f', _triton.arg_type.double: 'd', _triton.arg_type.buffer: 'P'
_triton.runtime.arg_type.int1: 'B', _triton.runtime.arg_type.int8: 'B', _triton.runtime.arg_type.int32: 'I',
_triton.runtime.arg_type.int64: 'Q', _triton.runtime.arg_type.half: 'H', _triton.runtime.arg_type.float: 'f',
_triton.runtime.arg_type.double: 'd', _triton.runtime.arg_type.buffer: 'P'
}
def th_to_triton(obj):
tys = {
torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long', torch.float16: 'half',
torch.float32: 'float', torch.float64: 'double'
torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long',\
torch.float16: 'half', torch.float32: 'float', torch.float64: 'double'
}
if isinstance(obj, torch.dtype):
return tys[obj]
@@ -30,69 +24,54 @@ def th_to_triton(obj):
def cdiv(a, b):
return (a + b - 1) // b
def synchronize(device):
dev_id = device.index
dev_id = -1 if dev_id is None else dev_id
_torch_utils.synchronize(dev_id)
def read(path, kernel_names:Optional[List]=None):
def read(path, kernel_names: Optional[List] = None):
if kernel_names is None:
kernel_names = []
with open(path, 'r') as f:
source = f.read()
source = _triton.extract_kernels(source, kernel_names)
source = _triton.tools.extract_kernels(source, kernel_names)
return source
class kernel:
def __init__(self,
src,
device,
defines: Optional[Dict]=None,
num_warps:int=4,
autotune_vals:Optional[List]=None,
autotune_key:Optional[List]=None):
config = _triton.runtime.config
class kernel:
def __init__(self, src, device, defines: Optional[Dict] = None, num_warps: int = 4,
autotune_vals: Optional[List] = None, autotune_key: Optional[List] = None):
if defines is None:
defines = {}
if autotune_vals is None:
autotune_vals = []
if autotune_key is None:
autotune_key = []
# check if src is empty
if src == '':
raise ValueError('Kernel source code is empty')
self.src = src
self.opt = _triton.options()
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
self.opt.num_warps = num_warps
# device
assert device.type in ['cuda', 'cpu']
if device.type == 'cuda':
self.device = torch.cuda.current_device() if device.index is None else device.index
self.device_id = torch.cuda.current_device() if device.index is None else device.index
self.device = _triton.driver.cu_device(_torch_utils.cu_device(self.device_id), False)
self.stream = _triton.driver.cu_stream(_torch_utils.cu_stream(self.device_id), False)
if device.type == 'cpu':
self.device = -1
_torch_utils.register_device(self.device)
_torch_utils.register_stream(self.device)
# C++ function wrapper
self.op_id = _triton.make_op_id()
_torch_utils.set_device(self.device)
_triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
# debug mode
self.is_debug = 'TRITON_DEBUG' in os.environ
# signature
arg_types = _triton.get_fn_signature(self.op_id)
self.tys = ''.join([codes[x] for x in arg_types])
self.device_id = -1
self.device = _triton.driver.host_device()
self.device = _triton.driver.host_stream()
_torch_utils.set_device(self.device_id)
# function
self.opt = _triton.runtime.options()
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
self.opt.num_warps = num_warps
# autotune_vals = [({}, 4)]
self.fn = _triton.runtime.function(self.src, self.opt, self.device, autotune_vals, autotune_key)
self.tys = ''.join([codes[x] for x in self.fn.signature()])
def __call__(self, *args, grid):
_torch_utils.set_device(self.device)
# make sure that the executing thread is on the right device
_torch_utils.set_device(self.device_id)
# pack parameters into a byte buffer
params = struct.pack(self.tys, *args)
opt = _triton.autotune(self.op_id, self.device, params, grid)
kernel = self.fn.autotune(params, grid, self.stream)
# run kernel
grid = grid(opt)
grid_0 = grid[0]
grid_1 = 1 if len(grid) < 2 else grid[1]
grid_2 = 1 if len(grid) < 3 else grid[2]
_triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)
grid = grid(kernel.opt)
kernel(params, self.stream, grid)

View File

@@ -1,17 +1,17 @@
__global__ void NAME(TYPE *A __readonly __noalias __aligned(16),
TYPE *B __readonly __noalias __aligned(16),
TYPE *C __noalias __aligned(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8),
long stride_za __multipleof(8),
long stride_zb __multipleof(8),
long stride_zc __multipleof(8),
long stride_ha __multipleof(8),
long stride_hb __multipleof(8),
long stride_hc __multipleof(8),
__global__ void NAME(TYPE *A __readonly __noalias,
TYPE *B __readonly __noalias,
TYPE *C __noalias,
int lda,
int ldb,
int ldc,
long stride_za,
long stride_zb,
long stride_zc,
long stride_ha,
long stride_hb,
long stride_hc,
int DS0, int DS1,
int SDD_K __multipleof(16),
int SDD_K,
int SDD_off_width,
int *lut, int *locks, int nlocks) {
/* ---------------- */

View File

@@ -1,17 +1,16 @@
__global__ void forward(TYPE *X __readonly __noalias __aligned(16),
__global__ void forward(TYPE *X __readonly __noalias,
float scale,
int *LUT __readonly __noalias __aligned(16),
TYPE *RPE __readonly __noalias __aligned(16),
TYPE *KP_M __readonly __noalias __aligned(16),
TYPE *ATTN_M __readonly __noalias __aligned(16),
int *LUT __readonly __noalias,
TYPE *RPE __readonly __noalias,
TYPE *KP_M __readonly __noalias,
TYPE *ATTN_M __readonly __noalias,
int sizemax,
long stride_zx __multipleof(4),
long stride_zrpe __multipleof(BLOCK),
int stride_hrpe __multipleof(BLOCK),
int stride_srpe __multipleof(BLOCK),
int stride_zkpm __multipleof(BLOCK),
int stride_zattnm __multipleof(BLOCK))
{
long stride_zx,
long stride_zrpe,
int stride_hrpe,
int stride_srpe,
int stride_zkpm,
int stride_zattnm) {
int pidhm = get_program_id(0);
int pidz = get_program_id(1);
// create index ranges
@@ -97,14 +96,13 @@ __global__ void forward(TYPE *X __readonly __noalias __aligned(16),
*? (check)px = y / ysum;
}
__global__ void backward(TYPE *X __readonly __noalias __aligned(16),
__global__ void backward(TYPE *X __readonly __noalias,
float scale,
TYPE *DX __readonly __noalias __aligned(16),
TYPE *DX __readonly __noalias,
int *LUT,
int sizemax,
long stride_zx __multipleof(BLOCK),
long stride_zdx __multipleof(BLOCK))
{
long stride_zx,
long stride_zdx) {
int pidhm = get_program_id(0);
int pidz = get_program_id(1);
// create index ranges

View File

@@ -1,126 +1,131 @@
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
TYPE *B __noalias __readonly __aligned(16),
TYPE *C __noalias __aligned(16),
float alpha,
// equivalent matmul
int M, int N, int K,
// convolution properties
int pad_h, int pad_w, int stride_h, int stride_w,
// pointer increment
int *ADELTA,
// memory strides
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
int gridx = M / TM;
int gridy = N / TN;
int rid = ridx + ridy * gridx;
ridx = rid / gridy;
ridy = rid % gridy;
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK;
__global__ void conv(TYPE *A __noalias __readonly,
TYPE *B __noalias __readonly,
TYPE *C __noalias,
float alpha,
// equivalent matmul
int M, int N, int K,
// convolution properties
int pad_h, int pad_w, int stride_h, int stride_w,
// pointer increment
int *ADELTA,
// memory strides
int lda_z, int lda_ci, int lda_h, int lda_w,
int ldb_ci, int ldb_r, int ldb_s, int ldb_co,
int ldc_z, int ldc_co, int ldc_p, int ldc_q) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
int gridx = M / TM;
int gridy = N / TN;
int rid = ridx + ridy * gridx;
ridx = rid / gridy;
ridy = rid % gridy;
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK;
// unpack aggregate rows
// m = (z, p, q)
int rq[TM] = rm % QQ;
int rzp[TM] = rm / QQ;
int rp[TM] = rzp % PP;
int rz[TM] = rzp / PP;
// unpack aggregate reduction
// k = (ci, r, s)
int rs [TK] = rk % SS;
int rcir[TK] = rk / SS;
int rr [TK] = rcir % RR;
int rci [TK] = rcir / RR;
// unpack aggregate rows
// m = (z, p, q)
int rq[TM] = rm % QQ;
int rzp[TM] = rm / QQ;
int rp[TM] = rzp % PP;
int rz[TM] = rzp / PP;
// unpack aggregate reduction
// k = (ci, r, s)
int rs[TK] = rk % SS;
int rcir[TK] = rk / SS;
int rr[TK] = rcir % RR;
int rci[TK] = rcir / RR;
// padding / striding
int rh_0[TM] = rp * stride_h - pad_h;
int rw_0[TM] = rq * stride_w - pad_w;
int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :];
int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :];
// padding / striding
int rh_0[TM] = rp * stride_h - pad_h;
int rw_0[TM] = rq * stride_w - pad_w;
int rh[TM, TK] = rh_0[:, newaxis] + rr [newaxis, :];
int rw[TM, TK] = rw_0[:, newaxis] + rs [newaxis, :];
// pointers to lhs
int offa[TM, TK] = rz [:, newaxis] * lda_z +
rci[newaxis, :] * lda_ci +
rh * lda_h +
rw * 1;
TYPE* pa[TM, TK] = A + offa;
int* padelta[TK] = ADELTA + rk;
// pointers to rhs
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
rr [:, newaxis] * ldb_r +
rs [:, newaxis] * ldb_s +
rn [newaxis, :] * 1;
TYPE* pb[TK, TN] = B + offb;
// pointers to lhs
int offa[TM, TK] = rz[:, newaxis] * lda_z +
rci [newaxis, :] * lda_ci +
rh * lda_h +
rw * 1;
TYPE *pa[TM, TK] = A + offa;
int *padelta[TK] = ADELTA + rk;
// pointers to rhs
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
rr
[:, newaxis] * ldb_r +
rs
[:, newaxis] * ldb_s +
rn [newaxis, :] * 1;
TYPE *pb[TK, TN] = B + offb;
// prefetches operands
bool checkam[TM, TK] = rm[:, newaxis] < M;
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
int total = 0;
// prefetches operands
bool checkam[TM, TK] = rm[:, newaxis] < M;
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
int total = 0;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
acc += a @ b;
// increment A
int adelta[TK] = *padelta;
padelta += TK;
pa += adelta[newaxis, :];
// bounds-checking A
rk += TK;
rs = rk % SS;
rcir = rk / SS;
rr = rcir % RR;
rh = rh_0[:, newaxis] + rr[newaxis, :];
rw = rw_0[:, newaxis] + rs[newaxis, :];
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
// increment B
pb += TK * ldb_s;
// bounds-checking B
bool checkb[TK, TN] = k > TK;
a = checka ? *pa : 0;
b = *?(checkb)pb;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// reduction loop
float acc[TM, TN] = 0;
for (int k = K; k > 0; k -= TK) {
acc += a @b;
// increment A
int adelta[TK] = *padelta;
padelta += TK;
pa += adelta [newaxis, :];
// bounds-checking A
rk += TK;
rs = rk % SS;
rcir = rk / SS;
rr = rcir % RR;
rh = rh_0[:, newaxis] + rr [newaxis, :];
rw = rw_0[:, newaxis] + rs [newaxis, :];
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
// increment B
pb += TK * ldb_s;
// bounds-checking B
bool checkb[TK, TN] = k > TK;
a = checka ? *pa : 0;
b = *? (checkb)pb;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
rm = ridx * TM + 0 ... TM;
rn = ridy * TN + 0 ... TN;
rq = rm % QQ;
rzp = rm / QQ;
rp = rzp % PP;
rz = rzp / PP;
int offc[TM, TN] = rz [:, newaxis] * ldc_z +
rn [newaxis, :] * ldc_co+
rp [:, newaxis] * ldc_p +
rq [:, newaxis] * 1;
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N;
// epilogue
rm = ridx * TM + 0 ... TM;
rn = ridy * TN + 0 ... TN;
rq = rm % QQ;
rzp = rm / QQ;
rp = rzp % PP;
rz = rzp / PP;
int offc[TM, TN] = rz[:, newaxis] * ldc_z +
rn [newaxis, :] * ldc_co +
rp
[:, newaxis] * ldc_p +
rq
[:, newaxis] * 1;
TYPE *pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rm[:, newaxis] < M && rn [newaxis, :] < N;
#if (TZ==1)
*?(checkc) pc = c;
#if (TZ == 1)
*? (checkc)pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount;
if (count == 0)
*? (checkc)pc = c;
else
*? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}

View File

@@ -1,8 +1,4 @@
__global__ void forward(TYPE *logit __aligned(16),
TYPE *modified_logit __aligned(16),
long *indices __readonly,
TYPE *result __aligned(16),
int n_cols __multipleof(N_COLS_MULT)) {
__global__ void forward(TYPE *logit, TYPE *modified_logit, long *indices, TYPE *result, int n_cols) {
int row = get_program_id(0);
bool check[TILE] = ((0 ... TILE) < n_cols);
@@ -19,10 +15,7 @@ __global__ void forward(TYPE *logit __aligned(16),
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
}
__global__ void backward(TYPE *neg_logprobs __aligned(16),
long *indices __aligned(16),
TYPE *dneg_logprobs __aligned(16),
int n_cols __multipleof(N_COLS_MULT)) {
__global__ void backward(TYPE *neg_logprobs, long *indices, TYPE *dneg_logprobs, int n_cols) {
int row = get_program_id(0);
// pointer arithmetic

View File

@@ -1,16 +1,12 @@
#define STM 8
#define STN 8
__global__ void matmul(TYPE *A __noalias __readonly __aligned(16),
TYPE *B __noalias __readonly __aligned(16),
TYPE *C __noalias __aligned(16),
__global__ void matmul(TYPE *A __noalias __readonly,
TYPE *B __noalias __readonly,
TYPE *C __noalias,
float alpha,
int M,
int N,
int K __multipleof(16),
int lda __multipleof(LDA_POW2_DIV),
int ldb __multipleof(LDB_POW2_DIV),
int ldc __multipleof(LDC_POW2_DIV),
int M, int N, int K,
int lda, int ldb, int ldc,
int *locks) {
// prologue
int pid = get_program_id(0);

View File

@@ -6,18 +6,18 @@ class _matmul(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
_DEFAULT_CONFIGS = [
({"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, 4),
({'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, 4),
({'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, 4),
({'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 2),
({'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 2),
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
triton.config(defines={"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, num_warps=4),
triton.config(defines={'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
triton.config(defines={'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=2),
triton.config(defines={'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=2),
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4),
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4),
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4),
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4),
]
_CONFIGS = _DEFAULT_CONFIGS

View File

@@ -1,4 +1,5 @@
import torch
import os
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
@@ -77,8 +78,12 @@ class Mark:
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
def run(self, result_path, with_plot):
for bench in self.benchmarks:
self._run(bench, result_path, with_plot)
with open(os.path.join(result_path, "results.html"), "w") as html:
html.write("<html><body>\n")
for bench in self.benchmarks:
self._run(bench, result_path, with_plot)
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
html.write("</body></html>\n")
def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks)

View File

@@ -0,0 +1,76 @@
import torch
import triton
# source-code for Triton compute kernel
# here we just copy-paste the above code without the extensive comments.
# you may prefer to store it in a .c file and load it from there instead.
_src = """
__global__ void add(float* z, float* x, float* y, int N){
// program id
int pid = get_program_id(0);
// create arrays of pointers
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
float* pz[BLOCK] = z + offset;
float* px[BLOCK] = x + offset;
float* py[BLOCK] = y + offset;
// bounds checking
bool check[BLOCK] = offset < N;
// write-back
*?(check)pz = *?(check)px + *?(check)py;
}
"""
# This function returns a callable `triton.kernel` object
# created from the above source code.
# For portability, we maintain a cache of kernels for different `torch.device`
# We compile the kernel with -DBLOCK=1024
_kernels = dict()
def make_add_kernel(device):
if device not in _kernels:
defines = {'BLOCK': 1024}
autotune_vals = [({'BLOCK': '1024'}, 4), ({'BLOCK': '2048'}, 4)]
autotune_key = ["N"]
_kernels[device] = triton.kernel(_src, device=device, defines=defines, autotune_vals=autotune_vals,
autotune_key=autotune_key)
return _kernels[device]
# This is a standard torch custom autograd Function
# The only difference is that we can now use the above kernel
# in the `forward` and `backward` functions.`
class _add(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
# constraints of the op
assert x.dtype == torch.float32
# *allocate output*
z = torch.empty_like(x)
# *create launch grid*:
# this is a function which takes compilation parameters `opt`
# as input and returns a tuple of int (i.e., launch grid) for the kernel.
# triton.cdiv is a shortcut for ceil division:
# triton.cdiv(a, b) = (a + b - 1) // b
grid = lambda opt: (triton.cdiv(z.shape[0], opt.BLOCK), )
# *launch kernel*:
# pointer to the data of torch tensors can be retrieved with
# the `.data_ptr()` method
kernel = make_add_kernel(z.device)
kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), z.shape[0], grid=grid)
return z
# Just like we standard PyTorch ops
# We use the `.apply` method to create a
# callable object for our function
add = _add.apply
torch.manual_seed(0)
x = torch.rand(32, device='cuda')
y = torch.rand(32, device='cuda')
za = x + y
zb = add(x, y)
print(za)
print(zb)
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
th_ms = triton.testing.do_bench(lambda: x + y)
tr_ms = triton.testing.do_bench(lambda: add(x, y))
print(th_ms, tr_ms)