[PYTHON][TESTS][DOC] Various improvement of the API and code quality:

* Simplified `triton.kernel` API to achieve lower latency:
  > .data_ptr() must now be passed as kernel argument. No more implicit
conversion from torch.tensor
  > compilation options are now constant attributes, i.e., opt.d('VAR')
becomes opt.VAR
  > torch.device must now be passed explicitly to triton.kernel (no
longer inferred from torch.tensor arguments)
* C++ tests moved to `python/tests/`
* C++ tutorial created in `tutorials/`
* Python tutorial created in python/tutorials/
* Version changed to 1.0alpha
* No longer copying C++ headers into the Python package
* added python/triton/ops/ package for pre-written Triton ops
This commit is contained in:
Philippe Tillet
2021-01-29 17:27:16 -05:00
parent a5a477c36b
commit 269ebc12e5
63 changed files with 2255 additions and 3883 deletions

View File

@@ -1,39 +0,0 @@
import torch
import triton
class _add(torch.autograd.Function):
src = """
__global__ void add(float* z, float* x, float* y, int N) {
int pid = get_program_id(0);
int offset[TILE] = pid * TILE + 0 ... TILE;
float* pz[TILE] = z + offset;
float* px[TILE] = x + offset;
float* py[TILE] = y + offset;
bool check[TILE] = offset < N;
*pz = *px + *py;
}
"""
kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4])
@staticmethod
def forward(ctx, x, y):
z = torch.empty_like(x).cuda()
N = x.numel()
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
_add.kernel(z,x,y, N, grid=grid)
return z
add = _add.apply
# test
torch.manual_seed(0)
x = torch.rand(900).cuda()
y = torch.rand(900).cuda()
za = x + y
zb = add(x, y)
print(torch.allclose(za,zb))

View File

@@ -1,70 +0,0 @@
import torch
import triton
class _copy(torch.autograd.Function):
src = """
__global__ void copy(TYPE * X, TYPE * Y,
int M __retune,
int N __retune,
int ldx __multipleof(8)) {
// extract program ID
int pidm = get_program_id(0); //(1)
int pidn = get_program_id(1); //(2)
// create 1D range along the two matrix's axes
int rm[TM] = pidm * TM + 0 ... TM; //(3)
int rn[TN] = pidn * TN + 0 ... TN; //(4)
// create 2D array of pointers
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldx; //(6)
*py = *px;
}
"""
kernel = None ### initialize later when we know the sizes
@staticmethod
def forward(ctx, x):
M, N = x.shape
ldx = N;
dtype = x.dtype
y = torch.empty((M,N)).cuda()
defines= {
'TYPE' : dtype,
'TM' : [32,64,128],
'TN' : [32,64,128],
}
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
if _copy.kernel is None:
_copy.kernel = triton.kernel(_copy.src, defines=defines, num_warps=[4])
_copy.kernel(x, y, M, N, ldx, grid=grid)
return y
copy = _copy.apply
# test
torch.manual_seed(0)
x = torch.randn(8,4).cuda()
print(x)
ya = x
yb = copy(x)
print()
print(ya)
print()
print(yb)
print(torch.allclose(ya, yb))
print(ya == yb)

View File

@@ -1,143 +0,0 @@
import torch
import triton
class _dot(torch.autograd.Function):
src = """
#define STM 4
#define STN 4
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16),
float alpha,
int M __retune,
int N __retune,
int K __retune __multipleof(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8)) {
// prologue
int pid = get_program_id(0);
int pidz = get_program_id(2);
int gridm = M / TM;
int gridn = N / TN;
int stgridm = (gridm + STM - 1) / STM;
int stgridn = (gridn + STN - 1) / STN;
int stid = pid / (STM * STN);
int laneid = pid % (STM * STN);
int stm = stid / stgridn;
int stn = stid % stgridn;
int lanem = laneid / STN;
int lanen = laneid % STN;
int pidm = stm*STM + lanem;
int pidn = stn*STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = pidz * K + 0 ... TK;
// pointers to operands
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
TYPE* pa[TM, TK] = A + offa;
TYPE* pb[TK, TN] = B + offb;
// prefetches operands
bool checka[TM, TK] = rk[newaxis, :] < K;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
acc += a @ b;
a = *?(checka)pa;
b = *?(checkb)pb;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
int rxm[TM] = pidm * TM + 0 ... TM;
int rxn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
#if (TZ==1)
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + pid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}
"""
@staticmethod
def forward(ctx, a, b):
c = _dot._call(a,b)
return c
kernel = dict()
@staticmethod
def _call(a, b):
# create kernel if necessary
dtype = a.dtype
if dtype not in _dot.kernel:
defines = {
'TYPE' : dtype,
'SHAPE_A': 'TM, TK', 'SHAPE_B': 'TK, TN',
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
'TM' : [128],
'TN' : [128],
'TK' : [32],
'TZ' : [1]
}
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
kernel = _dot.kernel[dtype]
# allocate output
M, K = a.shape
K, N = b.shape
c = torch.empty([M,N], dtype=dtype, device=a.device)
print(kernel.asm('sass', c.device))
print(kernel.asm('ptx', c.device))
# enqueue
grid = lambda opt: [triton.cdiv(M, opt.d('TM'))*triton.cdiv(N, opt.d('TN'))]
time = kernel(a, b, c, 1., M, N, K,
a.stride(0), b.stride(0), c.stride(0), grid=grid)
return c
dot = _dot.apply
torch.manual_seed(0)
M, N, K = 4096, 4096, 4096
a = torch.rand((M, K)).cuda().half()
b = torch.rand((K, N)).cuda().half()
#a[:] = 1
#b[:] = 1
zc = torch.matmul(a,b)
zc_ = dot(a,b)
print(torch.allclose(zc, zc_))

View File

@@ -1,76 +0,0 @@
import torch
import triton
class _transpose(torch.autograd.Function):
src = """
__global__ void transpose(TYPE * X, TYPE * Y,
int M __retune,
int N __retune,
int ldx __multipleof(8), int ldy __multipleof(8)) {
// extract program ID
int pidm = get_program_id(0); //(1)
int pidn = get_program_id(1); //(2)
// create 1D range along the two matrix's axes
int rm[TM] = pidm * TM + 0 ... TM; //(3)
int rn[TN] = pidn * TN + 0 ... TN; //(4)
// create 2D array of pointers
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
// create bounds-checking mask
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
bool checky[TN, TM] = (rn[:, newaxis] < N) && (rm[newaxis, :] < M); //(7b)
// conditional write-back using the conditional dereferencing operatior '*?()'
*?(checky)py = ^(*?(checkx)px); //(7)
}
"""
kernel = None ### initialize later when we know the sizes
@staticmethod
def forward(ctx, x):
M, N = x.shape
ldx = N
ldy = M
dtype = x.dtype
y = torch.empty((N,M)).cuda()
defines= {
'TYPE' : dtype,
'TM' : [32,64,128],
'TN' : [32,64,128],
}
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
if _transpose.kernel is None:
_transpose.kernel = triton.kernel(_transpose.src, defines=defines, num_warps=[4])
_transpose.kernel(x, y, M, N, ldx, ldy, grid=grid)
return y
transpose = _transpose.apply
# test
torch.manual_seed(0)
x = torch.randn(1024,128).cuda()
print(x)
ya = torch.t(x)
yb = transpose(x)
print()
print(ya)
print()
print(yb)
print(torch.allclose(ya, yb))
print(ya == yb)

View File

