[PYTHON] Merged blocksparse branch:
* Example for blocksparse matrix multiplication * Simplified Triton kernel API * Revived auto-tuning in einsum
This commit is contained in:
committed by
Philippe Tillet
parent
ea37ba5d35
commit
268894a5ce
@@ -75,6 +75,7 @@ public:
|
||||
// accessors
|
||||
arg_type type() const { return ty_; }
|
||||
void* data() const { return (void*)&val_; }
|
||||
driver::buffer* buffer() const { return val_.buf; }
|
||||
|
||||
|
||||
private:
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include <regex>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
@@ -343,6 +344,17 @@ function::cache_key_t function::get_key(driver::stream *stream, const std::vecto
|
||||
// returns program with best compilation options for given parameter
|
||||
function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn,
|
||||
const std::vector<arg>& args) {
|
||||
// // copy buffer argument so that auto-tuning doesn't corrupt data
|
||||
// std::list<std::shared_ptr<driver::cu_buffer>> copies;
|
||||
// std::vector<arg> args = args;
|
||||
// for(arg& x: args)
|
||||
// if(x.type() == BUFFER_T){
|
||||
// driver::buffer* old = x.buffer();
|
||||
// driver::context* ctx = old->context();
|
||||
// size_t size = old->size();
|
||||
// copies.push_back(std::make_shared<driver::cu_buffer>(ctx, size));
|
||||
// x = arg(copies.back().get());
|
||||
// }
|
||||
// fast path -- no autotuning necessary
|
||||
if(callers_.size() == 1)
|
||||
return &*callers_.begin()->second;
|
||||
|
364
python/examples/blocksparse.py
Normal file
364
python/examples/blocksparse.py
Normal file
@@ -0,0 +1,364 @@
|
||||
import triton
|
||||
import torch
|
||||
|
||||
class _linear(torch.autograd.Function):
|
||||
|
||||
src = '''
|
||||
__global__ void main (TYPE* A __readonly __noalias __aligned(16),
|
||||
TYPE* B __readonly __noalias __aligned(16),
|
||||
TYPE* C __writeonly __noalias __aligned(16),
|
||||
int lda, int ldb, int ldc,
|
||||
int M, int Kmax,
|
||||
int* lut,
|
||||
int* locks, int nlocks) {
|
||||
/* ---------------- */
|
||||
/* Prologue */
|
||||
/* ---------------- */
|
||||
// program ids
|
||||
int pid0 = get_program_id(0);
|
||||
int pid1 = get_program_id(1);
|
||||
#ifdef DW
|
||||
// load LUT header
|
||||
int *header = lut + pid0 * 2;
|
||||
int i = *(header + 0);
|
||||
int j = *(header + 1);
|
||||
int K = Kmax / TZ;
|
||||
int lockid = select(TZ > 1, 1, 0);
|
||||
int offk = pid1 * K;
|
||||
int offm = i * TM;
|
||||
int offn = j * TN;
|
||||
int maxid = get_num_programs(1);
|
||||
#else
|
||||
// load LUT header
|
||||
int *header = lut + pid1 * 5;
|
||||
int offset = *(header + 0);
|
||||
int K = *(header + 1);
|
||||
int column = *(header + 2);
|
||||
int lockid = *(header + 3);
|
||||
int maxid = *(header + 4);
|
||||
int *pinc = lut + offset;
|
||||
int offk = (*pinc) * TK;
|
||||
int offm = pid0 * TM;
|
||||
int offn = column * TN;
|
||||
#endif
|
||||
// initialize a, b pointers
|
||||
int rka[TK] = offk + 0 ... TK;
|
||||
int rkb[TK] = offk + 0 ... TK;
|
||||
int ram[TM] = offm + (0 ... TM);
|
||||
int rbn[TN] = offn + (0 ... TN);
|
||||
TYPE* pa[TM, TK] = A + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
|
||||
TYPE* pb[TK, TN] = B + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
|
||||
// pre-fetch
|
||||
bool checka[TM, TK] = ram[:, newaxis] < M;
|
||||
bool checkb[TK, TN] = 1;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
|
||||
/* ---------------- */
|
||||
/* Inner Loop */
|
||||
/* ---------------- */
|
||||
// create result tile
|
||||
float acc[TM, TN] = 0;
|
||||
#ifdef DW
|
||||
int step = TK;
|
||||
#else
|
||||
int step = 1;
|
||||
#endif
|
||||
for(int k = K; k > 0; k -= step) {
|
||||
acc += a @ b;
|
||||
// update pointers
|
||||
#ifdef DW
|
||||
int inc_a = TK * STRIDE_AK;
|
||||
int inc_b = TK * STRIDE_BK;
|
||||
#else
|
||||
pinc += 1;
|
||||
int inc_a = (*pinc) * TK * STRIDE_AK;
|
||||
int inc_b = (*pinc) * TK * STRIDE_BK;
|
||||
#endif
|
||||
pa += inc_a;
|
||||
pb += inc_b;
|
||||
// pre-fetch
|
||||
bool checka[TM, TK] = k > 1;
|
||||
bool checkb[TK, TN] = k > 1;
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
/* ---------------- */
|
||||
/* Epilogue */
|
||||
/* ---------------- */
|
||||
// initialize c pointers
|
||||
int rcm[TM] = offm + (0 ... TM);
|
||||
int rcn[TN] = offn + (0 ... TN);
|
||||
TYPE* pc[TM, TN] = C + rcm[:, newaxis]*ldc + rcn[newaxis, :];
|
||||
bool checkc[TM, TN] = rcm[:, newaxis] < M;
|
||||
// write-back directly
|
||||
if(lockid == 0) {
|
||||
*?(checkc) pc = c;
|
||||
}
|
||||
// accumulate partial result using spin-locks
|
||||
else {
|
||||
int *plock = locks + get_program_id(0)*nlocks + lockid - 1;
|
||||
int *pcount = plock + get_num_programs(0)*nlocks;
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % maxid);
|
||||
atomic_xchg(plock, 0);
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
# dictionaries for cached triton kernels
|
||||
y_kernel = dict()
|
||||
dx_kernel = dict()
|
||||
dw_kernel = dict()
|
||||
|
||||
# Given an array sizes representing reduction size for each
|
||||
# column of a block-sparse matrix multiplication,
|
||||
# performs load-balancing to achieve more smaller reductions
|
||||
# of size seg_size
|
||||
@staticmethod
|
||||
def load_balance(sizes, seg_size=8):
|
||||
div = sizes // seg_size
|
||||
rem = sizes % seg_size
|
||||
packs = div + (rem != 0).long()
|
||||
width = packs.sum()
|
||||
# split reduction into segments
|
||||
segments = torch.empty(width, dtype=sizes.dtype)
|
||||
column = torch.empty_like(segments)
|
||||
lockid = torch.zeros_like(segments)
|
||||
maxid = torch.zeros_like(segments)
|
||||
nlocks = 0
|
||||
current = 0
|
||||
col_idx = 0
|
||||
for i in range(len(sizes)):
|
||||
d, r = div[i], rem[i]
|
||||
last = current + d + (r > 0)
|
||||
# column id
|
||||
column[current:last] = col_idx
|
||||
# lock id
|
||||
if d > 1 or (d == 1 and r > 0):
|
||||
nlocks += 1
|
||||
lockid[current:last] = nlocks
|
||||
maxid[current:last] = last - current
|
||||
# segment size
|
||||
segments[current:current+d] = seg_size
|
||||
if r > 0:
|
||||
segments[current+d] = r
|
||||
current = last
|
||||
col_idx += 1
|
||||
offsets = torch.zeros_like(segments)
|
||||
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
|
||||
return segments, column, lockid, maxid, offsets
|
||||
|
||||
# Given a binary mask of 0s and 1s,
|
||||
# Construct look-up table for efficient execution on GPUs
|
||||
@staticmethod
|
||||
def make_ydx_lut(mask, block_size):
|
||||
# offsets in lookup table
|
||||
sizes = torch.sum(mask, 0)
|
||||
offsets = torch.zeros_like(sizes)
|
||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||
# load-balancing
|
||||
segments, column, lockid, maxid, offsets = dot.load_balance(sizes)
|
||||
# pointer increments
|
||||
nnz = torch.nonzero(mask.T)
|
||||
idx = nnz[:, 1]
|
||||
incs = idx.clone()
|
||||
incs[1:] -= idx[:-1]
|
||||
incs[offsets] = idx[offsets]
|
||||
# create header
|
||||
width = column.size(0)
|
||||
offsets += 5*width
|
||||
header = torch.stack((offsets, segments, column, lockid, maxid), dim=1).view(-1).contiguous()
|
||||
# create lut
|
||||
lut = torch.cat((header, incs)).type(torch.int32).cuda()
|
||||
# create locks
|
||||
num_locks = max(1, lockid.max())
|
||||
locks = torch.zeros((2*mask.size(0), num_locks), dtype=torch.int32).cuda()
|
||||
return lut, locks, width
|
||||
|
||||
@staticmethod
|
||||
def make_dw_lut(mask, depth, block_size):
|
||||
nnz = torch.nonzero(mask)
|
||||
# create lut
|
||||
width = nnz.size(0)
|
||||
i = nnz[:, 0]
|
||||
j = nnz[:, 1]
|
||||
lut = torch.stack((i, j), dim=1).view(-1).contiguous()
|
||||
lut = lut.type(torch.int32).cuda()
|
||||
# create locks
|
||||
num_locks = 1
|
||||
locks = torch.zeros((2*width, num_locks), dtype=torch.int32).cuda()
|
||||
return lut, locks, width
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, w, block_size,
|
||||
y_lut, y_locks, y_width,
|
||||
dx_lut, dx_locks, dx_width,
|
||||
dw_lut, dw_locks, dw_width):
|
||||
M, Kx = x.size()
|
||||
Kw, N = w.size()
|
||||
dtype = x.dtype
|
||||
# memory strides
|
||||
lda = Kx
|
||||
ldb = N
|
||||
ldc = N
|
||||
# create kernel
|
||||
key = (dtype, block_size)
|
||||
if key not in dot.y_kernel:
|
||||
defines = {'TM': 64, 'TN': block_size, 'TK': block_size, 'TYPE': dtype,
|
||||
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
|
||||
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb'}
|
||||
dot.y_kernel[key] = triton.kernel(dot.src, defines=defines)
|
||||
kernel = dot.y_kernel[key]
|
||||
# allocate output
|
||||
y = torch.empty((M, N), dtype=dtype, device=x.device)
|
||||
# launch kernel
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), y_width]
|
||||
kernel(x, w, y, lda, ldb, ldc, M, K, y_lut, y_locks, y_locks.size(1), grid=grid)
|
||||
# save information in context
|
||||
ctx.dx_width = dx_width
|
||||
ctx.dw_width = dw_width
|
||||
ctx.kernel = kernel
|
||||
ctx.block_size = block_size
|
||||
ctx.save_for_backward(x, w, dx_lut, dx_locks, dw_lut, dw_locks)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
# retrieve information in context
|
||||
x, w, dx_lut, dx_locks, dw_lut, dw_locks = ctx.saved_tensors
|
||||
dx_width = ctx.dx_width
|
||||
dw_width = ctx.dw_width
|
||||
block_size = ctx.block_size
|
||||
kernel = ctx.kernel
|
||||
# shapes
|
||||
M, N = dy.size()
|
||||
_, K = x.size()
|
||||
dtype = x.dtype
|
||||
################
|
||||
# input gradient
|
||||
################
|
||||
dx = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
# create kernel
|
||||
key = (dtype, block_size)
|
||||
if key not in dot.dx_kernel:
|
||||
defines = {'TM': 64, 'TN': block_size, 'TK': block_size, 'TYPE': dtype,
|
||||
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
|
||||
'STRIDE_BN': 'ldb', 'STRIDE_BK': '1'}
|
||||
dot.dx_kernel[key] = triton.kernel(dot.src, defines=defines)
|
||||
kernel = dot.dx_kernel[key]
|
||||
# allocate output
|
||||
dx = torch.empty_like(x)
|
||||
# launch kernel
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), dx_width]
|
||||
kernel(dy, w, dx, N, N, K, M, N, dx_lut, dx_locks, dx_locks.size(1), grid=grid)
|
||||
#################
|
||||
# weight gradient
|
||||
#################
|
||||
dw = None
|
||||
if ctx.needs_input_grad[1]:
|
||||
# create kernel
|
||||
key = (dtype, block_size)
|
||||
if key not in dot.dw_kernel:
|
||||
defines = {'TM': block_size, 'TN': block_size, 'TK': 8, 'TYPE': dtype,
|
||||
'STRIDE_AM': '1', 'STRIDE_AK': 'lda',
|
||||
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
||||
'DW': True, 'TZ': 2}
|
||||
dot.dw_kernel[key] = triton.kernel(dot.src, defines=defines)
|
||||
kernel = dot.dw_kernel[key]
|
||||
# allocate output
|
||||
dw = torch.zeros_like(w)
|
||||
# launch kernel
|
||||
grid = lambda opt: [dw_width, opt.d('TZ')]
|
||||
kernel(x, dy, dw, K, N, N, K, M, dw_lut, dw_locks, dw_locks.size(1), grid=grid)
|
||||
# done
|
||||
return dx, dw, None,\
|
||||
None, None, None,\
|
||||
None, None, None,\
|
||||
None, None, None
|
||||
linear = _linear.apply
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features, block_size, mask):
|
||||
super(Linear, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
|
||||
self.reset_parameter()
|
||||
# create look-up tables
|
||||
self.y_lut, self.y_locks, self.y_width = _linear.make_ydx_lut(mask, block_size)
|
||||
self.dx_lut, self.dx_locks, self.dx_width = _linear.make_ydx_lut(mask.T, block_size)
|
||||
self.dw_lut, self.dw_locks, self.dw_width = _linear.make_dw_lut(mask, M, block_size)
|
||||
|
||||
def reset_parameters(self):
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def forward(self, input):
|
||||
return linear(input, self.weight, self.block_size,
|
||||
self.y_lut, self.y_locks, self.y_width,
|
||||
self.dx_lut, self.dx_locks, self.dx_width,
|
||||
self.dw_lut, self.dw_locks, self.dw_width)
|
||||
|
||||
|
||||
def reference_dot(x, w, mask):
|
||||
WS0, WS1 = w.size()
|
||||
MS0, MS1 = mask.size()
|
||||
assert WS0 % MS0 == 0
|
||||
assert WS1 % MS1 == 0
|
||||
block_size_0 = WS0 // MS0
|
||||
block_size_1 = WS1 // MS1
|
||||
assert block_size_0 == block_size_1
|
||||
maskedw = w.clone()
|
||||
for bi, wi in enumerate(range(0, WS0, block_size_0)):
|
||||
for bj, wj in enumerate(range(0, WS1, block_size_1)):
|
||||
maskedw[wi : wi+block_size_0,
|
||||
wj : wj+block_size_1] *= mask[bi, bj]
|
||||
return torch.matmul(x, maskedw)
|
||||
|
||||
torch.manual_seed(0)
|
||||
# parameters
|
||||
M, N, K = 256, 256, 256
|
||||
BS = 16
|
||||
# initialize inputs
|
||||
mask = torch.randint(0, 2, (K//BS, N//BS))
|
||||
x = torch.rand((M, K), dtype=torch.float32, requires_grad=True).cuda()
|
||||
w = torch.rand((K, N), dtype=torch.float32, requires_grad=True).cuda()
|
||||
x.retain_grad()
|
||||
w.retain_grad()
|
||||
# reference result
|
||||
ry = reference_dot(x, w, mask)
|
||||
dy = torch.rand_like(ry)
|
||||
ry.backward(dy)
|
||||
rdx = x.grad.clone()
|
||||
rdw = w.grad.clone()
|
||||
# reset gradients
|
||||
x.grad.zero_()
|
||||
w.grad.zero_()
|
||||
# triton result
|
||||
y_lut, y_locks, y_width = _linear.make_ydx_lut(mask, BS)
|
||||
dx_lut, dx_locks, dx_width = _linear.make_ydx_lut(mask.T, BS)
|
||||
dw_lut, dw_locks, dw_width = _linear.make_dw_lut(mask, M, BS)
|
||||
ty = _linear.apply(x, w, BS,
|
||||
y_lut, y_locks, y_width,
|
||||
dx_lut, dx_locks, dx_width,
|
||||
dw_lut, dw_locks, dw_width)
|
||||
ty.backward(dy)
|
||||
tdx = x.grad.clone()
|
||||
tdw = w.grad.clone()
|
||||
# test
|
||||
print((ty - ry).abs().max())
|
||||
print((tdx - rdx).abs().max())
|
||||
print((tdw - rdw).abs().max())
|
@@ -148,71 +148,41 @@ def _make_framework_op(arg_types):
|
||||
|
||||
class kernel:
|
||||
|
||||
def __init__(self, src):
|
||||
self.fw_id = dict()
|
||||
self.fw_grids = dict()
|
||||
self.fw_op = None
|
||||
def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]):
|
||||
self.src = src
|
||||
# create constants
|
||||
self.cst = dict()
|
||||
# create triton op
|
||||
macros = []
|
||||
for k, v in defines.items():
|
||||
cvt = lambda x: _cvt_to_def_str(x)
|
||||
if(isinstance(v, list)):
|
||||
values = list(map(cvt, v))
|
||||
else:
|
||||
values = [cvt(v)]
|
||||
macros.append((k, values))
|
||||
opt = libtriton.options_space()
|
||||
opt.defines = macros
|
||||
opt.num_warps = num_warps
|
||||
self.op_id = libtriton.make_op_id()
|
||||
libtriton.register_fn(self.op_id, self.src, opt, os.path.realpath(libtriton.__file__))
|
||||
# create pytorch hook
|
||||
arg_types = libtriton.get_fn_signature(self.src, opt)
|
||||
self.fw_op = _make_framework_op(arg_types)
|
||||
|
||||
def set_constant(self, name, value):
|
||||
self.cst[name] = value
|
||||
libtriton.register_cst(self.op_id, name, value)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
########################
|
||||
# JIT Options
|
||||
########################
|
||||
num_warps = kwargs['num_warps'] if 'num_warps' in kwargs else [2, 4, 8]
|
||||
defines = kwargs['defines'] if 'defines' in kwargs else dict()
|
||||
# launch options
|
||||
bench = kwargs['bench'] if 'bench' in kwargs else 0
|
||||
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
|
||||
# launch grid
|
||||
if 'grid' not in kwargs:
|
||||
raise RuntimeError('Must provide grid for kernel launch')
|
||||
grid = kwargs['grid']
|
||||
|
||||
|
||||
#########################
|
||||
# cache
|
||||
########################
|
||||
# create a new framework op when defines are different
|
||||
key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in defines.items()])
|
||||
if key not in self.fw_id.keys():
|
||||
# code generation options
|
||||
macros = []
|
||||
for k, v in defines.items():
|
||||
cvt = lambda x: _cvt_to_def_str(x)
|
||||
if(isinstance(v, list)):
|
||||
values = list(map(cvt, v))
|
||||
else:
|
||||
values = [cvt(v)]
|
||||
macros.append((k, values))
|
||||
opt = libtriton.options_space()
|
||||
opt.defines = macros
|
||||
opt.num_warps = num_warps
|
||||
# create triton function for this op
|
||||
op_id = libtriton.make_op_id()
|
||||
self.fw_id[key] = op_id
|
||||
libtriton.register_fn(op_id, self.src, opt, os.path.realpath(libtriton.__file__))
|
||||
for name, value in self.cst.items():
|
||||
libtriton.register_cst(op_id, name, value)
|
||||
# create pytorch hook for this op
|
||||
arg_types = libtriton.get_fn_signature(self.src, opt)
|
||||
if self.fw_op is None:
|
||||
self.fw_op = _make_framework_op(arg_types)
|
||||
|
||||
########################
|
||||
# initialize
|
||||
########################
|
||||
op_id = self.fw_id[key]
|
||||
libtriton.register_grid(op_id, grid)
|
||||
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
|
||||
|
||||
############################
|
||||
# call torch function
|
||||
############################
|
||||
if fw.has_torch():
|
||||
self.fw_op(op_id, bench, bench_id, *args)
|
||||
if bench > 0:
|
||||
return libtriton.retrieve_scalar(bench_id)
|
||||
|
||||
else:
|
||||
assert False
|
||||
libtriton.register_grid(self.op_id, grid)
|
||||
# launch
|
||||
self.fw_op(self.op_id, bench, bench_id, *args)
|
||||
if bench > 0:
|
||||
return libtriton.retrieve_scalar(bench_id)
|
@@ -125,7 +125,7 @@ def replace_conv2d(model, acc_bitmask = None):
|
||||
#torch.Size([128, 256, 8, 8]) torch.Size([512, 256, 3, 3])
|
||||
|
||||
if __name__ == '__main__':
|
||||
N, C, H, W, K, RS = 128, 64, 30, 30, 128, 1
|
||||
N, C, H, W, K, RS = 128, 64, 30, 30, 128, 3
|
||||
#N, C, H, W, K, RS = 128, 128, 15, 15, 256, 3
|
||||
#N, C, H, W, K, RS = 128, 256, 8, 8, 512, 3
|
||||
pad, stride = 0, 1
|
||||
|
@@ -42,7 +42,6 @@ void fwdbatchnorm(float *Y, float *M, float *V,
|
||||
}
|
||||
}
|
||||
"""
|
||||
fwd_kernel = triton.kernel(fwd_src)
|
||||
|
||||
bwd_src = """
|
||||
void bwdbatchnorm(float *DX, float *DG, float *DB,
|
||||
@@ -89,10 +88,16 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
|
||||
}
|
||||
}
|
||||
"""
|
||||
bwd_kernel = triton.kernel(bwd_src)
|
||||
|
||||
fwd_kernel = None
|
||||
bwd_kernel = None
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, gamma, beta, eps):
|
||||
# lazy compilation of kernel
|
||||
if _batchnorm.fwd_kernel is None:
|
||||
_batchnorm.fwd_kernel = triton.kernel(fwd_src, defines = {'TM': 128})
|
||||
# shapes
|
||||
shape = triton.shape(x)
|
||||
dtype = x.dtype
|
||||
# allocate outputs
|
||||
@@ -102,8 +107,7 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
|
||||
var = triton.empty([C], dtype=dtype)
|
||||
# execute kernels
|
||||
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
|
||||
grid = lambda opt: [1, C],
|
||||
defines = {'TM': 128})
|
||||
grid = lambda opt: [1, C])
|
||||
# save
|
||||
ctx.save_for_backward(x, gamma, beta, mean, var)
|
||||
ctx.eps = eps
|
||||
@@ -111,6 +115,9 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
# lazy compilation of kernel
|
||||
if _batchnorm.bwd_kernel is None:
|
||||
_batchnorm.bwd_kernel = triton.kernel(bwd_src, defines = {'TN': 128})
|
||||
# retrieve info
|
||||
x, gamma, beta, mean, var = ctx.saved_tensors
|
||||
eps = ctx.eps
|
||||
@@ -123,8 +130,7 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
|
||||
_batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
|
||||
x, gamma, mean, var,
|
||||
H*W*B, eps,
|
||||
grid = lambda opt: [1, C],
|
||||
defines = {'TM': 128})
|
||||
grid = lambda opt: [1, C])
|
||||
return dx, dgamma, dbeta, None
|
||||
|
||||
batchnorm = _batchnorm.apply
|
@@ -69,7 +69,7 @@ class _einsum(torch.autograd.Function):
|
||||
ret = dict(zip(expr, ret))
|
||||
return ret
|
||||
|
||||
def make_kernel(name,
|
||||
def make_kernel(name, dtype, mask,
|
||||
expr_a, expr_b, expr_c,
|
||||
axes_m, axes_n, axes_k, axes_b,
|
||||
multipleof_a, multipleof_b, multipleof_c,
|
||||
@@ -329,8 +329,15 @@ __global__ void {name}(
|
||||
#endif
|
||||
}
|
||||
"""
|
||||
|
||||
ret = triton.kernel(src)
|
||||
# compilation options
|
||||
TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16]
|
||||
TK = 16 if dtype==torch.float16 else 8
|
||||
defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||
if mask is not None:
|
||||
defines['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
||||
# create kernel
|
||||
ret = triton.kernel(src, defines=defines)
|
||||
# set constant
|
||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
||||
ret.set_constant('AD', delta_a)
|
||||
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
|
||||
@@ -541,14 +548,15 @@ __global__ void {name}(
|
||||
stride_a_last = stride_a[-1]
|
||||
stride_b_last = stride_b[-1]
|
||||
stride_c_last = stride_c[-1]
|
||||
name = f'{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
|
||||
name = f'{dtype}_{mask}_{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
|
||||
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\
|
||||
f'_{stride_a_last}_{stride_b_last}_{stride_c_last}'
|
||||
# recompile if necessary
|
||||
cache = _einsum.instance.kernel_cache
|
||||
if name not in cache:
|
||||
cachesize = len(cache)
|
||||
cache[name] = _einsum.make_kernel(f'__einsum{cachesize}',
|
||||
cache[name] = _einsum.make_kernel(f'__einsum{cachesize}',
|
||||
dtype, mask,
|
||||
sym_a, sym_b, sym_c,
|
||||
axes_m, axes_n, axes_k, axes_b,
|
||||
stride_a_multiple, stride_b_multiple, stride_c_multiple,
|
||||
@@ -604,19 +612,6 @@ __global__ void {name}(
|
||||
self.pos_vars = len(self.args)
|
||||
self.varnames = varnames
|
||||
self.args += [None] * len(varnames)
|
||||
# tile size ranges
|
||||
MAX_GZ = triton.cdiv(K, 2048)
|
||||
TMs = [16] + [x for x in [32, 64, 128] if x <= M]
|
||||
TNs = [16] + [x for x in [32, 64, 128] if x <= N]
|
||||
TBs = [x for x in [1, 2, 4, 8] if x <= B]
|
||||
TZs = [x for x in [1, 2, 4, 8, 16, 32] if x <= MAX_GZ]
|
||||
# tile sizes
|
||||
TM, TN, TB, TZ = _einsum.instance._tile(M, N, B, TMs, TNs, TBs, TZs, TK)
|
||||
TM, TN, TB, TZ = 64, 128, 1, 1
|
||||
self.macros = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||
self.num_warps = [4]
|
||||
if mask is not None:
|
||||
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
||||
# save information on the operation
|
||||
self.expr_a = expr_a
|
||||
self.expr_b = expr_b
|
||||
@@ -634,7 +629,7 @@ __global__ void {name}(
|
||||
self.args[self.pos_c] = c
|
||||
for i, name in enumerate(self.varnames):
|
||||
self.args[self.pos_vars + i] = values[name]
|
||||
return self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros, num_warps=self.num_warps)
|
||||
return self.kernel(*self.args, grid=self.grid, bench=bench)
|
||||
|
||||
|
||||
|
||||
|
@@ -4,7 +4,7 @@ import numpy as np
|
||||
import weakref
|
||||
|
||||
def cdiv(a, b):
|
||||
return -(-a // b)
|
||||
return (a + b - 1) // b
|
||||
|
||||
class tf_empty_proxy:
|
||||
|
||||
|
Reference in New Issue
Block a user