[PYTHON][TESTS][DOC] Various improvement of the API and code quality:
* Simplified `triton.kernel` API to achieve lower latency: > .data_ptr() must now be passed as kernel argument. No more implicit conversion from torch.tensor > compilation options are now constant attributes, i.e., opt.d('VAR') becomes opt.VAR > torch.device must now be passed explicitly to triton.kernel (no longer inferred from torch.tensor arguments) * C++ tests moved to `python/tests/` * C++ tutorial created in `tutorials/` * Python tutorial created in python/tutorials/ * Version changed to 1.0alpha * No longer copying C++ headers into the Python package * added python/triton/ops/ package for pre-written Triton ops
This commit is contained in:
@@ -1,39 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _add(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N) {
|
||||
|
||||
int pid = get_program_id(0);
|
||||
|
||||
int offset[TILE] = pid * TILE + 0 ... TILE;
|
||||
float* pz[TILE] = z + offset;
|
||||
float* px[TILE] = x + offset;
|
||||
float* py[TILE] = y + offset;
|
||||
|
||||
bool check[TILE] = offset < N;
|
||||
|
||||
*pz = *px + *py;
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4])
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
z = torch.empty_like(x).cuda()
|
||||
N = x.numel()
|
||||
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
|
||||
_add.kernel(z,x,y, N, grid=grid)
|
||||
return z
|
||||
|
||||
add = _add.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(900).cuda()
|
||||
y = torch.rand(900).cuda()
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
print(torch.allclose(za,zb))
|
@@ -1,70 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _copy(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void copy(TYPE * X, TYPE * Y,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int ldx __multipleof(8)) {
|
||||
// extract program ID
|
||||
int pidm = get_program_id(0); //(1)
|
||||
int pidn = get_program_id(1); //(2)
|
||||
|
||||
// create 1D range along the two matrix's axes
|
||||
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||
|
||||
// create 2D array of pointers
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
|
||||
TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldx; //(6)
|
||||
|
||||
*py = *px;
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = None ### initialize later when we know the sizes
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
|
||||
M, N = x.shape
|
||||
|
||||
ldx = N;
|
||||
|
||||
dtype = x.dtype
|
||||
|
||||
y = torch.empty((M,N)).cuda()
|
||||
|
||||
defines= {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [32,64,128],
|
||||
'TN' : [32,64,128],
|
||||
}
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||
|
||||
if _copy.kernel is None:
|
||||
_copy.kernel = triton.kernel(_copy.src, defines=defines, num_warps=[4])
|
||||
|
||||
_copy.kernel(x, y, M, N, ldx, grid=grid)
|
||||
|
||||
return y
|
||||
|
||||
copy = _copy.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(8,4).cuda()
|
||||
|
||||
print(x)
|
||||
|
||||
ya = x
|
||||
yb = copy(x)
|
||||
|
||||
print()
|
||||
print(ya)
|
||||
print()
|
||||
print(yb)
|
||||
print(torch.allclose(ya, yb))
|
||||
|
||||
print(ya == yb)
|
@@ -1,143 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _dot(torch.autograd.Function):
|
||||
src = """
|
||||
#define STM 4
|
||||
#define STN 4
|
||||
|
||||
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __aligned(16),
|
||||
float alpha,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int K __retune __multipleof(16),
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
int pid = get_program_id(0);
|
||||
int pidz = get_program_id(2);
|
||||
int gridm = M / TM;
|
||||
int gridn = N / TN;
|
||||
int stgridm = (gridm + STM - 1) / STM;
|
||||
int stgridn = (gridn + STN - 1) / STN;
|
||||
int stid = pid / (STM * STN);
|
||||
int laneid = pid % (STM * STN);
|
||||
int stm = stid / stgridn;
|
||||
int stn = stid % stgridn;
|
||||
int lanem = laneid / STN;
|
||||
int lanen = laneid % STN;
|
||||
int pidm = stm*STM + lanem;
|
||||
int pidn = stn*STN + lanen;
|
||||
int rm[TM] = pidm * TM + 0 ... TM;
|
||||
int rn[TN] = pidn * TN + 0 ... TN;
|
||||
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = pidz * K + 0 ... TK;
|
||||
|
||||
// pointers to operands
|
||||
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
||||
int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
|
||||
TYPE* pa[TM, TK] = A + offa;
|
||||
TYPE* pb[TK, TN] = B + offb;
|
||||
|
||||
// prefetches operands
|
||||
bool checka[TM, TK] = rk[newaxis, :] < K;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
bool checka[TM, TK] = k > TK;
|
||||
bool checkb[TK, TN] = k > TK;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
acc += a @ b;
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
int rxm[TM] = pidm * TM + 0 ... TM;
|
||||
int rxn[TN] = pidn * TN + 0 ... TN;
|
||||
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
|
||||
|
||||
#if (TZ==1)
|
||||
*?(checkc) pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + pid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
c = _dot._call(a,b)
|
||||
return c
|
||||
|
||||
|
||||
kernel = dict()
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b):
|
||||
# create kernel if necessary
|
||||
dtype = a.dtype
|
||||
if dtype not in _dot.kernel:
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'SHAPE_A': 'TM, TK', 'SHAPE_B': 'TK, TN',
|
||||
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
|
||||
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
||||
'TM' : [128],
|
||||
'TN' : [128],
|
||||
'TK' : [32],
|
||||
'TZ' : [1]
|
||||
}
|
||||
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
|
||||
kernel = _dot.kernel[dtype]
|
||||
# allocate output
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
c = torch.empty([M,N], dtype=dtype, device=a.device)
|
||||
print(kernel.asm('sass', c.device))
|
||||
print(kernel.asm('ptx', c.device))
|
||||
# enqueue
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM'))*triton.cdiv(N, opt.d('TN'))]
|
||||
time = kernel(a, b, c, 1., M, N, K,
|
||||
a.stride(0), b.stride(0), c.stride(0), grid=grid)
|
||||
return c
|
||||
|
||||
|
||||
dot = _dot.apply
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
M, N, K = 4096, 4096, 4096
|
||||
a = torch.rand((M, K)).cuda().half()
|
||||
b = torch.rand((K, N)).cuda().half()
|
||||
|
||||
#a[:] = 1
|
||||
#b[:] = 1
|
||||
|
||||
zc = torch.matmul(a,b)
|
||||
zc_ = dot(a,b)
|
||||
print(torch.allclose(zc, zc_))
|
@@ -1,76 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _transpose(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void transpose(TYPE * X, TYPE * Y,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int ldx __multipleof(8), int ldy __multipleof(8)) {
|
||||
// extract program ID
|
||||
int pidm = get_program_id(0); //(1)
|
||||
int pidn = get_program_id(1); //(2)
|
||||
|
||||
// create 1D range along the two matrix's axes
|
||||
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||
|
||||
// create 2D array of pointers
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
|
||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
|
||||
|
||||
// create bounds-checking mask
|
||||
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
|
||||
bool checky[TN, TM] = (rn[:, newaxis] < N) && (rm[newaxis, :] < M); //(7b)
|
||||
|
||||
// conditional write-back using the conditional dereferencing operatior '*?()'
|
||||
*?(checky)py = ^(*?(checkx)px); //(7)
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = None ### initialize later when we know the sizes
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
|
||||
M, N = x.shape
|
||||
|
||||
ldx = N
|
||||
ldy = M
|
||||
|
||||
dtype = x.dtype
|
||||
|
||||
y = torch.empty((N,M)).cuda()
|
||||
|
||||
defines= {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [32,64,128],
|
||||
'TN' : [32,64,128],
|
||||
}
|
||||
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||
|
||||
if _transpose.kernel is None:
|
||||
_transpose.kernel = triton.kernel(_transpose.src, defines=defines, num_warps=[4])
|
||||
|
||||
_transpose.kernel(x, y, M, N, ldx, ldy, grid=grid)
|
||||
|
||||
return y
|
||||
|
||||
transpose = _transpose.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(1024,128).cuda()
|
||||
|
||||
print(x)
|
||||
|
||||
ya = torch.t(x)
|
||||
yb = transpose(x)
|
||||
print()
|
||||
print(ya)
|
||||
print()
|
||||
print(yb)
|
||||
print(torch.allclose(ya, yb))
|
||||
|
||||
print(ya == yb)
|
@@ -95,25 +95,18 @@ class CMakeBuild(build_ext):
|
||||
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
|
||||
|
||||
|
||||
find_llvm()
|
||||
|
||||
directories = [x[0] for x in os.walk(os.path.join('src', 'include'))]
|
||||
data = []
|
||||
for d in directories:
|
||||
for htype in ['h', 'hpp']:
|
||||
files = glob.glob(os.path.join(d, f'*.{htype}'), recursive=False)
|
||||
data += [os.path.relpath(f, 'src') for f in files]
|
||||
|
||||
setup(
|
||||
name='triton',
|
||||
version='0.3.0',
|
||||
version='1.0.0',
|
||||
author='Philippe Tillet',
|
||||
author_email='ptillet@g.harvard.edu',
|
||||
author_email='phil@openai.com',
|
||||
description='A language and compiler for custom Deep Learning operations',
|
||||
long_description='',
|
||||
packages=['triton', 'triton/_C'],
|
||||
install_requires=['numpy', 'torch', 'sympy'],
|
||||
package_data={'': data},
|
||||
packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'],
|
||||
install_requires=['numpy', 'torch'],
|
||||
package_data={'triton/ops': ['*.c'],
|
||||
'triton/ops/blocksparse': ['*.c']},
|
||||
include_package_data=True,
|
||||
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
|
||||
cmdclass=dict(build_ext=CMakeBuild),
|
||||
zip_safe=False,
|
||||
@@ -122,7 +115,7 @@ setup(
|
||||
url='https://github.com/ptillet/triton/',
|
||||
download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz',
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
|
||||
'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
|
||||
'Intended Audience :: Developers', # Define that your audience are developers
|
||||
'Topic :: Software Development :: Build Tools',
|
||||
'License :: OSI Approved :: MIT License', # Again, pick a license
|
||||
|
@@ -13,15 +13,19 @@
|
||||
#include "triton/ir/function.h"
|
||||
|
||||
using namespace triton;
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
typedef std::pair<int, int> map_key_t;
|
||||
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
|
||||
CUstream torch_get_cuda_stream(int64_t dev_id);
|
||||
CUdevice torch_get_cuda_device(int64_t dev_id);
|
||||
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<int, std::shared_ptr<rt::function>> id_fn_map;
|
||||
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
|
||||
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
|
||||
std::unordered_map<const rt::options_t*, pybind11::object> opt_cache_;
|
||||
extern CUstream torch_get_cuda_stream(int64_t dev_id);
|
||||
extern CUdevice torch_get_cuda_device(int64_t dev_id);
|
||||
|
||||
|
||||
/* Grid utilities */
|
||||
|
||||
@@ -36,106 +40,123 @@ void delete_grid(const map_key_t& key) {
|
||||
|
||||
/* Function utilities */
|
||||
|
||||
void register_fn(const map_key_t& key,
|
||||
void register_fn(int op_id,
|
||||
int dev_id,
|
||||
const std::string& src,
|
||||
const rt::options_space_t& opt) {
|
||||
if(id_fn_map.find(key) == id_fn_map.end())
|
||||
id_fn_map[key].reset(new rt::function(src, opt, ""));
|
||||
if(tt_devices.find(dev_id) == tt_devices.end()) {
|
||||
driver::device* device;
|
||||
driver::stream* stream;
|
||||
if(dev_id >= 0){
|
||||
device = new triton::driver::cu_device(torch_get_cuda_device(dev_id), false);
|
||||
stream = new triton::driver::cu_stream(torch_get_cuda_stream(dev_id), false);
|
||||
}
|
||||
else{
|
||||
device = new triton::driver::host_device();
|
||||
stream = new triton::driver::host_stream();
|
||||
}
|
||||
tt_devices[dev_id].reset(device);
|
||||
tt_streams[dev_id].reset(stream);
|
||||
}
|
||||
if(id_fn_map.find(op_id) == id_fn_map.end()){
|
||||
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id]));
|
||||
}
|
||||
for(const auto& k: id_fn_map[op_id]->get_kernels()){
|
||||
const rt::options_t* opt = &k.first;
|
||||
pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference);
|
||||
for(auto x: opt->defines)
|
||||
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
|
||||
obj.attr(x.first.c_str()) = std::stoi(x.second);
|
||||
opt_cache_[&k.second->opt] = obj;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void delete_fn(const map_key_t& key) {
|
||||
id_fn_map.erase(key);
|
||||
void delete_fn(int op_id) {
|
||||
id_fn_map.erase(op_id);
|
||||
}
|
||||
|
||||
std::string get_fn_asm(const map_key_t& key, rt::asm_mode_t mode, const rt::options_t& opt) {
|
||||
triton::driver::cu_device device(key.second, false);
|
||||
return id_fn_map[key]->get_asm(mode, &device, opt);
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
id_grid_map.clear();
|
||||
id_fn_map.clear();
|
||||
opt_cache_.clear();
|
||||
}
|
||||
|
||||
size_t make_op_id() {
|
||||
return id_fn_map.size();
|
||||
}
|
||||
|
||||
/* Function signature */
|
||||
void make_module(const std::string& src, ir::module* ir,
|
||||
const runtime::options_space_t& opt) {
|
||||
std::string copy = triton::runtime::function::preheader() + src;
|
||||
// pre-process
|
||||
TokenSequence tokens;
|
||||
Preprocessor cpp(©, true);
|
||||
for(auto it: opt.defines){
|
||||
cpp.AddMacro(it.first, &it.second[0]);
|
||||
}
|
||||
cpp.Process(tokens);
|
||||
// parse
|
||||
Parser parser(tokens);
|
||||
parser.Parse();
|
||||
Generator gen(&parser);
|
||||
gen.Gen(ir);
|
||||
std::vector<rt::arg_type> get_fn_signature(size_t op_id) {
|
||||
return id_fn_map[op_id]->get_kernels()[0].second->get_sig();
|
||||
}
|
||||
|
||||
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
||||
const runtime::options_space_t& opt) {
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
make_module(src, &*ir, opt);
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
// extract signature
|
||||
std::vector<rt::arg_type> ret;
|
||||
ir::function_type* ty = fn->get_fn_type();
|
||||
for(size_t i = 0; i < ty->get_num_params(); i++)
|
||||
ret.push_back(rt::convert(ty->get_param_ty(i)));
|
||||
return ret;
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, size_t grid_0, size_t grid_1, size_t grid_2){
|
||||
rt::function* fn = id_fn_map.at(op_id).get();
|
||||
(*fn)((void**)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]);
|
||||
|
||||
// for(size_t n = 0; n < constant_names.size(); n++){
|
||||
// const torch::Tensor& x = constant_vals[n];
|
||||
// fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
|
||||
}
|
||||
|
||||
typedef triton::runtime::options_t options_t;
|
||||
typedef triton::runtime::options_space_t options_space_t;
|
||||
pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string& args, const rt::function::grid_fn_ty& grid){
|
||||
rt::function* fn = id_fn_map.at(op_id).get();
|
||||
auto wrapper = [&grid](const rt::options_t& opt){
|
||||
pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference);
|
||||
for(auto x: opt.defines)
|
||||
if(std::all_of(x.second.begin(), x.second.end(), ::isdigit))
|
||||
obj.attr(x.first.c_str()) = std::stoi(x.second);
|
||||
return grid(*obj.cast<rt::options_t*>());
|
||||
};
|
||||
rt::kernel* kernel = fn->autotune((void**)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]);
|
||||
return opt_cache_.at(&kernel->opt);
|
||||
}
|
||||
|
||||
|
||||
void init_superblocking(pybind11::module &m);
|
||||
void init_launch(pybind11::module &m);
|
||||
|
||||
PYBIND11_MODULE(libtriton, m) {
|
||||
m.doc() = "Python bindings to the C++ Triton API";
|
||||
|
||||
// bindings for triton classes
|
||||
pybind11::enum_<rt::arg_type>(m, "arg_type")
|
||||
.value("int1", rt::INT1_T)
|
||||
.value("int8", rt::INT8_T)
|
||||
.value("int16", rt::INT16_T)
|
||||
.value("int32", rt::INT32_T)
|
||||
.value("int64", rt::INT64_T)
|
||||
.value("half", rt::HALF_T)
|
||||
.value("float", rt::FLOAT_T)
|
||||
.value("int1" , rt::INT1_T)
|
||||
.value("int8" , rt::INT8_T)
|
||||
.value("int16" , rt::INT16_T)
|
||||
.value("int32" , rt::INT32_T)
|
||||
.value("int64" , rt::INT64_T)
|
||||
.value("half" , rt::HALF_T)
|
||||
.value("float" , rt::FLOAT_T)
|
||||
.value("double", rt::DOUBLE_T)
|
||||
.value("buffer", rt::BUFFER_T);
|
||||
|
||||
pybind11::enum_<rt::asm_mode_t>(m, "asm_mode")
|
||||
.value("ptx", rt::ASM_NV_PTX)
|
||||
.value("ptx" , rt::ASM_NV_PTX)
|
||||
.value("sass", rt::ASM_NV_SASS);
|
||||
|
||||
pybind11::class_<options_t>(m, "options")
|
||||
.def(pybind11::init<>())
|
||||
.def("d", &options_t::D<int>)
|
||||
.def_readwrite("num_warps", &options_t::num_warps)
|
||||
.def_readwrite("defines" , &options_t::defines);
|
||||
pybind11::class_<rt::options_t>(m, "options", pybind11::dynamic_attr())
|
||||
.def_readwrite("num_warps", &rt::options_t::num_warps)
|
||||
.def_readwrite("defines" , &rt::options_t::defines);
|
||||
|
||||
pybind11::class_<options_space_t>(m, "options_space")
|
||||
pybind11::class_<rt::options_space_t>(m, "options_space")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("defines", &options_space_t::defines)
|
||||
.def_readwrite("num_warps", &options_space_t::num_warps);
|
||||
.def_readwrite("num_warps", &rt::options_space_t::num_warps)
|
||||
.def_readwrite("defines" , &rt::options_space_t::defines);
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("get_fn_signature", &get_fn_signature);
|
||||
m.def("get_fn_asm", &get_fn_asm);
|
||||
// m.def("get_fn_asm", &get_fn_asm);
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("delete_grid", &delete_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
m.def("delete_fn", &delete_fn);
|
||||
m.def("make_op_id", &make_op_id);
|
||||
m.def("cleanup", &cleanup);
|
||||
;
|
||||
m.def("autotune", &autotune, pybind11::return_value_policy::reference);
|
||||
m.def("launch_kernel", &launch_kernel);
|
||||
|
||||
init_launch(m);
|
||||
init_superblocking(m);
|
||||
}
|
||||
|
@@ -1,95 +0,0 @@
|
||||
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
|
||||
// as a string constructed with struct.pack in python
|
||||
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "torch/script.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
typedef std::pair<int, int> map_key_t;
|
||||
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
std::shared_ptr<drv::device> host_device;
|
||||
std::shared_ptr<drv::context> host_context;
|
||||
std::shared_ptr<drv::stream> host_stream;
|
||||
|
||||
int64_t cdiv_sum(torch::Tensor x, int64_t div){
|
||||
TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor")
|
||||
auto _x = x.accessor<int, 1>();
|
||||
int64_t ret = 0;
|
||||
for(size_t i = 0; i < x.size(0); i++)
|
||||
ret += (_x[i] + div - 1) / div;
|
||||
return ret;
|
||||
}
|
||||
|
||||
void init_host_stream() {
|
||||
if(!host_stream){
|
||||
host_device.reset(new drv::host_device());
|
||||
host_context.reset(drv::context::create(&*host_device));
|
||||
host_stream.reset(drv::stream::create(host_context->backend()));
|
||||
}
|
||||
}
|
||||
|
||||
CUstream torch_get_cuda_stream(int64_t dev_id) {
|
||||
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
}
|
||||
|
||||
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
|
||||
CUdevice ret;
|
||||
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void synchronize(int64_t dev_id) {
|
||||
if(dev_id == -1){
|
||||
init_host_stream();
|
||||
host_stream->synchronize();
|
||||
}
|
||||
else{
|
||||
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
||||
stream.synchronize();
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor cuda_empty_like(torch::Tensor x){
|
||||
if(x.nbytes() == 0)
|
||||
return torch::empty_like(x);
|
||||
void* data;
|
||||
cudaMalloc(&data, x.nbytes());
|
||||
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
|
||||
return ret;
|
||||
}
|
||||
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
||||
const std::vector<std::string>& constant_names, const std::vector<torch::Tensor>& constant_vals){
|
||||
rt::function* fn = id_fn_map.at({op_id, dev_id}).get();
|
||||
for(size_t n = 0; n < constant_names.size(); n++){
|
||||
const torch::Tensor& x = constant_vals[n];
|
||||
fn->set_cst(constant_names[n].c_str(), (char*)x.data_ptr(), x.numel()*x.element_size());
|
||||
}
|
||||
if(dev_id == -1){
|
||||
init_host_stream();
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream, &*host_device);
|
||||
}
|
||||
else{
|
||||
C10_CUDA_CHECK(cudaSetDevice(dev_id));
|
||||
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
||||
triton::driver::cu_device device(torch_get_cuda_device(dev_id), false);
|
||||
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream, &device);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static auto registry = torch::RegisterOperators()
|
||||
.op("triton::launch_kernel", &launch_kernel)
|
||||
.op("triton::cuda_empty_like", &cuda_empty_like)
|
||||
.op("triton::cdiv_sum", &cdiv_sum)
|
||||
.op("triton::synchronize", &synchronize);
|
83
python/src/torch/launch.cc
Normal file
83
python/src/torch/launch.cc
Normal file
@@ -0,0 +1,83 @@
|
||||
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
|
||||
// as a string constructed with struct.pack in python
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "torch/script.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
typedef std::pair<int, int> map_key_t;
|
||||
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<int, std::shared_ptr<rt::function>> id_fn_map;
|
||||
extern std::map<int, std::shared_ptr<drv::device>> tt_devices;
|
||||
extern std::map<int, std::shared_ptr<drv::stream>> tt_streams;
|
||||
|
||||
|
||||
int64_t cdiv(int64_t a, int64_t b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
int64_t largest_pow2_divisor(int64_t a){
|
||||
if(a % 8 == 0) return 8;
|
||||
if(a % 4 == 0) return 4;
|
||||
if(a % 2 == 0) return 2;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int64_t cdiv_sum(torch::Tensor x, int64_t div){
|
||||
TORCH_CHECK(!x.is_cuda(), "Argument of cdiv_sum must be a CPU tensor")
|
||||
auto _x = x.accessor<int, 1>();
|
||||
int64_t ret = 0;
|
||||
for(size_t i = 0; i < x.size(0); i++)
|
||||
ret += (_x[i] + div - 1) / div;
|
||||
return ret;
|
||||
}
|
||||
|
||||
CUstream torch_get_cuda_stream(int64_t dev_id) {
|
||||
return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
}
|
||||
|
||||
CUdeviceptr torch_get_cuda_device(int64_t dev_id) {
|
||||
CUdevice ret;
|
||||
triton::driver::dispatch::cuDeviceGet(&ret, dev_id);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void synchronize(int64_t dev_id) {
|
||||
tt_streams[dev_id]->synchronize();
|
||||
}
|
||||
|
||||
torch::Tensor cuda_empty_like(torch::Tensor x){
|
||||
if(x.nbytes() == 0)
|
||||
return torch::empty_like(x);
|
||||
void* data;
|
||||
cudaMalloc(&data, x.nbytes());
|
||||
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
|
||||
return ret;
|
||||
}
|
||||
|
||||
void cuda_set_device(int64_t dev_id) {
|
||||
if(dev_id >= 0)
|
||||
C10_CUDA_CHECK(cudaSetDevice(dev_id));
|
||||
}
|
||||
|
||||
|
||||
void init_launch(pybind11::module &m) {
|
||||
m.def("cuda_set_device", &cuda_set_device);
|
||||
m.def("cuda_empty_like", &cuda_empty_like);
|
||||
m.def("largest_pow2_divisor", &largest_pow2_divisor);
|
||||
m.def("cdiv", &cdiv);
|
||||
m.def("cdiv_sum", &cdiv_sum);
|
||||
m.def("synchronize", &synchronize);
|
||||
}
|
117
python/src/torch/superblock.cc
Normal file
117
python/src/torch/superblock.cc
Normal file
@@ -0,0 +1,117 @@
|
||||
#include <torch/extension.h>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;
|
||||
|
||||
void segment_blocks(torch::Tensor layout, torch::Tensor idx, torch::Tensor scratch, int max_width, ret_t& ret){
|
||||
size_t H = layout.size(0);
|
||||
size_t M = layout.size(1);
|
||||
size_t N = layout.size(2);
|
||||
torch::Tensor tmp = torch::zeros_like(layout);
|
||||
auto _tmp = tmp.accessor <int, 3>();
|
||||
auto _layout = layout.accessor <int, 3>();
|
||||
auto _idx = idx.accessor <int, 3>();
|
||||
auto _scratch = scratch.accessor<int, 3>();
|
||||
std::vector<int> current(H, 0);
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for(size_t h = 0; h < H; h++){
|
||||
// surrounding indices
|
||||
std::vector<int> ii_left(max_width, -1);
|
||||
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
|
||||
|
||||
for(size_t m = 0; m < M; m++){
|
||||
for(size_t n = 0; n < N; n++){
|
||||
int v = _layout[h][m][n];
|
||||
if(v == 0)
|
||||
continue;
|
||||
int n_left= ii_left[max_width-1];
|
||||
int m_top = ii_top [max_width-1][n];
|
||||
int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
|
||||
int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
|
||||
int topleft = (m_top >=0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
|
||||
int width = std::min(left, std::min(top, topleft)) + 1;
|
||||
|
||||
// reset width if blocks cannot be
|
||||
// packed together (i.e., there's a 1 "in the middle")
|
||||
for(int nn = n_left + 1; nn < n; nn++)
|
||||
if(ii_top[max_width-1][nn] > ii_top[max_width-1][n])
|
||||
width = 1;
|
||||
_tmp[h][m][n] = width;
|
||||
|
||||
// update n_left ring buffer
|
||||
for(int k = 0; k < max_width-1; k++)
|
||||
ii_left[k] = ii_left[k+1];
|
||||
ii_left[max_width-1] = n;
|
||||
|
||||
// update ii_top ring buffer
|
||||
for(int k = 0; k < max_width-1; k++)
|
||||
ii_top[k][n] = ii_top[k+1][n];
|
||||
ii_top[max_width-1][n] = m;
|
||||
|
||||
// block is too small -- skip
|
||||
if(width != max_width)
|
||||
continue;
|
||||
|
||||
// retained blocks are set to zeros
|
||||
for(size_t km = 0; km < max_width; km++)
|
||||
for(size_t kn = 0; kn < max_width; kn++)
|
||||
{
|
||||
int mm = ii_top[km][n];
|
||||
int nn = ii_left[kn];
|
||||
if(mm < 0 || nn < 0)
|
||||
continue;
|
||||
_layout[h][mm][nn] = 0;
|
||||
_tmp[h][mm][nn] = 0;
|
||||
_scratch[h][current[h]][0] = (int)h;
|
||||
_scratch[h][current[h]][1] = (int)mm;
|
||||
_scratch[h][current[h]][2] = (int)nn;
|
||||
_scratch[h][current[h]][3] = _idx[h][mm][nn];
|
||||
current[h]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<torch::Tensor> to_cat;
|
||||
for(size_t h = 0; h < H; h++)
|
||||
if(current[h] > 0)
|
||||
to_cat.push_back(scratch[h].slice(0, 0, current[h]));
|
||||
if(!to_cat.empty())
|
||||
ret.push_back(std::make_tuple(max_width, torch::cat(to_cat)));
|
||||
}
|
||||
|
||||
|
||||
ret_t superblock(torch::Tensor layout, int start_width) {
|
||||
ret_t ret;
|
||||
// block index
|
||||
torch::Tensor idx = torch::zeros_like(layout);
|
||||
int current = 0;
|
||||
int64_t H = layout.size(0);
|
||||
int64_t M = layout.size(1);
|
||||
int64_t N = layout.size(2);
|
||||
auto _layout = layout.accessor <int, 3>();
|
||||
auto _idx = idx.accessor<int, 3>();
|
||||
for(int64_t h = 0; h < H; h++)
|
||||
for(int64_t m = 0; m < M; m++)
|
||||
for(int64_t n = 0; n < N; n++){
|
||||
if(_layout[h][m][n] == 0)
|
||||
continue;
|
||||
_idx[h][m][n] = current++;
|
||||
}
|
||||
// scratch memory
|
||||
torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
|
||||
for(int max_width = start_width; max_width > 0; max_width /= 2)
|
||||
segment_blocks(layout, idx, scratch, max_width, ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
void init_superblocking(pybind11::module &m) {
|
||||
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
|
||||
}
|
50
python/tests/test_blocksparse.py
Normal file
50
python/tests/test_blocksparse.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import itertools
|
||||
import torch
|
||||
import triton as tt
|
||||
import pytest
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
||||
ret[:, idx, :, :] = x[:, h, i*block: (i+1)*block, j*block: (j+1)*block]
|
||||
return ret
|
||||
|
||||
def mask_tensor(x, mask, block, value = 0):
|
||||
ret = x.clone()
|
||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
|
||||
return ret
|
||||
|
||||
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
|
||||
[
|
||||
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
|
||||
for at in [False, True]\
|
||||
for bt in [False, True]\
|
||||
for block in [16, 32, 64]
|
||||
]
|
||||
)
|
||||
def test_op(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
# create inputs
|
||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
|
||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
|
||||
shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE]
|
||||
layout = torch.randint(2, (H, shape[0]//BLOCK, shape[1]//BLOCK))
|
||||
# triton result
|
||||
op = tt.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
ra = sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
|
||||
rb = sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b
|
||||
rc = op(ra, rb)
|
||||
# torch result
|
||||
ta = mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
|
||||
tb = mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b
|
||||
ta = ta.transpose(2, 3) if TRANS_A else ta
|
||||
tb = tb.transpose(2, 3) if TRANS_B else tb
|
||||
tc = torch.matmul(ta, tb)
|
||||
tc = mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
|
||||
tc = sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
|
||||
# compare
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert torch.allclose(rc, tc, rtol=rtol, atol=atol)
|
17
python/tests/test_conv.py
Normal file
17
python/tests/test_conv.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
|
||||
def test_op():
|
||||
torch.manual_seed(0)
|
||||
DTYPE = torch.float16
|
||||
N, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
|
||||
pad, stride, = (1, 1), (1, 1)
|
||||
dilation = (1, 1)
|
||||
a = torch.rand((N , CI, H, W ), dtype=DTYPE, device='cuda') / CI**.5
|
||||
b = torch.rand((CI, R , S, CO), dtype=DTYPE, device='cuda') / CI**.5
|
||||
th_c = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, dilation)
|
||||
tt_c = triton.ops.conv(a, b, pad, stride)
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert torch.allclose(tt_c, th_c, atol=atol, rtol=rtol)
|
96
python/tests/test_matmul.py
Normal file
96
python/tests/test_matmul.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import pytest
|
||||
import itertools
|
||||
import triton as tt
|
||||
import torch as th
|
||||
|
||||
@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 4, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 8, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE)
|
||||
]
|
||||
for DTYPE in ['float16']
|
||||
for AT in [False, True]
|
||||
for BT in [False, True]
|
||||
]))
|
||||
def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
|
||||
th.manual_seed(0)
|
||||
tt.ops._matmul.kernel = dict()
|
||||
tt.ops._matmul.TM = [TM]
|
||||
tt.ops._matmul.TN = [TN]
|
||||
tt.ops._matmul.TK = [TK]
|
||||
tt.ops._matmul.num_warps = [NWARP]
|
||||
if M is None: M = TM
|
||||
if N is None: N = TN
|
||||
if K is None: K = TK
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
||||
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
th_c = th.matmul(a, b)
|
||||
tt_c = tt.ops.matmul(a, b)
|
||||
rtol, atol = {th.float32: (1e-4, 1e-5),
|
||||
th.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert th.allclose(tt_c, th_c, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def do_bench(fn, flops = 0, warmup = 10, rep = 50):
|
||||
start_event = th.cuda.Event(enable_timing=True)
|
||||
end_event = th.cuda.Event(enable_timing=True)
|
||||
ret = fn()
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
th.cuda.synchronize()
|
||||
start_event.record()
|
||||
for i in range(rep):
|
||||
fn()
|
||||
end_event.record()
|
||||
th.cuda.synchronize()
|
||||
time_ms = start_event.elapsed_time(end_event) / rep
|
||||
return time_ms, flops/time_ms*1e-9, ret
|
||||
|
||||
|
||||
def perf_op(dtype=th.float16, warmup=10, rep=50):
|
||||
AT, BT = False, False
|
||||
configs = [(N, N, N) for N in [128, 8192]]
|
||||
for M, N, K in configs:
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
a = a[::,::]
|
||||
b = b[::,::]
|
||||
TH_MS, TH_TFLOPS, _ = do_bench(lambda: th.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
|
||||
TT_MS, TT_TFLOPS, _ = do_bench(lambda: tt.ops.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
|
||||
print((M, N, K), TH_MS, TT_MS)
|
8
python/tests/test_softmax.py
Normal file
8
python/tests/test_softmax.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
def test_op(M = 1024, N = 1024, dtype = torch.float32):
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda')
|
||||
th_y = torch.softmax(x, dim=-1)
|
||||
tt_y = triton.ops.softmax(x)
|
||||
assert torch.allclose(tt_y, th_y)
|
@@ -1,8 +1,13 @@
|
||||
from .kernel import *
|
||||
# TODO: torch needs to be imported first
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch
|
||||
|
||||
# clean-up libtriton resources
|
||||
# libtriton resources
|
||||
import atexit
|
||||
import triton._C.libtriton as libtriton
|
||||
@atexit.register
|
||||
def cleanup():
|
||||
libtriton.cleanup()
|
||||
libtriton.cleanup()
|
||||
|
||||
from .kernel import *
|
||||
from . import ops
|
@@ -15,18 +15,6 @@ codes = {
|
||||
libtriton.arg_type.buffer: 'P'
|
||||
}
|
||||
|
||||
sizes = {
|
||||
libtriton.arg_type.int1: 1,
|
||||
libtriton.arg_type.int8: 1,
|
||||
libtriton.arg_type.int32: 4,
|
||||
libtriton.arg_type.int64: 8,
|
||||
libtriton.arg_type.half: 2,
|
||||
libtriton.arg_type.float: 4,
|
||||
libtriton.arg_type.double: 8,
|
||||
libtriton.arg_type.buffer: 8
|
||||
}
|
||||
|
||||
|
||||
def th_to_triton(obj):
|
||||
tys = {
|
||||
torch.int8: 'char',
|
||||
@@ -43,92 +31,65 @@ def th_to_triton(obj):
|
||||
return [th_to_triton(x)[0] for x in obj]
|
||||
return [str(obj)]
|
||||
|
||||
|
||||
def cdiv(a, b):
|
||||
return (a + b - 1) // b
|
||||
return libtriton.cdiv(a, b)
|
||||
|
||||
def cdiv_sum(a, b):
|
||||
return torch.ops.triton.cdiv_sum(a, b)
|
||||
|
||||
def synchronize(device):
|
||||
dev_id = device.index
|
||||
dev_id = -1 if dev_id is None else dev_id
|
||||
torch.ops.triton.synchronize(dev_id)
|
||||
libtriton.synchronize(dev_id)
|
||||
|
||||
def read(path):
|
||||
with open(path, 'r') as f:
|
||||
source = f.read()
|
||||
return source
|
||||
|
||||
|
||||
class kernel:
|
||||
|
||||
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]):
|
||||
def __init__(self, src, device, defines = dict(), num_warps = [4]):
|
||||
self.src = src
|
||||
self.opt = libtriton.options_space()
|
||||
self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()]
|
||||
self.opt.num_warps = num_warps
|
||||
# device
|
||||
assert device.type in ['cuda', 'cpu']
|
||||
if device.type == 'cuda':
|
||||
self.device = torch.cuda.current_device() if device.index is None else device.index
|
||||
if device.type == 'cpu':
|
||||
self.device = -1
|
||||
# C++ function wrapper
|
||||
self.op_id = libtriton.make_op_id()
|
||||
self.registered = set()
|
||||
arg_types = libtriton.get_fn_signature(self.src, self.opt)
|
||||
size = sum([sizes[x] for x in arg_types])
|
||||
libtriton.register_fn(self.op_id, self.device, self.src, self.opt)
|
||||
# debug mode
|
||||
self.is_debug = 'TRITON_DEBUG' in os.environ
|
||||
# signature
|
||||
arg_types = libtriton.get_fn_signature(self.op_id)
|
||||
self.tys = ''.join([codes[x] for x in arg_types])
|
||||
|
||||
def asm(self, mode, device, **kwargs):
|
||||
dev_id = device.index
|
||||
# assembly mode
|
||||
supported = {
|
||||
'ptx': libtriton.asm_mode.ptx,
|
||||
'sass': libtriton.asm_mode.sass,
|
||||
}
|
||||
if mode not in supported:
|
||||
raise('ASM mode must be in ', supported.keys())
|
||||
mode = supported[mode]
|
||||
# disambiguates #defines
|
||||
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
|
||||
def _single_value_or_err(x, key):
|
||||
if isinstance(x, list) and len(x) == 1:
|
||||
return x[0]
|
||||
if isinstance(x, list) and len(x) > 1:
|
||||
if key in kwargs:
|
||||
return kwargs[key]
|
||||
raise ValueError(f'Parameter {key}={x} was auto-tuned during kernel creation. '
|
||||
'Please supply an explicit value as a keyword argument.')
|
||||
return str(x)
|
||||
defines = dict()
|
||||
for (D, V) in self.opt.defines:
|
||||
defines[D] = _single_value_or_err(V, D)
|
||||
opt = libtriton.options()
|
||||
opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps')
|
||||
opt.defines = defines
|
||||
# run
|
||||
return libtriton.get_fn_asm((self.op_id, dev_id), mode, opt)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if 'TRITON_DEBUG_MODE' in os.environ:
|
||||
def __call__(self, *args, grid):
|
||||
# debug mode (initialize)
|
||||
if self.is_debug:
|
||||
_args = args
|
||||
args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args]
|
||||
for i in range(len(args)):
|
||||
if isinstance(args[i], torch.Tensor):
|
||||
args[i] = torch.ops.triton.cuda_empty_like(args[i])
|
||||
args[i] = libtriton.cuda_empty_like(args[i])
|
||||
args[i].copy_(_args[i])
|
||||
torch.cuda.synchronize()
|
||||
for x in args:
|
||||
if isinstance(x, torch.Tensor):
|
||||
device = x.device.index
|
||||
device = -1 if device is None else device
|
||||
break
|
||||
# lazily register function for device
|
||||
libtriton.register_fn((self.op_id, device), self.src, self.opt)
|
||||
# launch grid
|
||||
if 'grid' not in kwargs:
|
||||
raise RuntimeError('Must provide grid for kernel launch')
|
||||
grid = kwargs['grid']
|
||||
libtriton.register_grid((self.op_id, device), grid)
|
||||
# re-allocate buffers for auto-tuning
|
||||
if 'autotune_buf' in kwargs:
|
||||
pass
|
||||
# launch
|
||||
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
|
||||
names = list(kwargs['constants'].keys()) if 'constants' in kwargs else []
|
||||
constants = list(kwargs['constants'].values()) if 'constants' in kwargs else []
|
||||
torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants)
|
||||
if 'TRITON_DEBUG_MODE' in os.environ:
|
||||
torch.cuda.synchronize()
|
||||
# initialize cuda device if necessary
|
||||
libtriton.cuda_set_device(self.device)
|
||||
# pack parameters into a byte buffer
|
||||
params = pack(self.tys, *args)
|
||||
# auto-tune if necessary
|
||||
opt = libtriton.autotune(self.op_id, self.device, params, grid)
|
||||
# run kernel
|
||||
grid = grid(opt)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = 1 if len(grid) < 2 else grid[1]
|
||||
grid_2 = 1 if len(grid) < 3 else grid[2]
|
||||
libtriton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)
|
||||
# debug mode (finalize)
|
||||
if self.is_debug:
|
||||
for i in range(len(args)):
|
||||
if isinstance(args[i], torch.Tensor):
|
||||
_args[i].copy_(args[i].clone())
|
||||
|
4
python/triton/ops/__init__.py
Normal file
4
python/triton/ops/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .conv import _conv, conv
|
||||
from .matmul import _matmul, matmul
|
||||
from .softmax import _softmax, softmax
|
||||
from . import blocksparse
|
1
python/triton/ops/blocksparse/__init__.py
Normal file
1
python/triton/ops/blocksparse/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .matmul import matmul
|
198
python/triton/ops/blocksparse/matmul.c
Normal file
198
python/triton/ops/blocksparse/matmul.c
Normal file
@@ -0,0 +1,198 @@
|
||||
__global__ void NAME (TYPE* A __readonly __noalias __aligned(16),
|
||||
TYPE* B __readonly __noalias __aligned(16),
|
||||
TYPE* C __noalias __aligned(16),
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8),
|
||||
long stride_za __multipleof(8),
|
||||
long stride_zb __multipleof(8),
|
||||
long stride_zc __multipleof(8),
|
||||
long stride_ha __multipleof(8),
|
||||
long stride_hb __multipleof(8),
|
||||
long stride_hc __multipleof(8),
|
||||
int DS0, int DS1,
|
||||
int SDD_K __multipleof(16),
|
||||
int SDD_off_width,
|
||||
int* lut, int* locks, int nlocks) {
|
||||
/* ---------------- */
|
||||
/* Prologue */
|
||||
/* ---------------- */
|
||||
// program ids
|
||||
int pid0 = get_program_id(0);
|
||||
int pid1 = get_program_id(1);
|
||||
int pidz = get_program_id(2);
|
||||
#ifdef SDD
|
||||
// load LUT header
|
||||
pid1 = pid1 + SDD_off_width;
|
||||
int blockidm[TM] = (0 ... TM) / BLOCK;
|
||||
int blockidn[TN] = (0 ... TN) / BLOCK;
|
||||
int offlutm[TM] = blockidm*(TN/BLOCK)*4;
|
||||
int offlutn[TN] = blockidn*4;
|
||||
int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4;
|
||||
int z = *(header + 0);
|
||||
int i[TM] = *(header + 1 + offlutm);
|
||||
int j[TN] = *(header + 2 + offlutn);
|
||||
int AS1 = SDD_K / TZ;
|
||||
int lockid = select(TZ > 1, 1, 0);
|
||||
int offka = pid0 * AS1;
|
||||
int offkb = pid0 * AS1;
|
||||
int offmc = 0;
|
||||
int offnc = 0;
|
||||
int offpa = 0;
|
||||
int offpb = 0;
|
||||
int maxid = TZ;
|
||||
int offhc = 0;
|
||||
int offha = z;
|
||||
int offhb = z;
|
||||
int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK);
|
||||
int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK);
|
||||
#else
|
||||
// load LUT header
|
||||
int *header = lut + pid0 * 6;
|
||||
int offset = *(header + 0);
|
||||
int AS1 = *(header + 1);
|
||||
int column = *(header + 2);
|
||||
int depth = *(header + 3);
|
||||
int lockid = *(header + 4);
|
||||
int maxid = *(header + 5);
|
||||
int *pinc = lut + offset;
|
||||
int offhc = depth;
|
||||
#ifdef DSD
|
||||
// output offset
|
||||
int offnc = pid1 * TN;
|
||||
int offmc = column * TM;
|
||||
int offpc = 0;
|
||||
// dense input offset
|
||||
int offnb = pid1 * TN;
|
||||
int offkb __multipleof(8) = *pinc;
|
||||
int offpb = 0;
|
||||
// sparse input offset
|
||||
int offma = 0;
|
||||
int offka = 0;
|
||||
long offpa __multipleof(8) = *(pinc + 1);
|
||||
offpa = offpa * BLOCK * BLOCK;
|
||||
int offha = 0;
|
||||
int offhb = depth;
|
||||
#endif
|
||||
#ifdef DDS
|
||||
// output offset
|
||||
int offmc = pid1 * TM;
|
||||
int offnc = column * TN;
|
||||
int offpc = 0;
|
||||
// dense input offset
|
||||
int offma = pid1 * TM;
|
||||
int offka __multipleof(8) = *pinc;
|
||||
int offpa = 0;
|
||||
// sparse input offset
|
||||
int offnb = 0;
|
||||
int offkb = 0;
|
||||
long offpb __multipleof(8) = *(pinc + 1);
|
||||
offpb = offpb * BLOCK * BLOCK;
|
||||
int offha = depth;
|
||||
int offhb = 0;
|
||||
#endif
|
||||
int ram[TM] = offma + 0 ... TM;
|
||||
int rbn[TN] = offnb + 0 ... TN;
|
||||
#endif
|
||||
// initialize a, b pointers
|
||||
int rka[TK] = offka + 0 ... TK;
|
||||
int rkb[TK] = offkb + 0 ... TK;
|
||||
TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
|
||||
TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
|
||||
// pre-fetch
|
||||
#ifdef DDS
|
||||
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
|
||||
#else
|
||||
bool checkam[TM, TK] = AS1 > 0;
|
||||
#endif
|
||||
#ifdef DSD
|
||||
bool checkbn[TK, TN] = rbn[newaxis, :] < DS0;
|
||||
#else
|
||||
bool checkbn[TK, TN] = AS1 > 0;
|
||||
#endif
|
||||
TYPE a[TM, TK] = checkam ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkbn ? *pb : 0;
|
||||
|
||||
/* ---------------- */
|
||||
/* Inner Loop */
|
||||
/* ---------------- */
|
||||
// create result tile
|
||||
float acc[TM, TN] = 0;
|
||||
int step = TK;
|
||||
for(int k = AS1; k > 0; k -= step) {
|
||||
acc += a @ b;
|
||||
// update pointers
|
||||
#ifdef SDD
|
||||
int inc_a = TK * STRIDE_AK;
|
||||
int inc_b = TK * STRIDE_BK;
|
||||
#else
|
||||
pinc += 2;
|
||||
#ifdef DSD
|
||||
int inc_b __multipleof(8) = *pinc;
|
||||
int inc_a __multipleof(8) = *(pinc + 1);
|
||||
inc_b = inc_b * STRIDE_BK;
|
||||
#endif
|
||||
#ifdef DDS
|
||||
int inc_a __multipleof(8) = *pinc;
|
||||
int inc_b __multipleof(8) = *(pinc + 1);
|
||||
inc_a = inc_a * STRIDE_AK;
|
||||
#endif
|
||||
#endif
|
||||
pa += inc_a;
|
||||
pb += inc_b;
|
||||
// pre-fetch
|
||||
bool checkak[TM, TK] = k > TK;
|
||||
bool checkbk[TK, TN] = k > TK;
|
||||
bool checka[TM, TK] = checkam && checkak;
|
||||
bool checkb[TK, TN] = checkbk && checkbn;
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
/* ---------------- */
|
||||
/* Epilogue */
|
||||
/* ---------------- */
|
||||
// initialize c pointers
|
||||
#ifdef SDD
|
||||
bool checkc[TM, TN] = 1;
|
||||
// rematerialize
|
||||
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
|
||||
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
|
||||
int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4;
|
||||
int rr_offlutn[TN] = rr_blockidn*4;
|
||||
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
|
||||
int bkid[TM, TN] = *(header + off_bkid);
|
||||
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
|
||||
// range within blocks
|
||||
int rcm[TM] = (0 ... TM) % BLOCK;
|
||||
int rcn[TN] = (0 ... TN) % BLOCK;
|
||||
#else
|
||||
int rcm[TM] = offmc + 0 ... TM;
|
||||
int rcn[TN] = offnc + 0 ... TN;
|
||||
#ifdef DSD
|
||||
bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
|
||||
#endif
|
||||
#ifdef DDS
|
||||
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
|
||||
#endif
|
||||
#endif
|
||||
TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
|
||||
// write-back directly
|
||||
if(lockid == 0) {
|
||||
*?(checkc) pc = c;
|
||||
}
|
||||
// accumulate partial result using spin-locks
|
||||
else {
|
||||
int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
|
||||
int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % maxid);
|
||||
atomic_xchg(plock, 0);
|
||||
}
|
||||
}
|
467
python/triton/ops/blocksparse/matmul.py
Normal file
467
python/triton/ops/blocksparse/matmul.py
Normal file
@@ -0,0 +1,467 @@
|
||||
import triton
|
||||
import triton._C.libtriton as libtriton
|
||||
import torch
|
||||
import os
|
||||
import math
|
||||
|
||||
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
|
||||
|
||||
##############
|
||||
# MAIN API #
|
||||
##############
|
||||
class _matmul(torch.autograd.Function):
|
||||
|
||||
sdd_cache = dict()
|
||||
dsd_cache = dict()
|
||||
dds_cache = dict()
|
||||
locks = dict()
|
||||
|
||||
# Given an array sizes representing reduction size for each
|
||||
# column of a block-mode matrix multiplication,
|
||||
# performs load-balancing to achieve more smaller reductions
|
||||
# between `seg_size` elements
|
||||
@staticmethod
|
||||
def load_balance(sizes, block):
|
||||
# segment size
|
||||
# heuristics taken from OpenAI blocksparse code
|
||||
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
|
||||
max_size = sizes.max()
|
||||
min_size = sizes[sizes != 0].min()
|
||||
#if max_size > min_size * 2.0:
|
||||
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
|
||||
#else:
|
||||
# seg_max = max_size
|
||||
seg_max = max_size
|
||||
seg_min = max(triton.cdiv(seg_max, 4), 4)
|
||||
# split reduction into segments
|
||||
div = sizes // seg_max
|
||||
rem = sizes % seg_max
|
||||
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
|
||||
width = packs.sum()
|
||||
segments = torch.empty(width, dtype=sizes.dtype)
|
||||
column = torch.empty_like(segments)
|
||||
lockid = torch.zeros_like(segments)
|
||||
maxid = torch.zeros_like(segments)
|
||||
nlocks = 0
|
||||
current = 0
|
||||
col_idx = 0
|
||||
for i in range(len(sizes)):
|
||||
d, r = div[i], rem[i]
|
||||
isempty = sizes[i] < seg_min
|
||||
last = current + d + (r >= seg_min) + isempty
|
||||
# column id
|
||||
column[current:last] = col_idx
|
||||
# lock id
|
||||
if d > 1 or (d == 1 and r >= seg_min):
|
||||
nlocks += 1
|
||||
lockid[current:last] = nlocks
|
||||
maxid[current:last] = last - current
|
||||
# segment size
|
||||
segments[current:current+d] = seg_max
|
||||
if r < seg_min and not isempty:
|
||||
segments[current+d-1] += r
|
||||
if r >= seg_min or isempty:
|
||||
segments[current+d] = r
|
||||
current = last
|
||||
col_idx += 1
|
||||
offsets = torch.zeros_like(segments)
|
||||
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
|
||||
return segments, column, lockid, maxid, offsets
|
||||
|
||||
@staticmethod
|
||||
def get_locks(size, dev):
|
||||
if dev not in _matmul.locks or \
|
||||
size > _matmul.locks[dev].size(0):
|
||||
_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
|
||||
return _matmul.locks[dev]
|
||||
|
||||
##########################
|
||||
# SPARSE = DENSE x DENSE #
|
||||
##########################
|
||||
|
||||
@staticmethod
|
||||
def make_sdd_lut(layout, block, dtype, device):
|
||||
start_width = 64 // block
|
||||
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
|
||||
luts, widths, packs = [], [], []
|
||||
for size, nnz in superblocks:
|
||||
width = nnz.shape[0] // (size*size)
|
||||
h = nnz[:, 0]
|
||||
i = nnz[:, 1]
|
||||
j = nnz[:, 2]
|
||||
b = nnz[:, 3]
|
||||
lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
|
||||
luts.append(lut.type(torch.int32).to(device))
|
||||
widths.append(width)
|
||||
packs.append(size)
|
||||
# create locks
|
||||
return luts, None, widths, packs
|
||||
|
||||
@staticmethod
|
||||
def _sdd_matmul(a, b, trans_a, trans_b, trans_c,
|
||||
spdims, block, luts, num_locks, widths, packs):
|
||||
|
||||
if trans_c:
|
||||
a, b = b, a
|
||||
trans_a, trans_b = not trans_b, not trans_a
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
AS2 = a.size(3 if trans_a else 2)
|
||||
AS3 = a.size(2 if trans_a else 3)
|
||||
BS0 = b.size(0)
|
||||
BS1 = b.size(1)
|
||||
BS2 = b.size(3 if trans_b else 2)
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
device = a.device
|
||||
is_16_multiple = AS3 % 16 == 0
|
||||
is_32_multiple = AS3 % 32 == 0
|
||||
is_64_multiple = AS3 % 64 == 0
|
||||
if not is_16_multiple:
|
||||
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
||||
# create kernel
|
||||
total_width = sum([width*pack*pack for width,pack in zip(widths, packs)])
|
||||
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
|
||||
for lut, width, pack in zip(luts, widths, packs):
|
||||
num_lock = 1
|
||||
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
|
||||
if key not in _matmul.sdd_cache:
|
||||
F32TK = [8, 16]
|
||||
#F16TK = [16]
|
||||
#F16TK += [32] if is_32_multiple else []
|
||||
#F16TK += [64] if is_64_multiple else []
|
||||
F16TK = [64]
|
||||
TK = {torch.float32: F32TK,
|
||||
torch.float16: F16TK}[dtype]
|
||||
defines = {'TM': block*pack, 'TN': block*pack, 'TMN': block*block*pack*pack, 'BLOCK': block,
|
||||
'TK': TK, 'TYPE': dtype,
|
||||
'STRIDE_AM': '1' if trans_a else 'lda',
|
||||
'STRIDE_AK': 'lda' if trans_a else '1',
|
||||
'STRIDE_BN': 'ldb' if trans_b else '1',
|
||||
'STRIDE_BK': '1' if trans_b else 'ldb',
|
||||
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
|
||||
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
|
||||
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=[1, 2, 4])
|
||||
|
||||
kernel = _matmul.sdd_cache[key]
|
||||
# create output
|
||||
locks = _matmul.get_locks(2*width*AS0*num_lock, a.device)
|
||||
# maximum grid size is 65535
|
||||
# so operation might be decomposed into multiple
|
||||
# kernel calls
|
||||
max_width = 49152
|
||||
for off_width in range(0, width, max_width):
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
a.stride(2), b.stride(2), block,
|
||||
a.stride(0), b.stride(0), c.stride(0),
|
||||
a.stride(1), b.stride(1), c.stride(0),
|
||||
AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock,
|
||||
grid = lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
|
||||
# save for backward pass
|
||||
return c
|
||||
|
||||
##########################
|
||||
# DENSE = DENSE x SPARSE #
|
||||
# DENSE = SPARSE x DENSE #
|
||||
##########################
|
||||
|
||||
# Given a binary layout of 0s and 1s,
|
||||
# Construct look-up table for efficient execution on GPUs
|
||||
@staticmethod
|
||||
def make_dxx_lut(layout, block, step, trans, device, transform = lambda idx: idx):
|
||||
# load-balancing
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
segments = _empty.clone()
|
||||
column = _empty.clone()
|
||||
depth = _empty.clone()
|
||||
lockid = _empty.clone()
|
||||
maxid = _empty.clone()
|
||||
offsets = _empty.clone()
|
||||
current_offset = 0
|
||||
current_maxid = 0
|
||||
for z in range(layout.size(0)):
|
||||
if trans:
|
||||
sizes = torch.sum(layout[z, :, :], 1)
|
||||
else:
|
||||
sizes = torch.sum(layout[z, :, :], 0)
|
||||
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
|
||||
z_depth = z * torch.ones_like(z_segments)
|
||||
z_lockid[z_lockid > 0] += current_maxid
|
||||
current_maxid = z_lockid.max()
|
||||
# concatenate depth
|
||||
segments = torch.cat((segments, z_segments))
|
||||
column = torch.cat((column, z_column))
|
||||
depth = torch.cat((depth, z_depth))
|
||||
maxid = torch.cat((maxid, z_maxid))
|
||||
offsets = torch.cat((offsets, current_offset + z_offsets))
|
||||
lockid = torch.cat((lockid, z_lockid))
|
||||
current_offset += layout[z, :, :].sum()
|
||||
segments *= step
|
||||
# pointer increments
|
||||
if trans:
|
||||
nnz = layout.nonzero(as_tuple=False)
|
||||
else:
|
||||
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
|
||||
num_blocks = nnz.size(0)
|
||||
offsets = torch.min(offsets, (num_blocks - 1)*torch.ones_like(offsets))
|
||||
idx = transform(nnz[:, 2]*block)
|
||||
xincs = idx.clone()
|
||||
xincs[1:] -= idx[:-1]
|
||||
# divide block into multiple steps
|
||||
div = block // step
|
||||
xincs = xincs.view(-1, 1).repeat(1, div)
|
||||
xincs[:, 1:] = step
|
||||
xincs[:, 0 ] -= (div-1)*step
|
||||
# first increment for each reduction is actually the offset
|
||||
xincs[offsets[segments>0], 0] = idx[offsets[segments>0]]
|
||||
xincs = xincs.view(-1)
|
||||
# block-mode input increments
|
||||
if trans:
|
||||
widx = torch.arange(num_blocks)
|
||||
else:
|
||||
widx = _empty.clone()
|
||||
current_offset = 0
|
||||
for z in range(layout.size(0)):
|
||||
layoutw = layout[z, :, :].clone()
|
||||
msum = layoutw.sum()
|
||||
layoutw[layoutw > 0] = 1 + torch.arange(msum)
|
||||
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
|
||||
current_offset += msum
|
||||
widx = widx
|
||||
wincs = widx*block*block
|
||||
wincs[1:] -= widx[:-1]*block*block
|
||||
wincs = wincs.view(-1, 1).repeat(1, div)
|
||||
if trans:
|
||||
wincs[:, 1:] = step
|
||||
wincs[:, 0] -= (div-1)*step
|
||||
else:
|
||||
wincs[:, 1:] = step*block
|
||||
wincs[:, 0] -= (div - 1)*step*block
|
||||
wincs[offsets[segments>0], 0] = widx[offsets[segments>0]]
|
||||
wincs = wincs.view(-1)
|
||||
# adjust offset and segment size
|
||||
offsets *= 2*div
|
||||
segments *= div
|
||||
# create header
|
||||
width = column.size(0)
|
||||
offsets += 6*width
|
||||
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
|
||||
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
|
||||
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
|
||||
# create lut
|
||||
lut = torch.cat((header, incs))
|
||||
lut = lut.type(torch.int32).to(device)
|
||||
# create locks
|
||||
num_locks = max(1, lockid.max())
|
||||
return lut, num_locks, width, None
|
||||
|
||||
@staticmethod
|
||||
def _dds_matmul(a, b, trans_a, trans_b, trans_c,
|
||||
spdims, block, lut, num_locks, width, packs):
|
||||
# shapes / dtypes
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
AS2 = a.size(3 if trans_a else 2)
|
||||
AS3 = a.size(2 if trans_a else 3)
|
||||
BS0 = spdims[0]
|
||||
BS1 = block * spdims[2 if trans_b else 1]
|
||||
BS2 = block * spdims[1 if trans_b else 2]
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||
if key not in _matmul.dds_cache:
|
||||
TM = [64, 128] if dtype == torch.float32 else [64, 128, 256]
|
||||
TK = [8] if dtype == torch.float32 else [16]
|
||||
defines = {'TM': TM, 'TN': block, 'TK': TK,
|
||||
'BLOCK': block,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': 1 if trans_a else 'lda',
|
||||
'STRIDE_AK': 'lda' if trans_a else 1,
|
||||
'STRIDE_BN': block if trans_b else 1,
|
||||
'STRIDE_BK': 1 if trans_b else block,
|
||||
'STRIDE_CM': '1' if trans_c else 'ldc',
|
||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||
'NAME': 'dds_kernel',
|
||||
'DDS': True}
|
||||
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
|
||||
kernel = _matmul.dds_cache[key]
|
||||
# output
|
||||
CS0 = AS0
|
||||
CS1 = AS1
|
||||
CS2 = BS2 if trans_c else AS2
|
||||
CS3 = AS2 if trans_c else BS2
|
||||
locks = _matmul.get_locks(2*AS0*AS2//32*num_locks, a.device)
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
a.stride(2), block, c.stride(2),
|
||||
a.stride(0), b.stride(0), c.stride(0),
|
||||
a.stride(1), b.stride(1), c.stride(1),
|
||||
AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
|
||||
grid = lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def _dsd_matmul(a, b, trans_a, trans_b, trans_c,
|
||||
spdims, block, lut, num_locks, width, packs):
|
||||
# shapes / dtypes
|
||||
AS0 = spdims[0]
|
||||
AS1 = block * spdims[2 if trans_a else 1]
|
||||
AS2 = block * spdims[1 if trans_a else 2]
|
||||
BS0 = b.size(0)
|
||||
BS1 = b.size(1)
|
||||
BS2 = b.size(3 if trans_b else 2)
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
# kernel
|
||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||
if key not in _matmul.dsd_cache:
|
||||
TN = [64, 128] if dtype == torch.float32 else [64, 128]
|
||||
TK = [8] if dtype == torch.float32 else [16]
|
||||
defines = {'TM': block, 'TN': TN, 'TK': TK,
|
||||
'BLOCK': block,
|
||||
'TYPE': dtype,
|
||||
'STRIDE_AM': 1 if trans_a else block,
|
||||
'STRIDE_AK': block if trans_a else 1,
|
||||
'STRIDE_BN': 'ldb' if trans_b else '1',
|
||||
'STRIDE_BK': '1' if trans_b else 'ldb',
|
||||
'STRIDE_CM': '1' if trans_c else 'ldc',
|
||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||
'NAME': 'dsd_kernel',
|
||||
'DSD': True}
|
||||
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4])
|
||||
kernel = _matmul.dsd_cache[key]
|
||||
# output
|
||||
CS0 = BS0
|
||||
CS1 = BS1
|
||||
CS2 = BS3 if trans_c else AS1
|
||||
CS3 = AS1 if trans_c else BS3
|
||||
locks = _matmul.get_locks(2*BS0*BS3//32*num_locks, a.device)
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
block, b.stride(2), c.stride(2),
|
||||
a.stride(0), b.stride(0), c.stride(0),
|
||||
a.stride(1), b.stride(1), c.stride(1),
|
||||
BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
|
||||
grid = lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
|
||||
return c
|
||||
|
||||
fn = {'sdd': _sdd_matmul.__get__(object),
|
||||
'dsd': _dsd_matmul.__get__(object),
|
||||
'dds': _dds_matmul.__get__(object)}
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, trans_a, trans_b, trans_c,
|
||||
mode, spdims, block,
|
||||
c_lut, c_num_locks, c_width, c_packs,
|
||||
da_lut, da_num_locks, da_width, da_packs,
|
||||
db_lut, db_num_locks, db_width, db_packs):
|
||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block,
|
||||
c_lut, c_num_locks, c_width, c_packs)
|
||||
# save for backward
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.da_num_locks = da_num_locks
|
||||
ctx.da_lut = da_lut
|
||||
ctx.da_width = da_width
|
||||
ctx.da_packs = da_packs
|
||||
ctx.db_lut = db_lut
|
||||
ctx.db_num_locks = db_num_locks
|
||||
ctx.db_width = db_width
|
||||
ctx.db_packs = db_packs
|
||||
ctx.mode = mode
|
||||
ctx.spdims = spdims
|
||||
ctx.block = block
|
||||
ctx.trans_a = trans_a
|
||||
ctx.trans_b = trans_b
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dc):
|
||||
# saved for backward
|
||||
a, b = ctx.saved_tensors
|
||||
mode = ctx.mode
|
||||
# gradients w.r.t. a
|
||||
if ctx.needs_input_grad[0]:
|
||||
mode_da = mode[1] + mode[0] + mode[2]
|
||||
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
|
||||
ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs)
|
||||
# gradients w.r.t. b
|
||||
if ctx.needs_input_grad[1]:
|
||||
mode_db = mode[2] + mode[1] + mode[0]
|
||||
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block,
|
||||
ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs)
|
||||
return da, db, None, None, None,\
|
||||
None, None, None, None,\
|
||||
None, None, None, None, None, None,\
|
||||
None, None, None, None, None, None,\
|
||||
None, None, None, None, None, None
|
||||
|
||||
class matmul:
|
||||
|
||||
def make_lut(self, dtype, device):
|
||||
key = (dtype, device)
|
||||
if key in self.lut_cache:
|
||||
return self.lut_cache[key]
|
||||
# C look-up table
|
||||
layout, block = self.layout, self.block
|
||||
step = 8 if dtype == torch.float32 else 16
|
||||
if self.mode == 'sdd':
|
||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
elif self.mode == 'dsd':
|
||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
|
||||
elif self.mode == 'dds':
|
||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
|
||||
# DA look-up table
|
||||
if self.mode == 'sdd':
|
||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
|
||||
elif self.mode == 'dsd':
|
||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
elif self.mode == 'dds':
|
||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
|
||||
# DB look-up table
|
||||
if self.mode == 'sdd':
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
|
||||
elif self.mode == 'dsd':
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
|
||||
elif self.mode == 'dds':
|
||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
|
||||
da_lut, da_num_locks, da_width, da_packs,\
|
||||
db_lut, db_num_locks, db_width, db_packs)
|
||||
return self.lut_cache[key]
|
||||
|
||||
def __init__(self, layout, block, mode, trans_a = False, trans_b = False):
|
||||
if mode not in ['sdd', 'dsd', 'dds']:
|
||||
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
|
||||
# look-up table cache
|
||||
self.lut_cache = dict()
|
||||
# attributes
|
||||
self.trans_a = trans_a
|
||||
self.trans_b = trans_b
|
||||
self.mode = mode
|
||||
self.spdims = layout.shape
|
||||
self.block = block
|
||||
self.layout = layout
|
||||
|
||||
# pad shapes of a tensor to make it
|
||||
# compatible with kernel calls
|
||||
@staticmethod
|
||||
def _pad_shape(x, is_sparse):
|
||||
max_dim = 3 if is_sparse else 4
|
||||
for i in range(max_dim - x.dim()):
|
||||
x = x.unsqueeze(0)
|
||||
return x
|
||||
|
||||
def __call__(self, a, b):
|
||||
c_lut, c_num_locks, c_width, c_packs,\
|
||||
da_lut, da_num_locks, da_width, da_packs,\
|
||||
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
|
||||
# pad shapes with ones
|
||||
a = matmul._pad_shape(a, self.mode == 'dsd')
|
||||
b = matmul._pad_shape(b, self.mode == 'dds')
|
||||
# execute
|
||||
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False,
|
||||
self.mode, self.spdims, self.block,
|
||||
c_lut, c_num_locks, c_width, c_packs,
|
||||
da_lut, da_num_locks, da_width, da_packs,
|
||||
db_lut, db_num_locks, db_width, db_packs)
|
||||
return c
|
@@ -1,16 +1,9 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _conv(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
|
||||
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
|
||||
TYPE *B __noalias __readonly __aligned(16),
|
||||
TYPE *C __noalias __aligned(16),
|
||||
float alpha,
|
||||
// equivalent matmul
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int K __retune,
|
||||
int M, int N, int K,
|
||||
// convolution properties
|
||||
int pad_h, int pad_w, int stride_h, int stride_w,
|
||||
// pointer increment
|
||||
@@ -130,73 +123,4 @@ class _conv(torch.autograd.Function):
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = dict()
|
||||
|
||||
@staticmethod
|
||||
def unpack(IDX, CI, R, S):
|
||||
s = IDX % S
|
||||
cr = IDX // S
|
||||
r = cr % R
|
||||
ci = cr // R
|
||||
return ci, r, s
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, pad, stride, time):
|
||||
# create kernel if necessary
|
||||
dtype = a.dtype
|
||||
# shapes
|
||||
Z, CI, H, W = a.shape
|
||||
_, R, S, CO = b.shape
|
||||
P = (H + 2*pad[0] - R)//stride[0] + 1
|
||||
Q = (W + 2*pad[1] - S)//stride[1] + 1
|
||||
# compile kernel
|
||||
if dtype not in _conv.kernel:
|
||||
TK = 8
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [16, 32, 64, 128],
|
||||
'TN' : [16, 32, 64, 128],
|
||||
'TK' : [TK],
|
||||
'TZ' : [1],
|
||||
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
|
||||
}
|
||||
idx = torch.arange(CI*R*S)
|
||||
ci, r, s = _conv.unpack(idx, CI, R, S)
|
||||
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
|
||||
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
|
||||
delta = delta.type(torch.int32).cuda()
|
||||
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
|
||||
delta, kernel = _conv.kernel[dtype]
|
||||
# allocate output
|
||||
c = torch.empty([Z, CO, P, Q], dtype=dtype)
|
||||
# enqueue
|
||||
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')),
|
||||
triton.cdiv(CO, opt.d('TN'))]
|
||||
time[0] = kernel(a, b, c, 1., Z*P*Q, CO, CI*R*S,
|
||||
pad[0], pad[1], stride[0], stride[1],
|
||||
delta,
|
||||
a.stride(0), a.stride(1), a.stride(2), a.stride(3),
|
||||
b.stride(0), b.stride(1), b.stride(2), b.stride(3),
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||
grid=grid, bench=100)
|
||||
return c
|
||||
|
||||
|
||||
|
||||
conv = _conv.apply
|
||||
torch.manual_seed(0)
|
||||
Z, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
|
||||
pad = (1, 1)
|
||||
stride = (1, 1)
|
||||
a = torch.rand((Z, CI, H, W)).cuda()
|
||||
b = torch.rand((CI, R, S, CO)).cuda()
|
||||
time = [None]
|
||||
cc = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, [1, 1])
|
||||
c = conv(a, b, pad, stride, time)
|
||||
print((cc - c).abs().max() / max(cc.max(), c.max()))
|
||||
print(time[0], 2*Z*H*W*CI*CO*R*S/(time[0]*1e-9)*1e-12)
|
||||
#zc = torch.matmul(a,b)
|
||||
#zc_ = dot(a,b)
|
||||
}
|
57
python/triton/ops/conv.py
Normal file
57
python/triton/ops/conv.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import triton
|
||||
import os
|
||||
|
||||
class _conv(torch.autograd.Function):
|
||||
src = triton.read(os.path.join(os.path.dirname(__file__), 'conv.c'))
|
||||
kernel = dict()
|
||||
|
||||
@staticmethod
|
||||
def unpack(IDX, CI, R, S):
|
||||
s = IDX % S
|
||||
cr = IDX // S
|
||||
r = cr % R
|
||||
ci = cr // R
|
||||
return ci, r, s
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, pad, stride):
|
||||
# create kernel if necessary
|
||||
dtype = a.dtype
|
||||
device = a.device
|
||||
# shapes
|
||||
Z, CI, H, W = a.shape
|
||||
_, R, S, CO = b.shape
|
||||
P = (H + 2*pad[0] - R)//stride[0] + 1
|
||||
Q = (W + 2*pad[1] - S)//stride[1] + 1
|
||||
# compile kernel
|
||||
if (dtype, device) not in _conv.kernel:
|
||||
TK = 16
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [32, 64, 128],
|
||||
'TN' : [32, 64, 128],
|
||||
'TK' : [TK],
|
||||
'TZ' : [1],
|
||||
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
|
||||
}
|
||||
idx = torch.arange(CI*R*S)
|
||||
ci, r, s = _conv.unpack(idx, CI, R, S)
|
||||
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
|
||||
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
|
||||
delta = delta.type(torch.int32).cuda()
|
||||
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, num_warps=[4], defines=defines))
|
||||
delta, kernel = _conv.kernel[dtype]
|
||||
# allocate output
|
||||
c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
|
||||
# enqueue
|
||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), 1., Z*P*Q, CO, CI*R*S,
|
||||
pad[0], pad[1], stride[0], stride[1],
|
||||
delta.data_ptr(),
|
||||
a.stride(0), a.stride(1), a.stride(2), a.stride(3),
|
||||
b.stride(0), b.stride(1), b.stride(2), b.stride(3),
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.TM), triton.cdiv(CO, opt.TN)])
|
||||
return c
|
||||
|
||||
conv = _conv.apply
|
97
python/triton/ops/matmul.c
Normal file
97
python/triton/ops/matmul.c
Normal file
@@ -0,0 +1,97 @@
|
||||
#define STM 8
|
||||
#define STN 8
|
||||
|
||||
__global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __aligned(16),
|
||||
float alpha,
|
||||
int M,
|
||||
int N,
|
||||
int K __multipleof(16),
|
||||
int lda __multipleof(LDA_POW2_DIV),
|
||||
int ldb __multipleof(LDB_POW2_DIV),
|
||||
int ldc __multipleof(LDC_POW2_DIV),
|
||||
int* locks) {
|
||||
// prologue
|
||||
int pid = get_program_id(0);
|
||||
int pidz = get_program_id(2);
|
||||
int gridm = (M + TM - 1) / TM;
|
||||
int gridn = (N + TN - 1) / TN;
|
||||
|
||||
// swizzle for better L2 performance
|
||||
int width = STM*gridn;
|
||||
int stm = pid / width;
|
||||
int RSTM = min(gridm - stm*STM, STM);
|
||||
int stn = (pid % width) / (RSTM*STN);
|
||||
int RSTN = min(gridn - stn*STN, STN);
|
||||
int laneid = pid % (RSTM * RSTN);
|
||||
int lanem = laneid / RSTN;
|
||||
int lanen = laneid % RSTN;
|
||||
int pidm = stm*STM + lanem;
|
||||
int pidn = stn*STN + lanen;
|
||||
int rm[TM] = pidm * TM + 0 ... TM;
|
||||
int rn[TN] = pidn * TN + 0 ... TN;
|
||||
|
||||
// split-k for better parrallelism
|
||||
K = K / TZ;
|
||||
int rk[TK] = 0 ... TK;
|
||||
// pointers to operands
|
||||
int offa[TM, TK] = (pidz*K + rk[newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
||||
int offb[TK, TN] = (pidz*K + rk[:, newaxis]) * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
|
||||
TYPE* pa[TM, TK] = A + offa;
|
||||
TYPE* pb[TK, TN] = B + offb;
|
||||
|
||||
// prefetches operands
|
||||
bool checka[TM, TK] = rk[newaxis, :] < K;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
#if (IS_TK_DIV_K==1)
|
||||
bool checkk[TK] = k > TK;
|
||||
#else
|
||||
bool checkk[TK] = rk < k - TK;
|
||||
#endif
|
||||
bool checka[TM, TK] = checkk[newaxis, :];
|
||||
bool checkb[TK, TN] = checkk[:, newaxis];
|
||||
acc += a @ b;
|
||||
#if (IS_TK_DIV_K==1)
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
#else
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
#endif
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
int rcm[TM] = pidm * TM + 0 ... TM;
|
||||
int rcn[TN] = pidn * TN + 0 ... TN;
|
||||
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :];
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn[newaxis, :] < N;
|
||||
#if (TZ==1)
|
||||
*?(checkc) pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
80
python/triton/ops/matmul.py
Normal file
80
python/triton/ops/matmul.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import torch
|
||||
import triton
|
||||
import os
|
||||
|
||||
class _matmul(torch.autograd.Function):
|
||||
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
|
||||
|
||||
TM = [128]
|
||||
TN = [128]
|
||||
TK = [32]
|
||||
TZ = 1
|
||||
num_warps = [4]
|
||||
|
||||
@staticmethod
|
||||
def largest_pow2_divisor(N):
|
||||
if N % 8 == 0: return 8
|
||||
if N % 4 == 0: return 4
|
||||
if N % 2 == 0: return 2
|
||||
return 1
|
||||
|
||||
|
||||
_locks = dict()
|
||||
_kernels = dict()
|
||||
@staticmethod
|
||||
def _call(a, b):
|
||||
dtype = a.dtype
|
||||
device = a.device
|
||||
# allocate output
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
c = torch.empty((M, N), dtype=dtype, device=device)
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous()
|
||||
# kernel hash
|
||||
is_a_row = a.stride(1) == 1
|
||||
is_b_row = b.stride(1) == 1
|
||||
lda = a.stride(0) if is_a_row else a.stride(1)
|
||||
ldb = b.stride(0) if is_b_row else b.stride(1)
|
||||
ldc = c.stride(0)
|
||||
lda_pow2_div = _matmul.largest_pow2_divisor(lda)
|
||||
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
|
||||
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
|
||||
is_tk_div_k = K % 32 == 0
|
||||
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
|
||||
if key not in _matmul._kernels:
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'STRIDE_AM' : 'lda' if is_a_row else '1',
|
||||
'STRIDE_AK' : '1' if is_a_row else 'lda',
|
||||
'STRIDE_BK' : 'ldb' if is_b_row else '1',
|
||||
'STRIDE_BN' : '1' if is_b_row else 'ldb',
|
||||
'LDA_POW2_DIV': lda_pow2_div,
|
||||
'LDB_POW2_DIV': ldb_pow2_div,
|
||||
'LDC_POW2_DIV': ldc_pow2_div,
|
||||
'TM' : _matmul.TM,
|
||||
'TN' : _matmul.TN,
|
||||
'TK' : _matmul.TK,
|
||||
'TZ' : _matmul.TZ,
|
||||
'IS_TK_DIV_K' : is_tk_div_k
|
||||
}
|
||||
_matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines)
|
||||
kernel = _matmul._kernels[key]
|
||||
# # locks for split-k
|
||||
if device not in _matmul._locks:
|
||||
_matmul._locks[device] = torch.zeros(1024*1024, dtype=torch.int32, device=device)
|
||||
locks = _matmul._locks[device]
|
||||
# enqueue
|
||||
alpha = 1.
|
||||
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
|
||||
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, 1]
|
||||
kernel(*args, grid=grid)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
c = _matmul._call(a,b)
|
||||
return c
|
||||
|
||||
matmul = _matmul.apply
|
8
python/triton/ops/softmax.c
Normal file
8
python/triton/ops/softmax.c
Normal file
@@ -0,0 +1,8 @@
|
||||
__global__ void forward(TYPE* X, TYPE* Y) {
|
||||
int pid = get_program_id(0);
|
||||
int off[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
float x[BLOCK] = *(X + off);
|
||||
float shifted[BLOCK] = exp(x - x[max]);
|
||||
float sum = shifted[+];
|
||||
*(Y + off) = shifted / sum;
|
||||
}
|
27
python/triton/ops/softmax.py
Normal file
27
python/triton/ops/softmax.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import triton
|
||||
import os
|
||||
|
||||
kernels = dict()
|
||||
def get_kernel(block, dtype, device):
|
||||
key = (block, dtype, device)
|
||||
if key not in kernels:
|
||||
src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'))
|
||||
defines = {'BLOCK': block, 'TYPE': dtype}
|
||||
kernels[key] = triton.kernel(src, device = device, defines = defines)
|
||||
return kernels[key]
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
y = torch.empty_like(x)
|
||||
M, N = x.shape
|
||||
kernel = get_kernel(N, x.dtype, x.device)
|
||||
kernel(x.data_ptr(), y.data_ptr(), grid = lambda opt: [M, ])
|
||||
return y
|
||||
|
||||
softmax = _softmax.apply
|
||||
|
||||
|
335
python/tutorials/01-vector-add.ipynb
Normal file
335
python/tutorials/01-vector-add.ipynb
Normal file
@@ -0,0 +1,335 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "induced-zoning",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Getting Started"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "median-malaysia",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this tutorial, we will see how to construct a simple, high-performance vector addition using Triton. You will learn:\n",
|
||||
"* The basic syntax of the Triton programming language\n",
|
||||
"* The best practices for creating PyTorch custom operators using the `triton.kernel` Python API\n",
|
||||
"* The best practices for validating and benchmarking custom ops against native reference implementations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "identical-conditions",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Writing the Compute Kernel"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "collectible-belle",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Each compute kernel is declared using the `__global__` attribute, and executed many times in parallel on different chunks of data (See the [Single Program, Multiple Data](https://en.wikipedia.org/wiki/SPMD) programming model for more details).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```c\n",
|
||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
||||
" // The `get_program_id(i)` returns the i-th coordinate\n",
|
||||
" // of the program in the overaching SPMD context\n",
|
||||
" // (a.k.a launch grid). This is what allows us to process\n",
|
||||
" // different chunks of data in parallel.\n",
|
||||
" // For those similar with CUDA, `get_program_id({0,1,2})`\n",
|
||||
" // is similar to blockIdx.{x,y,z}\n",
|
||||
" int pid = get_program_id(0);\n",
|
||||
" // In Triton, arrays are first-class citizen. In other words,\n",
|
||||
" // they are primitives data-types and are -- contrary to C and\n",
|
||||
" // CUDA -- not implemented as pointers to contiguous chunks of\n",
|
||||
" // memory.\n",
|
||||
" // In the few lines below, we create an array of `BLOCK` pointers\n",
|
||||
" // whose memory values are, e.g.:\n",
|
||||
" // [z + pid*BLOCK + 0, z + pid*BLOCK + 1, ..., z + pid*BLOCK + BLOCK - 1]\n",
|
||||
" // Note: here BLOCK is expected to be a pre-processor macro defined at compile-time\n",
|
||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
||||
" float* pz [BLOCK] = z + offset;\n",
|
||||
" float* px [BLOCK] = x + offset;\n",
|
||||
" float* py [BLOCK] = y + offset;\n",
|
||||
" // Simple element-wise control-flow for load/store operations can\n",
|
||||
" // be achieved using the the ternary operator `cond ? val_true : val_false`\n",
|
||||
" // or the conditional dereferencing operator `*?(cond)ptr\n",
|
||||
" // Here, we make sure that we do not access memory out-of-bounds when we\n",
|
||||
" // write-back `z`\n",
|
||||
" bool check[BLOCK] = offset < N;\n",
|
||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The existence of arrays as a primitive data-type for Triton comes with a number of advantages that are highlighted in the [MAPL'2019 Triton paper](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "forbidden-wednesday",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Writing the Torch bindings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "numerical-agency",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The only thing that matters when it comes to Triton and Torch is the `triton.kernel` class. This allows you to transform the above C-like function into a callable python object that can be used to modify `torch.tensor` objects.\n",
|
||||
"\n",
|
||||
"To create a `triton.kernel`, you only need three things:\n",
|
||||
"* `source: string`: the source-code of the kernel you want to create\n",
|
||||
"* `device: torch.device`: the device you want to compile this code for\n",
|
||||
"* `defines: dict`: the set of macros that you want the pre-processor to `#define` for you\n",
|
||||
"\n",
|
||||
"Note: The constructor of `triton.kernel` does some just-in-time compilation, so expect some overhead there. For this reason, I personally like to initialize kernels lazily in a cache (see `_kernels` variable below). This also makes it possible to choose the compilation device dynamically based on the type of the operator's inputs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "sporting-keyboard",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import triton\n",
|
||||
"\n",
|
||||
"# source-code for Triton compute kernel\n",
|
||||
"# here we just copy-paste the above code without the extensive comments.\n",
|
||||
"# you may prefer to store it in a .c file and load it from there instead.\n",
|
||||
"_src = \"\"\"\n",
|
||||
"__global__ void add(float* z, float* x, float* y, int N){\n",
|
||||
" // program id\n",
|
||||
" int pid = get_program_id(0);\n",
|
||||
" // create arrays of pointers\n",
|
||||
" int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;\n",
|
||||
" float* pz[BLOCK] = z + offset;\n",
|
||||
" float* px[BLOCK] = x + offset;\n",
|
||||
" float* py[BLOCK] = y + offset;\n",
|
||||
" // bounds checking\n",
|
||||
" bool check[BLOCK] = offset < N;\n",
|
||||
" // write-back\n",
|
||||
" *?(check)pz = *?(check)px + *?(check)py;\n",
|
||||
"}\n",
|
||||
" \"\"\"\n",
|
||||
"# This function returns a callable `triton.kernel` object\n",
|
||||
"# created from the above source code.\n",
|
||||
"# For portability, we maintain a cache of kernels for different `torch.device`\n",
|
||||
"# We compile the kernel with -DBLOCK=1024\n",
|
||||
"_kernels = dict()\n",
|
||||
"def make_add_kernel(device):\n",
|
||||
" if device not in _kernels:\n",
|
||||
" defines = {'BLOCK': 1024}\n",
|
||||
" _kernels[device] = triton.kernel(_src, device=device, defines=defines)\n",
|
||||
" return _kernels[device]\n",
|
||||
"\n",
|
||||
"# This is a standard torch custom autograd Function\n",
|
||||
"# The only difference is that we can now use the above kernel\n",
|
||||
"# in the `forward` and `backward` functions.`\n",
|
||||
"class _add(torch.autograd.Function):\n",
|
||||
" \n",
|
||||
" @staticmethod\n",
|
||||
" def forward(ctx, x, y):\n",
|
||||
" # constraints of the op\n",
|
||||
" assert x.dtype == torch.float32\n",
|
||||
" # *allocate output*\n",
|
||||
" z = torch.empty_like(x)\n",
|
||||
" # *create launch grid*:\n",
|
||||
" # this is a function which takes compilation parameters `opt`\n",
|
||||
" # as input and returns a tuple of int (i.e., launch grid) for the kernel.\n",
|
||||
" # triton.cdiv is a shortcut for ceil division:\n",
|
||||
" # triton.cdiv(a, b) = (a + b - 1) // b\n",
|
||||
" N = z.shape[0]\n",
|
||||
" grid = lambda opt: (triton.cdiv(N, opt.BLOCK), )\n",
|
||||
" # *launch kernel*:\n",
|
||||
" # pointer to the data of torch tensors can be retrieved with\n",
|
||||
" # the `.data_ptr()` method\n",
|
||||
" kernel = make_add_kernel(z.device)\n",
|
||||
" kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), N, grid = grid)\n",
|
||||
" return z\n",
|
||||
"# Just like we standard PyTorch ops\n",
|
||||
"# We use the `.apply` method to create a \n",
|
||||
"# callable object for our function\n",
|
||||
"add = _add.apply"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "separated-polyester",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"At this point `add(x, y)` is equivalent to `x + y` for contiguous tensors. Now let's test and benchmark it!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "exclusive-salvation",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Writing a Unit Test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "supported-ribbon",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
||||
"tensor([1.3713, 1.3076, 0.4940, ..., 0.6682, 1.1984, 1.2696], device='cuda:0')\n",
|
||||
"The maximum difference between torch and triton is 0.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(0)\n",
|
||||
"x = torch.rand(98432, device='cuda')\n",
|
||||
"y = torch.rand(98432, device='cuda')\n",
|
||||
"za = x + y\n",
|
||||
"zb = add(x, y)\n",
|
||||
"print(za)\n",
|
||||
"print(zb)\n",
|
||||
"print(f'The maximum difference between torch and triton is '\n",
|
||||
" f'{torch.max(torch.abs(za - zb))}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "otherwise-canadian",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Seems to work!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "polished-australia",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Writing a Benchmark"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "historic-glass",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The performance of our GPU code can be benchmark using the `torch.cuda.Event(enable_timing=True)` wrapper. Below is a simple function that benchmarks `rep` runs of our kernels after `warmup` \"cold\" runs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "strange-luxembourg",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# We now want to benchmark the performance of `add`\n",
|
||||
"# Against that of PyTorch for increasing vector sizes\n",
|
||||
"def do_bench(fn, warmup = 10, rep = 50):\n",
|
||||
" start_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" end_event = torch.cuda.Event(enable_timing=True)\n",
|
||||
" ret = fn()\n",
|
||||
" for i in range(warmup):\n",
|
||||
" fn()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" start_event.record()\n",
|
||||
" for i in range(rep):\n",
|
||||
" fn()\n",
|
||||
" end_event.record()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" time_ms = start_event.elapsed_time(end_event) / rep\n",
|
||||
" return time_ms"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "hairy-claim",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now benchmark our custom op for vectors of increasing sizes to get a sense of how it does"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "pleasant-valley",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"131072 0.020 0.003\n",
|
||||
"262144 0.019 0.004\n",
|
||||
"524288 0.016 0.016\n",
|
||||
"1048576 0.033 0.033\n",
|
||||
"2097152 0.071 0.070\n",
|
||||
"4194304 0.142 0.144\n",
|
||||
"8388608 0.287 0.286\n",
|
||||
"16777216 0.572 0.568\n",
|
||||
"33554432 1.139 1.110\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for N in [2**i for i in range(17, 26, 1)]:\n",
|
||||
" x = torch.rand(N, device='cuda')\n",
|
||||
" y = torch.rand(N, device='cuda')\n",
|
||||
" triton_ms = do_bench(lambda: add(x, y))\n",
|
||||
" torch_ms = do_bench(lambda: x + y)\n",
|
||||
" # print the performance of triton and torch as well as the achieved bandwidth\n",
|
||||
" print(f'{N} {triton_ms:.3f} {torch_ms:.3f}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "juvenile-supplement",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Our op is on-par with Torch's vectorized element-wise kernel when the vectors are large enough. One caveat is that the latency of PyTorch is much smaller for small vectors (3us vs 18-20us). This is something we are actively working on to reduce."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "agreed-backing",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Reference in New Issue
Block a user