@@ -95,25 +95,18 @@ class CMakeBuild(build_ext):
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
find_llvm()
directories = [x[0] for x in os.walk(os.path.join('src', 'include'))]
data = []
for d in directories:
for htype in ['h', 'hpp']:
files = glob.glob(os.path.join(d, f'*.{htype}'), recursive=False)
data += [os.path.relpath(f, 'src') for f in files]
setup(
name='triton',
version='0.3.0',
version='1.0.0',
author='Philippe Tillet',
author_email='ptillet@g.harvard.edu',
author_email='phil@openai.com',
description='A language and compiler for custom Deep Learning operations',
long_description='',
packages=['triton', 'triton/_C'],
install_requires=['numpy', 'torch', 'sympy'],
package_data={'': data},
packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'],
install_requires=['numpy', 'torch'],
package_data={'triton/ops': ['*.c'],
'triton/ops/blocksparse': ['*.c']},
include_package_data=True,
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
cmdclass=dict(build_ext=CMakeBuild),
zip_safe=False,
@@ -122,7 +115,7 @@ setup(
url='https://github.com/ptillet/triton/',
download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz',
classifiers=[
'Development Status :: 4 - Beta', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
'Intended Audience :: Developers', # Define that your audience are developers
'Topic :: Software Development :: Build Tools',
'License :: OSI Approved :: MIT License', # Again, pick a license

View File

@@ -13,15 +13,19 @@
#include "triton/ir/function.h"
using namespace triton;
namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::pair<int, int> map_key_t;
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
CUstream torch_get_cuda_stream(int64_t dev_id);
CUdevice torch_get_cuda_device(int64_t dev_id);
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<int, std::shared_ptr<rt::function>> id_fn_map;
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
std::unordered_map<const rt::options_t*, pybind11::object> opt_cache_;
extern CUstream torch_get_cuda_stream(int64_t dev_id);
extern CUdevice torch_get_cuda_device(int64_t dev_id);
/* Grid utilities */
@@ -36,106 +40,123 @@ void delete_grid(const map_key_t& key) {
/* Function utilities */
void register_fn(const map_key_t& key,
void register_fn(int op_id,
int dev_id,
const std::string& src,
const rt::options_space_t& opt) {
if(id_fn_map.find(key) == id_fn_map.end())
id_fn_map[key].reset(new rt::function(src, opt, ""));
if(tt_devices.find(dev_id) == tt_devices.end()) {
driver::device* device;
driver::stream* stream;
if(dev_id >= 0){
device = new triton::driver::cu_device(torch_get_cuda_device(dev_id), false);
stream = new triton::driver::cu_stream(torch_get_cuda_stream(dev_id), false);
}
else{
device = new triton::driver::host_device();
stream = new triton::driver::host_stream();
}
tt_devices[dev_id].reset(device);
tt_streams[dev_id].reset(stream);
}
if(id_fn_map.find(op_id) == id_fn_map.end()){
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id]));
}
for(const auto& k: id_fn_map[op_id]->get_kernels()){
const rt::options_t* opt = &k.first;
pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference);
for(auto x: opt->defines)
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
opt_cache_[&k.second->opt] = obj;
}
}
void delete_fn(const map_key_t& key) {
id_fn_map.erase(key);
void delete_fn(int op_id) {
id_fn_map.erase(op_id);
}
std::string get_fn_asm(const map_key_t& key, rt::asm_mode_t mode, const rt::options_t& opt) {
triton::driver::cu_device device(key.second, false);
return id_fn_map[key]->get_asm(mode, &device, opt);
}
void cleanup() {
id_grid_map.clear();
id_fn_map.clear();
opt_cache_.clear();
}
size_t make_op_id() {
return id_fn_map.size();
}
/* Function signature */
void make_module(const std::string& src, ir::module* ir,
const runtime::options_space_t& opt) {
std::string copy = triton::runtime::function::preheader() + src;
// pre-process
TokenSequence tokens;
Preprocessor cpp(&copy, true);
for(auto it: opt.defines){
cpp.AddMacro(it.first, &it.second[0]);
}
cpp.Process(tokens);
// parse
Parser parser(tokens);
parser.Parse();
Generator gen(&parser);
gen.Gen(ir);
std::vector<rt::arg_type> get_fn_signature(size_t op_id) {
return id_fn_map[op_id]->get_kernels()[0].second->get_sig();
}
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
const runtime::options_space_t& opt) {
// triton-ir code-gen
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
make_module(src, &*ir, opt);
// function
ir::function* fn = ir->get_function_list().front();
// extract signature
std::vector<rt::arg_type> ret;
ir::function_type* ty = fn->get_fn_type();
for(size_t i = 0; i < ty->get_num_params(); i++)
ret.push_back(rt::convert(ty->get_param_ty(i)));
return ret;
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, size_t grid_0, size_t grid_1, size_t grid_2){
rt::function* fn = id_fn_map.at(op_id).get();
(*fn)((void**)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]);
// for(size_t n = 0; n < constant_names.size(); n++){
// const torch::Tensor& x = constant_vals[n];
// fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
}
typedef triton::runtime::options_t options_t;
typedef triton::runtime::options_space_t options_space_t;
pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string& args, const rt::function::grid_fn_ty& grid){
rt::function* fn = id_fn_map.at(op_id).get();
auto wrapper = [&grid](const rt::options_t& opt){
pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference);
for(auto x: opt.defines)
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
obj.attr(x.first.c_str()) = std::stoi(x.second);
return grid(*obj.cast<rt::options_t*>());
};
rt::kernel* kernel = fn->autotune((void**)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]);
return opt_cache_.at(&kernel->opt);
}
void init_superblocking(pybind11::module &m);
void init_launch(pybind11::module &m);
PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API";
// bindings for triton classes
pybind11::enum_<rt::arg_type>(m, "arg_type")
.value("int1", rt::INT1_T)
.value("int8", rt::INT8_T)
.value("int16", rt::INT16_T)
.value("int32", rt::INT32_T)
.value("int64", rt::INT64_T)
.value("half", rt::HALF_T)
.value("float", rt::FLOAT_T)
.value("int1" , rt::INT1_T)
.value("int8" , rt::INT8_T)
.value("int16" , rt::INT16_T)
.value("int32" , rt::INT32_T)
.value("int64" , rt::INT64_T)
.value("half" , rt::HALF_T)
.value("float" , rt::FLOAT_T)
.value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T);
pybind11::enum_<rt::asm_mode_t>(m, "asm_mode")
.value("ptx", rt::ASM_NV_PTX)
.value("ptx" , rt::ASM_NV_PTX)
.value("sass", rt::ASM_NV_SASS);
pybind11::class_<options_t>(m, "options")
.def(pybind11::init<>())
.def("d", &options_t::D<int>)
.def_readwrite("num_warps", &options_t::num_warps)
.def_readwrite("defines" , &options_t::defines);
pybind11::class_<rt::options_t>(m, "options", pybind11::dynamic_attr())
.def_readwrite("num_warps", &rt::options_t::num_warps)
.def_readwrite("defines" , &rt::options_t::defines);
pybind11::class_<options_space_t>(m, "options_space")
pybind11::class_<rt::options_space_t>(m, "options_space")
.def(pybind11::init<>())
.def_readwrite("defines", &options_space_t::defines)
.def_readwrite("num_warps", &options_space_t::num_warps);
.def_readwrite("num_warps", &rt::options_space_t::num_warps)
.def_readwrite("defines" , &rt::options_space_t::defines);
// hooks into triton constructs since frameworks may not use pybind11
m.def("get_fn_signature", &get_fn_signature);
m.def("get_fn_asm", &get_fn_asm);
// m.def("get_fn_asm", &get_fn_asm);
m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn);
m.def("delete_fn", &delete_fn);
m.def("make_op_id", &make_op_id);
m.def("cleanup", &cleanup);
;
m.def("autotune", &autotune, pybind11::return_value_policy::reference);
m.def("launch_kernel", &launch_kernel);
init_launch(m);
init_superblocking(m);
}

View File

@@ -1,95 +0,0 @@
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
// as a string constructed with struct.pack in python
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "torch/script.h"
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime_api.h>
namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::pair<int, int> map_key_t;
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
std::shared_ptr<drv::device> host_device;
std::shared_ptr<drv::context> host_context;
std::shared_ptr<drv::stream> host_stream;
int64_t cdiv_sum(torch::Tensor x, int64_t div){
TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor")
auto _x = x.accessor<int, 1>();
int64_t ret = 0;
for(size_t i = 0; i < x.size(0); i++)
ret += (_x[i] + div - 1) / div;
return ret;
}
void init_host_stream() {
if(!host_stream){
host_device.reset(new drv::host_device());
host_context.reset(drv::context::create(&*host_device));
host_stream.reset(drv::stream::create(host_context->backend()));
}
}
CUstream torch_get_cuda_stream(int64_t dev_id) {
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
}
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
CUdevice ret;
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
return ret;
}
void synchronize(int64_t dev_id) {
if(dev_id == -1){
init_host_stream();
host_stream->synchronize();
}
else{
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
stream.synchronize();
}
}
torch::Tensor cuda_empty_like(torch::Tensor x){
if(x.nbytes() == 0)
return torch::empty_like(x);
void* data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
return ret;
}
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
const std::vector<std::string>& constant_names, const std::vector<torch::Tensor>& constant_vals){
rt::function* fn = id_fn_map.at({op_id, dev_id}).get();
for(size_t n = 0; n < constant_names.size(); n++){
const torch::Tensor& x = constant_vals[n];
fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
}
if(dev_id == -1){
init_host_stream();
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream, &*host_device);
}
else{
C10_CUDA_CHECK(cudaSetDevice(dev_id));
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
triton::driver::cu_device device(torch_get_cuda_device(dev_id), false);
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream, &device);
}
}
static auto registry = torch::RegisterOperators()
.op("triton::launch_kernel", &launch_kernel)
.op("triton::cuda_empty_like", &cuda_empty_like)
.op("triton::cdiv_sum", &cdiv_sum)
.op("triton::synchronize", &synchronize);

View File

@@ -0,0 +1,83 @@
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
// as a string constructed with struct.pack in python
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "torch/script.h"
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime_api.h>
namespace rt = triton::runtime;
namespace drv = triton::driver;
typedef std::pair<int, int> map_key_t;
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<int, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<int, std::shared_ptr<drv::device>> tt_devices;
extern std::map<int, std::shared_ptr<drv::stream>> tt_streams;
int64_t cdiv(int64_t a, int64_t b) {
return (a + b - 1) / b;
}
int64_t largest_pow2_divisor(int64_t a){
if(a % 8 == 0) return 8;
if(a % 4 == 0) return 4;
if(a % 2 == 0) return 2;
return 1;
}
int64_t cdiv_sum(torch::Tensor x, int64_t div){
TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor")
auto _x = x.accessor<int, 1>();
int64_t ret = 0;
for(size_t i = 0; i < x.size(0); i++)
ret += (_x[i] + div - 1) / div;
return ret;
}
CUstream torch_get_cuda_stream(int64_t dev_id) {
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
}
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
CUdevice ret;
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
return ret;
}
void synchronize(int64_t dev_id) {
tt_streams[dev_id]->synchronize();
}
torch::Tensor cuda_empty_like(torch::Tensor x){
if(x.nbytes() == 0)
return torch::empty_like(x);
void* data;
cudaMalloc(&data, x.nbytes());
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
return ret;
}
void cuda_set_device(int64_t dev_id) {
if(dev_id >= 0)
C10_CUDA_CHECK(cudaSetDevice(dev_id));
}
void init_launch(pybind11::module &m) {
m.def("cuda_set_device", &cuda_set_device);
m.def("cuda_empty_like", &cuda_empty_like);
m.def("largest_pow2_divisor", &largest_pow2_divisor);
m.def("cdiv", &cdiv);
m.def("cdiv_sum", &cdiv_sum);
m.def("synchronize", &synchronize);
}

View File

@@ -0,0 +1,117 @@
#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;
void segment_blocks(torch::Tensor layout, torch::Tensor idx, torch::Tensor scratch, int max_width, ret_t& ret){
size_t H = layout.size(0);
size_t M = layout.size(1);
size_t N = layout.size(2);
torch::Tensor tmp = torch::zeros_like(layout);
auto _tmp = tmp.accessor <int, 3>();
auto _layout = layout.accessor <int, 3>();
auto _idx = idx.accessor <int, 3>();
auto _scratch = scratch.accessor<int, 3>();
std::vector<int> current(H, 0);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for(size_t h = 0; h < H; h++){
// surrounding indices
std::vector<int> ii_left(max_width, -1);
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
for(size_t m = 0; m < M; m++){
for(size_t n = 0; n < N; n++){
int v = _layout[h][m][n];
if(v == 0)
continue;
int n_left= ii_left[max_width-1];
int m_top = ii_top [max_width-1][n];
int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
int topleft = (m_top >=0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
int width = std::min(left, std::min(top, topleft)) + 1;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for(int nn = n_left + 1; nn < n; nn++)
if(ii_top[max_width-1][nn] > ii_top[max_width-1][n])
width = 1;
_tmp[h][m][n] = width;
// update n_left ring buffer
for(int k = 0; k < max_width-1; k++)
ii_left[k] = ii_left[k+1];
ii_left[max_width-1] = n;
// update ii_top ring buffer
for(int k = 0; k < max_width-1; k++)
ii_top[k][n] = ii_top[k+1][n];
ii_top[max_width-1][n] = m;
// block is too small -- skip
if(width != max_width)
continue;
// retained blocks are set to zeros
for(size_t km = 0; km < max_width; km++)
for(size_t kn = 0; kn < max_width; kn++)
{
int mm = ii_top[km][n];
int nn = ii_left[kn];
if(mm < 0 || nn < 0)
continue;
_layout[h][mm][nn] = 0;
_tmp[h][mm][nn] = 0;
_scratch[h][current[h]][0] = (int)h;
_scratch[h][current[h]][1] = (int)mm;
_scratch[h][current[h]][2] = (int)nn;
_scratch[h][current[h]][3] = _idx[h][mm][nn];
current[h]++;
}
}
}
}
std::vector<torch::Tensor> to_cat;
for(size_t h = 0; h < H; h++)
if(current[h] > 0)
to_cat.push_back(scratch[h].slice(0, 0, current[h]));
if(!to_cat.empty())
ret.push_back(std::make_tuple(max_width, torch::cat(to_cat)));
}
ret_t superblock(torch::Tensor layout, int start_width) {
ret_t ret;
// block index
torch::Tensor idx = torch::zeros_like(layout);
int current = 0;
int64_t H = layout.size(0);
int64_t M = layout.size(1);
int64_t N = layout.size(2);
auto _layout = layout.accessor <int, 3>();
auto _idx = idx.accessor<int, 3>();
for(int64_t h = 0; h < H; h++)
for(int64_t m = 0; m < M; m++)
for(int64_t n = 0; n < N; n++){
if(_layout[h][m][n] == 0)
continue;
_idx[h][m][n] = current++;
}
// scratch memory
torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
for(int max_width = start_width; max_width > 0; max_width /= 2)
segment_blocks(layout, idx, scratch, max_width, ret);
return ret;
}
void init_superblocking(pybind11::module &m) {
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
}

View File

@@ -0,0 +1,50 @@
import itertools
import torch
import triton as tt
import pytest
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i*block: (i+1)*block, j*block: (j+1)*block]
return ret
def mask_tensor(x, mask, block, value = 0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
return ret
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
[
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
for at in [False, True]\
for bt in [False, True]\
for block in [16, 32, 64]
]
)
def test_op(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384):
# set seed
torch.random.manual_seed(0)
# create inputs
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE]
layout = torch.randint(2, (H, shape[0]//BLOCK, shape[1]//BLOCK))
# triton result
op = tt.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
ra = sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
rb = sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b
rc = op(ra, rb)
# torch result
ta = mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
tb = mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b
ta = ta.transpose(2, 3) if TRANS_A else ta
tb = tb.transpose(2, 3) if TRANS_B else tb
tc = torch.matmul(ta, tb)
tc = mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
tc = sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
# compare
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(rc, tc, rtol=rtol, atol=atol)

17
python/tests/test_conv.py Normal file
View File

@@ -0,0 +1,17 @@
import torch
import triton
def test_op():
torch.manual_seed(0)
DTYPE = torch.float16
N, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
pad, stride, = (1, 1), (1, 1)
dilation = (1, 1)
a = torch.rand((N , CI, H, W ), dtype=DTYPE, device='cuda') / CI**.5
b = torch.rand((CI, R , S, CO), dtype=DTYPE, device='cuda') / CI**.5
th_c = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, dilation)
tt_c = triton.ops.conv(a, b, pad, stride)
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(tt_c, th_c, atol=atol, rtol=rtol)

View File

@@ -0,0 +1,96 @@
import pytest
import itertools
import triton as tt
import torch as th
@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
[
# 1 warp
(16, 16, 16, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 4, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 8, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE),
(128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE),
(128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE),
(128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE)
]
for DTYPE in ['float16']
for AT in [False, True]
for BT in [False, True]
]))
def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
th.manual_seed(0)
tt.ops._matmul.kernel = dict()
tt.ops._matmul.TM = [TM]
tt.ops._matmul.TN = [TN]
tt.ops._matmul.TK = [TK]
tt.ops._matmul.num_warps = [NWARP]
if M is None: M = TM
if N is None: N = TN
if K is None: K = TK
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
a = a.t() if AT else a
b = b.t() if BT else b
th_c = th.matmul(a, b)
tt_c = tt.ops.matmul(a, b)
rtol, atol = {th.float32: (1e-4, 1e-5),
th.float16: (1e-2, 1e-3)}[DTYPE]
assert th.allclose(tt_c, th_c, atol=atol, rtol=rtol)
def do_bench(fn, flops = 0, warmup = 10, rep = 50):
start_event = th.cuda.Event(enable_timing=True)
end_event = th.cuda.Event(enable_timing=True)
ret = fn()
for i in range(warmup):
fn()
th.cuda.synchronize()
start_event.record()
for i in range(rep):
fn()
end_event.record()
th.cuda.synchronize()
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms, flops/time_ms*1e-9, ret
def perf_op(dtype=th.float16, warmup=10, rep=50):
AT, BT = False, False
configs = [(N, N, N) for N in [128, 8192]]
for M, N, K in configs:
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
if AT: a = a.t()
if BT: b = b.t()
a = a[::,::]
b = b[::,::]
TH_MS, TH_TFLOPS, _ = do_bench(lambda: th.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
TT_MS, TT_TFLOPS, _ = do_bench(lambda: tt.ops.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
print((M, N, K), TH_MS, TT_MS)

View File

@@ -0,0 +1,8 @@
import torch
import triton
def test_op(M = 1024, N = 1024, dtype = torch.float32):
x = torch.randn(M, N, dtype=dtype, device='cuda')
th_y = torch.softmax(x, dim=-1)
tt_y = triton.ops.softmax(x)
assert torch.allclose(tt_y, th_y)

View File

@@ -1,8 +1,13 @@
from .kernel import *
# TODO: torch needs to be imported first
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
# clean-up libtriton resources
# libtriton resources
import atexit
import triton._C.libtriton as libtriton
@atexit.register
def cleanup():
libtriton.cleanup()
libtriton.cleanup()
from .kernel import *
from . import ops

View File

@@ -15,18 +15,6 @@ codes = {
libtriton.arg_type.buffer: 'P'
}
sizes = {
libtriton.arg_type.int1: 1,
libtriton.arg_type.int8: 1,
libtriton.arg_type.int32: 4,
libtriton.arg_type.int64: 8,
libtriton.arg_type.half: 2,
libtriton.arg_type.float: 4,
libtriton.arg_type.double: 8,
libtriton.arg_type.buffer: 8
}
def th_to_triton(obj):
tys = {
torch.int8: 'char',
@@ -43,92 +31,65 @@ def th_to_triton(obj):
return [th_to_triton(x)[0] for x in obj]
return [str(obj)]
def cdiv(a, b):
return (a + b - 1) // b
return libtriton.cdiv(a, b)
def cdiv_sum(a, b):
return torch.ops.triton.cdiv_sum(a, b)
def synchronize(device):
dev_id = device.index
dev_id = -1 if dev_id is None else dev_id
torch.ops.triton.synchronize(dev_id)
libtriton.synchronize(dev_id)
def read(path):
with open(path, 'r') as f:
source = f.read()
return source
class kernel:
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]):
def __init__(self, src, device, defines = dict(), num_warps = [4]):
self.src = src
self.opt = libtriton.options_space()
self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()]
self.opt.num_warps = num_warps
# device
assert device.type in ['cuda', 'cpu']
if device.type == 'cuda':
self.device = torch.cuda.current_device() if device.index is None else device.index
if device.type == 'cpu':
self.device = -1
# C++ function wrapper
self.op_id = libtriton.make_op_id()
self.registered = set()
arg_types = libtriton.get_fn_signature(self.src, self.opt)
size = sum([sizes[x] for x in arg_types])
libtriton.register_fn(self.op_id, self.device, self.src, self.opt)
# debug mode
self.is_debug = 'TRITON_DEBUG' in os.environ
# signature
arg_types = libtriton.get_fn_signature(self.op_id)
self.tys = ''.join([codes[x] for x in arg_types])
def asm(self, mode, device, **kwargs):
dev_id = device.index
# assembly mode
supported = {
'ptx': libtriton.asm_mode.ptx,
'sass': libtriton.asm_mode.sass,
}
if mode not in supported:
raise('ASM mode must be in ', supported.keys())
mode = supported[mode]
# disambiguates #defines
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
def _single_value_or_err(x, key):
if isinstance(x, list) and len(x) == 1:
return x[0]
if isinstance(x, list) and len(x) > 1:
if key in kwargs:
return kwargs[key]
raise ValueError(f'Parameter {key}={x} was auto-tuned during kernel creation. '
'Please supply an explicit value as a keyword argument.')
return str(x)
defines = dict()
for (D, V) in self.opt.defines:
defines[D] = _single_value_or_err(V, D)
opt = libtriton.options()
opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps')
opt.defines = defines
# run
return libtriton.get_fn_asm((self.op_id, dev_id), mode, opt)
def __call__(self, *args, **kwargs):
if 'TRITON_DEBUG_MODE' in os.environ:
def __call__(self, *args, grid):
# debug mode (initialize)
if self.is_debug:
_args = args
args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args]
for i in range(len(args)):
if isinstance(args[i], torch.Tensor):
args[i] = torch.ops.triton.cuda_empty_like(args[i])
args[i] = libtriton.cuda_empty_like(args[i])
args[i].copy_(_args[i])
torch.cuda.synchronize()
for x in args:
if isinstance(x, torch.Tensor):
device = x.device.index
device = -1 if device is None else device
break
# lazily register function for device
libtriton.register_fn((self.op_id, device), self.src, self.opt)
# launch grid
if 'grid' not in kwargs:
raise RuntimeError('Must provide grid for kernel launch')
grid = kwargs['grid']
libtriton.register_grid((self.op_id, device), grid)
# re-allocate buffers for auto-tuning
if 'autotune_buf' in kwargs:
pass
# launch
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
names = list(kwargs['constants'].keys()) if 'constants' in kwargs else []
constants = list(kwargs['constants'].values()) if 'constants' in kwargs else []
torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants)
if 'TRITON_DEBUG_MODE' in os.environ:
torch.cuda.synchronize()
# initialize cuda device if necessary
libtriton.cuda_set_device(self.device)
# pack parameters into a byte buffer
params = pack(self.tys, *args)
# auto-tune if necessary
opt = libtriton.autotune(self.op_id, self.device, params, grid)
# run kernel
grid = grid(opt)
grid_0 = grid[0]
grid_1 = 1 if len(grid) < 2 else grid[1]
grid_2 = 1 if len(grid) < 3 else grid[2]
libtriton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)
# debug mode (finalize)
if self.is_debug:
for i in range(len(args)):
if isinstance(args[i], torch.Tensor):
_args[i].copy_(args[i].clone())

View File

@@ -0,0 +1,4 @@
from .conv import _conv, conv
from .matmul import _matmul, matmul
from .softmax import _softmax, softmax
from . import blocksparse

View File

@@ -0,0 +1 @@
from .matmul import matmul

View File

@@ -0,0 +1,198 @@
__global__ void NAME (TYPE* A __readonly __noalias __aligned(16),
TYPE* B __readonly __noalias __aligned(16),
TYPE* C __noalias __aligned(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8),
long stride_za __multipleof(8),
long stride_zb __multipleof(8),
long stride_zc __multipleof(8),
long stride_ha __multipleof(8),
long stride_hb __multipleof(8),
long stride_hc __multipleof(8),
int DS0, int DS1,
int SDD_K __multipleof(16),
int SDD_off_width,
int* lut, int* locks, int nlocks) {
/* ---------------- */
/* Prologue */
/* ---------------- */
// program ids
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int pidz = get_program_id(2);
#ifdef SDD
// load LUT header
pid1 = pid1 + SDD_off_width;
int blockidm[TM] = (0 ... TM) / BLOCK;
int blockidn[TN] = (0 ... TN) / BLOCK;
int offlutm[TM] = blockidm*(TN/BLOCK)*4;
int offlutn[TN] = blockidn*4;
int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4;
int z = *(header + 0);
int i[TM] = *(header + 1 + offlutm);
int j[TN] = *(header + 2 + offlutn);
int AS1 = SDD_K / TZ;
int lockid = select(TZ > 1, 1, 0);
int offka = pid0 * AS1;
int offkb = pid0 * AS1;
int offmc = 0;
int offnc = 0;
int offpa = 0;
int offpb = 0;
int maxid = TZ;
int offhc = 0;
int offha = z;
int offhb = z;
int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK);
int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK);
#else
// load LUT header
int *header = lut + pid0 * 6;
int offset = *(header + 0);
int AS1 = *(header + 1);
int column = *(header + 2);
int depth = *(header + 3);
int lockid = *(header + 4);
int maxid = *(header + 5);
int *pinc = lut + offset;
int offhc = depth;
#ifdef DSD
// output offset
int offnc = pid1 * TN;
int offmc = column * TM;
int offpc = 0;
// dense input offset
int offnb = pid1 * TN;
int offkb __multipleof(8) = *pinc;
int offpb = 0;
// sparse input offset
int offma = 0;
int offka = 0;
long offpa __multipleof(8) = *(pinc + 1);
offpa = offpa * BLOCK * BLOCK;
int offha = 0;
int offhb = depth;
#endif
#ifdef DDS
// output offset
int offmc = pid1 * TM;
int offnc = column * TN;
int offpc = 0;
// dense input offset
int offma = pid1 * TM;
int offka __multipleof(8) = *pinc;
int offpa = 0;
// sparse input offset
int offnb = 0;
int offkb = 0;
long offpb __multipleof(8) = *(pinc + 1);
offpb = offpb * BLOCK * BLOCK;
int offha = depth;
int offhb = 0;
#endif
int ram[TM] = offma + 0 ... TM;
int rbn[TN] = offnb + 0 ... TN;
#endif
// initialize a, b pointers
int rka[TK] = offka + 0 ... TK;
int rkb[TK] = offkb + 0 ... TK;
TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
// pre-fetch
#ifdef DDS
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
#else
bool checkam[TM, TK] = AS1 > 0;
#endif
#ifdef DSD
bool checkbn[TK, TN] = rbn[newaxis, :] < DS0;
#else
bool checkbn[TK, TN] = AS1 > 0;
#endif
TYPE a[TM, TK] = checkam ? *pa : 0;
TYPE b[TK, TN] = checkbn ? *pb : 0;
/* ---------------- */
/* Inner Loop */
/* ---------------- */
// create result tile
float acc[TM, TN] = 0;
int step = TK;
for(int k = AS1; k > 0; k -= step) {
acc += a @ b;
// update pointers
#ifdef SDD
int inc_a = TK * STRIDE_AK;
int inc_b = TK * STRIDE_BK;
#else
pinc += 2;
#ifdef DSD
int inc_b __multipleof(8) = *pinc;
int inc_a __multipleof(8) = *(pinc + 1);
inc_b = inc_b * STRIDE_BK;
#endif
#ifdef DDS
int inc_a __multipleof(8) = *pinc;
int inc_b __multipleof(8) = *(pinc + 1);
inc_a = inc_a * STRIDE_AK;
#endif
#endif
pa += inc_a;
pb += inc_b;
// pre-fetch
bool checkak[TM, TK] = k > TK;
bool checkbk[TK, TN] = k > TK;
bool checka[TM, TK] = checkam && checkak;
bool checkb[TK, TN] = checkbk && checkbn;
a = *?(checka)pa;
b = *?(checkb)pb;
}
TYPE c[TM, TN] = acc;
/* ---------------- */
/* Epilogue */
/* ---------------- */
// initialize c pointers
#ifdef SDD
bool checkc[TM, TN] = 1;
// rematerialize
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4;
int rr_offlutn[TN] = rr_blockidn*4;
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
int bkid[TM, TN] = *(header + off_bkid);
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
// range within blocks
int rcm[TM] = (0 ... TM) % BLOCK;
int rcn[TN] = (0 ... TN) % BLOCK;
#else
int rcm[TM] = offmc + 0 ... TM;
int rcn[TN] = offnc + 0 ... TN;
#ifdef DSD
bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
#endif
#ifdef DDS
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
#endif
#endif
TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
// write-back directly
if(lockid == 0) {
*?(checkc) pc = c;
}
// accumulate partial result using spin-locks
else {
int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % maxid);
atomic_xchg(plock, 0);
}
}

View File

@@ -0,0 +1,467 @@
import triton
import triton._C.libtriton as libtriton
import torch
import os
import math
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
##############
# MAIN API #
##############
class _matmul(torch.autograd.Function):
sdd_cache = dict()
dsd_cache = dict()
dds_cache = dict()
locks = dict()
# Given an array sizes representing reduction size for each
# column of a block-mode matrix multiplication,
# performs load-balancing to achieve more smaller reductions
# between `seg_size` elements
@staticmethod
def load_balance(sizes, block):
# segment size
# heuristics taken from OpenAI blocksparse code
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
max_size = sizes.max()
min_size = sizes[sizes != 0].min()
#if max_size > min_size * 2.0:
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
#else:
# seg_max = max_size
seg_max = max_size
seg_min = max(triton.cdiv(seg_max, 4), 4)
# split reduction into segments
div = sizes // seg_max
rem = sizes % seg_max
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
width = packs.sum()
segments = torch.empty(width, dtype=sizes.dtype)
column = torch.empty_like(segments)
lockid = torch.zeros_like(segments)
maxid = torch.zeros_like(segments)
nlocks = 0
current = 0
col_idx = 0
for i in range(len(sizes)):
d, r = div[i], rem[i]
isempty = sizes[i] < seg_min
last = current + d + (r >= seg_min) + isempty
# column id
column[current:last] = col_idx
# lock id
if d > 1 or (d == 1 and r >= seg_min):
nlocks += 1
lockid[current:last] = nlocks
maxid[current:last] = last - current
# segment size
segments[current:current+d] = seg_max
if r < seg_min and not isempty:
segments[current+d-1] += r
if r >= seg_min or isempty:
segments[current+d] = r
current = last
col_idx += 1
offsets = torch.zeros_like(segments)
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
return segments, column, lockid, maxid, offsets
@staticmethod
def get_locks(size, dev):
if dev not in _matmul.locks or \
size > _matmul.locks[dev].size(0):
_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
return _matmul.locks[dev]
##########################
# SPARSE = DENSE x DENSE #
##########################
@staticmethod
def make_sdd_lut(layout, block, dtype, device):
start_width = 64 // block
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
luts, widths, packs = [], [], []
for size, nnz in superblocks:
width = nnz.shape[0] // (size*size)
h = nnz[:, 0]
i = nnz[:, 1]
j = nnz[:, 2]
b = nnz[:, 3]
lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
luts.append(lut.type(torch.int32).to(device))
widths.append(width)
packs.append(size)
# create locks
return luts, None, widths, packs
@staticmethod
def _sdd_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, luts, num_locks, widths, packs):
if trans_c:
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
device = a.device
is_16_multiple = AS3 % 16 == 0
is_32_multiple = AS3 % 32 == 0
is_64_multiple = AS3 % 64 == 0
if not is_16_multiple:
raise ValueError('Reduction size for SDD must be a multiple of 16')
# create kernel
total_width = sum([width*pack*pack for width,pack in zip(widths, packs)])
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
for lut, width, pack in zip(luts, widths, packs):
num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache:
F32TK = [8, 16]
#F16TK = [16]
#F16TK += [32] if is_32_multiple else []
#F16TK += [64] if is_64_multiple else []
F16TK = [64]
TK = {torch.float32: F32TK,
torch.float16: F16TK}[dtype]
defines = {'TM': block*pack, 'TN': block*pack, 'TMN': block*block*pack*pack, 'BLOCK': block,
'TK': TK, 'TYPE': dtype,
'STRIDE_AM': '1' if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=[1, 2, 4])
kernel = _matmul.sdd_cache[key]
# create output
locks = _matmul.get_locks(2*width*AS0*num_lock, a.device)
# maximum grid size is 65535
# so operation might be decomposed into multiple
# kernel calls
max_width = 49152
for off_width in range(0, width, max_width):
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
a.stride(2), b.stride(2), block,
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(0),
AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock,
grid = lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
# save for backward pass
return c
##########################
# DENSE = DENSE x SPARSE #
# DENSE = SPARSE x DENSE #
##########################
# Given a binary layout of 0s and 1s,
# Construct look-up table for efficient execution on GPUs
@staticmethod
def make_dxx_lut(layout, block, step, trans, device, transform = lambda idx: idx):
# load-balancing
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
segments = _empty.clone()
column = _empty.clone()
depth = _empty.clone()
lockid = _empty.clone()
maxid = _empty.clone()
offsets = _empty.clone()
current_offset = 0
current_maxid = 0
for z in range(layout.size(0)):
if trans:
sizes = torch.sum(layout[z, :, :], 1)
else:
sizes = torch.sum(layout[z, :, :], 0)
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
z_depth = z * torch.ones_like(z_segments)
z_lockid[z_lockid > 0] += current_maxid
current_maxid = z_lockid.max()
# concatenate depth
segments = torch.cat((segments, z_segments))
column = torch.cat((column, z_column))
depth = torch.cat((depth, z_depth))
maxid = torch.cat((maxid, z_maxid))
offsets = torch.cat((offsets, current_offset + z_offsets))
lockid = torch.cat((lockid, z_lockid))
current_offset += layout[z, :, :].sum()
segments *= step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
else:
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
num_blocks = nnz.size(0)
offsets = torch.min(offsets, (num_blocks - 1)*torch.ones_like(offsets))
idx = transform(nnz[:, 2]*block)
xincs = idx.clone()
xincs[1:] -= idx[:-1]
# divide block into multiple steps
div = block // step
xincs = xincs.view(-1, 1).repeat(1, div)
xincs[:, 1:] = step
xincs[:, 0 ] -= (div-1)*step
# first increment for each reduction is actually the offset
xincs[offsets[segments>0], 0] = idx[offsets[segments>0]]
xincs = xincs.view(-1)
# block-mode input increments
if trans:
widx = torch.arange(num_blocks)
else:
widx = _empty.clone()
current_offset = 0
for z in range(layout.size(0)):
layoutw = layout[z, :, :].clone()
msum = layoutw.sum()
layoutw[layoutw > 0] = 1 + torch.arange(msum)
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum
widx = widx
wincs = widx*block*block
wincs[1:] -= widx[:-1]*block*block
wincs = wincs.view(-1, 1).repeat(1, div)
if trans:
wincs[:, 1:] = step
wincs[:, 0] -= (div-1)*step
else:
wincs[:, 1:] = step*block
wincs[:, 0] -= (div - 1)*step*block
wincs[offsets[segments>0], 0] = widx[offsets[segments>0]]
wincs = wincs.view(-1)
# adjust offset and segment size
offsets *= 2*div
segments *= div
# create header
width = column.size(0)
offsets += 6*width
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
num_locks = max(1, lockid.max())
return lut, num_locks, width, None
@staticmethod
def _dds_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = spdims[0]
BS1 = block * spdims[2 if trans_b else 1]
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dds_cache:
TM = [64, 128] if dtype == torch.float32 else [64, 128, 256]
TK = [8] if dtype == torch.float32 else [16]
defines = {'TM': TM, 'TN': block, 'TK': TK,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else 1,
'STRIDE_BN': block if trans_b else 1,
'STRIDE_BK': 1 if trans_b else block,
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel',
'DDS': True}
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
kernel = _matmul.dds_cache[key]
# output
CS0 = AS0
CS1 = AS1
CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2
locks = _matmul.get_locks(2*AS0*AS2//32*num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
a.stride(2), block, c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
return c
@staticmethod
def _dsd_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = spdims[0]
AS1 = block * spdims[2 if trans_a else 1]
AS2 = block * spdims[1 if trans_a else 2]
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache:
TN = [64, 128] if dtype == torch.float32 else [64, 128]
TK = [8] if dtype == torch.float32 else [16]
defines = {'TM': block, 'TN': TN, 'TK': TK,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else block,
'STRIDE_AK': block if trans_a else 1,
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dsd_kernel',
'DSD': True}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
kernel = _matmul.dsd_cache[key]
# output
CS0 = BS0
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
locks = _matmul.get_locks(2*BS0*BS3//32*num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
block, b.stride(2), c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
return c
fn = {'sdd': _sdd_matmul.__get__(object),
'dsd': _dsd_matmul.__get__(object),
'dds': _dds_matmul.__get__(object)}
@staticmethod
def forward(ctx, a, b, trans_a, trans_b, trans_c,
mode, spdims, block,
c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block,
c_lut, c_num_locks, c_width, c_packs)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_num_locks = da_num_locks
ctx.da_lut = da_lut
ctx.da_width = da_width
ctx.da_packs = da_packs
ctx.db_lut = db_lut
ctx.db_num_locks = db_num_locks
ctx.db_width = db_width
ctx.db_packs = db_packs
ctx.mode = mode
ctx.spdims = spdims
ctx.block = block
ctx.trans_a = trans_a
ctx.trans_b = trans_b
return c
@staticmethod
def backward(ctx, dc):
# saved for backward
a, b = ctx.saved_tensors
mode = ctx.mode
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block,
ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs)
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None
class matmul:
def make_lut(self, dtype, device):
key = (dtype, device)
if key in self.lut_cache:
return self.lut_cache[key]
# C look-up table
layout, block = self.layout, self.block
step = 8 if dtype == torch.float32 else 16
if self.mode == 'sdd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dsd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
elif self.mode == 'dds':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
# DA look-up table
if self.mode == 'sdd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
elif self.mode == 'dsd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dds':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
# DB look-up table
if self.mode == 'sdd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
elif self.mode == 'dsd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
elif self.mode == 'dds':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs)
return self.lut_cache[key]
def __init__(self, layout, block, mode, trans_a = False, trans_b = False):
if mode not in ['sdd', 'dsd', 'dds']:
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
# look-up table cache
self.lut_cache = dict()
# attributes
self.trans_a = trans_a
self.trans_b = trans_b
self.mode = mode
self.spdims = layout.shape
self.block = block
self.layout = layout
# pad shapes of a tensor to make it
# compatible with kernel calls
@staticmethod
def _pad_shape(x, is_sparse):
max_dim = 3 if is_sparse else 4
for i in range(max_dim - x.dim()):
x = x.unsqueeze(0)
return x
def __call__(self, a, b):
c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
# pad shapes with ones
a = matmul._pad_shape(a, self.mode == 'dsd')
b = matmul._pad_shape(b, self.mode == 'dds')
# execute
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False,
self.mode, self.spdims, self.block,
c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs)
return c

View File

@@ -1,16 +1,9 @@
import torch
import triton
class _conv(torch.autograd.Function):
src = """
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
TYPE *B __noalias __readonly __aligned(16),
TYPE *C __noalias __aligned(16),
float alpha,
// equivalent matmul
int M __retune,
int N __retune,
int K __retune,
int M, int N, int K,
// convolution properties
int pad_h, int pad_w, int stride_h, int stride_w,
// pointer increment
@@ -130,73 +123,4 @@ class _conv(torch.autograd.Function):
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}
"""
kernel = dict()
@staticmethod
def unpack(IDX, CI, R, S):
s = IDX % S
cr = IDX // S
r = cr % R
ci = cr // R
return ci, r, s
@staticmethod
def forward(ctx, a, b, pad, stride, time):
# create kernel if necessary
dtype = a.dtype
# shapes
Z, CI, H, W = a.shape
_, R, S, CO = b.shape
P = (H + 2*pad[0] - R)//stride[0] + 1
Q = (W + 2*pad[1] - S)//stride[1] + 1
# compile kernel
if dtype not in _conv.kernel:
TK = 8
defines = {
'TYPE' : dtype,
'TM' : [16, 32, 64, 128],
'TN' : [16, 32, 64, 128],
'TK' : [TK],
'TZ' : [1],
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
}
idx = torch.arange(CI*R*S)
ci, r, s = _conv.unpack(idx, CI, R, S)
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
delta = delta.type(torch.int32).cuda()
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
delta, kernel = _conv.kernel[dtype]
# allocate output
c = torch.empty([Z, CO, P, Q], dtype=dtype)
# enqueue
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')),
triton.cdiv(CO, opt.d('TN'))]
time[0] = kernel(a, b, c, 1., Z*P*Q, CO, CI*R*S,
pad[0], pad[1], stride[0], stride[1],
delta,
a.stride(0), a.stride(1), a.stride(2), a.stride(3),
b.stride(0), b.stride(1), b.stride(2), b.stride(3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
grid=grid, bench=100)
return c
conv = _conv.apply
torch.manual_seed(0)
Z, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
pad = (1, 1)
stride = (1, 1)
a = torch.rand((Z, CI, H, W)).cuda()
b = torch.rand((CI, R, S, CO)).cuda()
time = [None]
cc = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, [1, 1])
c = conv(a, b, pad, stride, time)
print((cc - c).abs().max() / max(cc.max(), c.max()))
print(time[0], 2*Z*H*W*CI*CO*R*S/(time[0]*1e-9)*1e-12)
#zc = torch.matmul(a,b)
#zc_ = dot(a,b)
}

57
python/triton/ops/conv.py Normal file
View File

@@ -0,0 +1,57 @@
import torch
import triton
import os
class _conv(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'conv.c'))
kernel = dict()
@staticmethod
def unpack(IDX, CI, R, S):
s = IDX % S
cr = IDX // S
r = cr % R
ci = cr // R
return ci, r, s
@staticmethod
def forward(ctx, a, b, pad, stride):
# create kernel if necessary
dtype = a.dtype
device = a.device
# shapes
Z, CI, H, W = a.shape
_, R, S, CO = b.shape
P = (H + 2*pad[0] - R)//stride[0] + 1
Q = (W + 2*pad[1] - S)//stride[1] + 1
# compile kernel
if (dtype, device) not in _conv.kernel:
TK = 16
defines = {
'TYPE' : dtype,
'TM' : [32, 64, 128],
'TN' : [32, 64, 128],
'TK' : [TK],
'TZ' : [1],
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
}
idx = torch.arange(CI*R*S)
ci, r, s = _conv.unpack(idx, CI, R, S)
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
delta = delta.type(torch.int32).cuda()
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, num_warps=[4], defines=defines))
delta, kernel = _conv.kernel[dtype]
# allocate output
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
# enqueue
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), 1., Z*P*Q, CO, CI*R*S,
pad[0], pad[1], stride[0], stride[1],
delta.data_ptr(),
a.stride(0), a.stride(1), a.stride(2), a.stride(3),
b.stride(0), b.stride(1), b.stride(2), b.stride(3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.TM), triton.cdiv(CO, opt.TN)])
return c
conv = _conv.apply

View File

@@ -0,0 +1,97 @@
#define STM 8
#define STN 8
__global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16),
float alpha,
int M,
int N,
int K __multipleof(16),
int lda __multipleof(LDA_POW2_DIV),
int ldb __multipleof(LDB_POW2_DIV),
int ldc __multipleof(LDC_POW2_DIV),
int* locks) {
// prologue
int pid = get_program_id(0);
int pidz = get_program_id(2);
int gridm = (M + TM - 1) / TM;
int gridn = (N + TN - 1) / TN;
// swizzle for better L2 performance
int width = STM*gridn;
int stm = pid / width;
int RSTM = min(gridm - stm*STM, STM);
int stn = (pid % width) / (RSTM*STN);
int RSTN = min(gridn - stn*STN, STN);
int laneid = pid % (RSTM * RSTN);
int lanem = laneid / RSTN;
int lanen = laneid % RSTN;
int pidm = stm*STM + lanem;
int pidn = stn*STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// split-k for better parrallelism
K = K / TZ;
int rk[TK] = 0 ... TK;
// pointers to operands
int offa[TM, TK] = (pidz*K + rk[newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = (pidz*K + rk[:, newaxis]) * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
TYPE* pa[TM, TK] = A + offa;
TYPE* pb[TK, TN] = B + offb;
// prefetches operands
bool checka[TM, TK] = rk[newaxis, :] < K;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
#if (IS_TK_DIV_K==1)
bool checkk[TK] = k > TK;
#else
bool checkk[TK] = rk < k - TK;
#endif
bool checka[TM, TK] = checkk[newaxis, :];
bool checkb[TK, TN] = checkk[:, newaxis];
acc += a @ b;
#if (IS_TK_DIV_K==1)
a = *?(checka)pa;
b = *?(checkb)pb;
#else
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
#endif
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
int rcm[TM] = pidm * TM + 0 ... TM;
int rcn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn[newaxis, :] < N;
#if (TZ==1)
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}

View File

@@ -0,0 +1,80 @@
import torch
import triton
import os
class _matmul(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
TM = [128]
TN = [128]
TK = [32]
TZ = 1
num_warps = [4]
@staticmethod
def largest_pow2_divisor(N):
if N % 8 == 0: return 8
if N % 4 == 0: return 4
if N % 2 == 0: return 2
return 1
_locks = dict()
_kernels = dict()
@staticmethod
def _call(a, b):
dtype = a.dtype
device = a.device
# allocate output
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), dtype=dtype, device=device)
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous()
# kernel hash
is_a_row = a.stride(1) == 1
is_b_row = b.stride(1) == 1
lda = a.stride(0) if is_a_row else a.stride(1)
ldb = b.stride(0) if is_b_row else b.stride(1)
ldc = c.stride(0)
lda_pow2_div = _matmul.largest_pow2_divisor(lda)
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
is_tk_div_k = K % 32 == 0
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
if key not in _matmul._kernels:
defines = {
'TYPE' : dtype,
'STRIDE_AM' : 'lda' if is_a_row else '1',
'STRIDE_AK' : '1' if is_a_row else 'lda',
'STRIDE_BK' : 'ldb' if is_b_row else '1',
'STRIDE_BN' : '1' if is_b_row else 'ldb',
'LDA_POW2_DIV': lda_pow2_div,
'LDB_POW2_DIV': ldb_pow2_div,
'LDC_POW2_DIV': ldc_pow2_div,
'TM' : _matmul.TM,
'TN' : _matmul.TN,
'TK' : _matmul.TK,
'TZ' : _matmul.TZ,
'IS_TK_DIV_K' : is_tk_div_k
}
_matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines)
kernel = _matmul._kernels[key]
# # locks for split-k
if device not in _matmul._locks:
_matmul._locks[device] = torch.zeros(1024*1024, dtype=torch.int32, device=device)
locks = _matmul._locks[device]
# enqueue
alpha = 1.
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, 1]
kernel(*args, grid=grid)
return c
@staticmethod
def forward(ctx, a, b):
c = _matmul._call(a,b)
return c
matmul = _matmul.apply

View File

@@ -0,0 +1,8 @@
__global__ void forward(TYPE* X, TYPE* Y) {
int pid = get_program_id(0);
int off[BLOCK] = pid * BLOCK + 0 ... BLOCK;
float x[BLOCK] = *(X + off);
float shifted[BLOCK] = exp(x - x[max]);
float sum = shifted[+];
*(Y + off) = shifted / sum;
}

View File

@@ -0,0 +1,27 @@
import torch
import triton
import os
kernels = dict()
def get_kernel(block, dtype, device):
key = (block, dtype, device)
if key not in kernels:
src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'))
defines = {'BLOCK': block, 'TYPE': dtype}
kernels[key] = triton.kernel(src, device = device, defines = defines)
return kernels[key]
class _softmax(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = torch.empty_like(x)
M, N = x.shape
kernel = get_kernel(N, x.dtype, x.device)
kernel(x.data_ptr(), y.data_ptr(), grid = lambda opt: [M, ])
return y
softmax = _softmax.apply

View File

@@ -0,0 +1,335 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "induced-zoning",
"metadata": {},
"source": [
"# Getting Started"
]
},
{
"cell_type": "markdown",
"id": "median-malaysia",
"metadata": {},
"source": [
"In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n",
"* The basic syntax of the Triton programming language\n",
"* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n",
"* The best practices for validating and benchmarking custom ops against native reference implementations"
]
},
{
"cell_type": "markdown",
"id": "identical-conditions",
"metadata": {},
"source": [
"# Writing the Compute Kernel"
]
},
{
"cell_type": "markdown",
"id": "collectible-belle",
"metadata": {},
"source": [
"Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n",
"\n",
"\n",
"```c\n",
"__global__ void add(float* z, float* x, float* y, int N){\n",
" // The `get_program_id(i)` returns the i-th coordinate\n",
" // of the program in the overaching SPMD context\n",
" // (a.k.a launch grid). This is what allows us to process\n",
" // different chunks of data in parallel.\n",
" // For those similar with CUDA, `get_program_id({0,1,2})`\n",
" // is similar to blockIdx.{x,y,z}\n",
" int pid = get_program_id(0);\n",
" // In Triton, arrays are first-class citizen. In other words,\n",
" // they are primitives data-types and are -- contrary to C and\n",
" // CUDA -- not implemented as pointers to contiguous chunks of\n",
" // memory.\n",
" // In the few lines below, we create an array of `BLOCK` pointers\n",
" // whose memory values are, e.g.:\n",
" // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n",
" // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n",
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
" float* pz [BLOCK] = z + offset;\n",
" float* px [BLOCK] = x + offset;\n",
" float* py [BLOCK] = y + offset;\n",
" // Simple element-wise control-flow for load/store operations can\n",
" // be achieved using the the ternary operator `cond ? val_true : val_false`\n",
" // or the conditional dereferencing operator `*?(cond)ptr\n",
" // Here, we make sure that we do not access memory out-of-bounds when we\n",
" // write-back `z`\n",
" bool check[BLOCK] = offset < N;\n",
" *?(check)pz = *?(check)px + *?(check)py;\n",
"}\n",
"```\n",
"\n",
"The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)."
]
},
{
"cell_type": "markdown",
"id": "forbidden-wednesday",
"metadata": {},
"source": [
"# Writing the Torch bindings"
]
},
{
"cell_type": "markdown",
"id": "numerical-agency",
"metadata": {},
"source": [
"The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n",
"\n",
"To create a `triton.kernel`, you only need three things:\n",
"* `source: string`: the source-code of the kernel you want to create\n",
"* `device: torch.device`: the device you want to compile this code for\n",
"* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n",
"\n",
"Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "sporting-keyboard",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import triton\n",
"\n",
"# source-code for Triton compute kernel\n",
"# here we just copy-paste the above code without the extensive comments.\n",
"# you may prefer to store it in a .c file and load it from there instead.\n",
"_src = \"\"\"\n",
"__global__ void add(float* z, float* x, float* y, int N){\n",
" // program id\n",
" int pid = get_program_id(0);\n",
" // create arrays of pointers\n",
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
" float* pz[BLOCK] = z + offset;\n",
" float* px[BLOCK] = x + offset;\n",
" float* py[BLOCK] = y + offset;\n",
" // bounds checking\n",
" bool check[BLOCK] = offset < N;\n",
" // write-back\n",
" *?(check)pz = *?(check)px + *?(check)py;\n",
"}\n",
" \"\"\"\n",
"# This function returns a callable `triton.kernel` object\n",
"# created from the above source code.\n",
"# For portability, we maintain a cache of kernels for different `torch.device`\n",
"# We compile the kernel with -DBLOCK=1024\n",
"_kernels = dict()\n",
"def make_add_kernel(device):\n",
" if device not in _kernels:\n",
" defines = {'BLOCK': 1024}\n",
" _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n",
" return _kernels[device]\n",
"\n",
"# This is a standard torch custom autograd Function\n",
"# The only difference is that we can now use the above kernel\n",
"# in the `forward` and `backward` functions.`\n",
"class _add(torch.autograd.Function):\n",
" \n",
" @staticmethod\n",
" def forward(ctx, x, y):\n",
" # constraints of the op\n",
" assert x.dtype == torch.float32\n",
" # *allocate output*\n",
" z = torch.empty_like(x)\n",
" # *create launch grid*:\n",
" # this is a function which takes compilation parameters `opt`\n",
" # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n",
" # triton.cdiv is a shortcut for ceil division:\n",
" # triton.cdiv(a, b) = (a + b - 1) // b\n",
" N = z.shape[0]\n",
" grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n",
" # *launch kernel*:\n",
" # pointer to the data of torch tensors can be retrieved with\n",
" # the `.data_ptr()` method\n",
" kernel = make_add_kernel(z.device)\n",
" kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n",
" return z\n",
"# Just like we standard PyTorch ops\n",
"# We use the `.apply` method to create a \n",
"# callable object for our function\n",
"add = _add.apply"
]
},
{
"cell_type": "markdown",
"id": "separated-polyester",
"metadata": {},
"source": [
"At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!"
]
},
{
"cell_type": "markdown",
"id": "exclusive-salvation",
"metadata": {},
"source": [
"# Writing a Unit Test"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "supported-ribbon",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
"The maximum difference between torch and triton is 0.0\n"
]
}
],
"source": [
"torch.manual_seed(0)\n",
"x = torch.rand(98432, device='cuda')\n",
"y = torch.rand(98432, device='cuda')\n",
"za = x + y\n",
"zb = add(x, y)\n",
"print(za)\n",
"print(zb)\n",
"print(f'The maximum difference between torch and triton is '\n",
" f'{torch.max(torch.abs(za - zb))}')"
]
},
{
"cell_type": "markdown",
"id": "otherwise-canadian",
"metadata": {},
"source": [
"Seems to work!"
]
},
{
"cell_type": "markdown",
"id": "polished-australia",
"metadata": {},
"source": [
"# Writing a Benchmark"
]
},
{
"cell_type": "markdown",
"id": "historic-glass",
"metadata": {},
"source": [
"The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "strange-luxembourg",
"metadata": {},
"outputs": [],
"source": [
"# We now want to benchmark the performance of `add`\n",
"# Against that of PyTorch for increasing vector sizes\n",
"def do_bench(fn, warmup = 10, rep = 50):\n",
" start_event = torch.cuda.Event(enable_timing=True)\n",
" end_event = torch.cuda.Event(enable_timing=True)\n",
" ret = fn()\n",
" for i in range(warmup):\n",
" fn()\n",
" torch.cuda.synchronize()\n",
" start_event.record()\n",
" for i in range(rep):\n",
" fn()\n",
" end_event.record()\n",
" torch.cuda.synchronize()\n",
" time_ms = start_event.elapsed_time(end_event) / rep\n",
" return time_ms"
]
},
{
"cell_type": "markdown",
"id": "hairy-claim",
"metadata": {},
"source": [
"We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "pleasant-valley",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"131072 0.020 0.003\n",
"262144 0.019 0.004\n",
"524288 0.016 0.016\n",
"1048576 0.033 0.033\n",
"2097152 0.071 0.070\n",
"4194304 0.142 0.144\n",
"8388608 0.287 0.286\n",
"16777216 0.572 0.568\n",
"33554432 1.139 1.110\n"
]
}
],
"source": [
"for N in [2**i for i in range(17, 26, 1)]:\n",
" x = torch.rand(N, device='cuda')\n",
" y = torch.rand(N, device='cuda')\n",
" triton_ms = do_bench(lambda: add(x, y))\n",
" torch_ms = do_bench(lambda: x + y)\n",
" # print the performance of triton and torch as well as the achieved bandwidth\n",
" print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
]
},
{
"cell_type": "markdown",
"id": "juvenile-supplement",
"metadata": {},
"source": [
"Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "agreed-backing",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}