[PYTHON][EXAMPLES] Removed BlockSparse examples; see
https://github.com/ptillet/torch-blocksparse.git
This commit is contained in:
committed by
Philippe Tillet
parent
268894a5ce
commit
9fda39f64c
@@ -1,364 +0,0 @@
|
|||||||
import triton
|
|
||||||
import torch
|
|
||||||
|
|
||||||
class _linear(torch.autograd.Function):
|
|
||||||
|
|
||||||
src = '''
|
|
||||||
__global__ void main (TYPE* A __readonly __noalias __aligned(16),
|
|
||||||
TYPE* B __readonly __noalias __aligned(16),
|
|
||||||
TYPE* C __writeonly __noalias __aligned(16),
|
|
||||||
int lda, int ldb, int ldc,
|
|
||||||
int M, int Kmax,
|
|
||||||
int* lut,
|
|
||||||
int* locks, int nlocks) {
|
|
||||||
/* ---------------- */
|
|
||||||
/* Prologue */
|
|
||||||
/* ---------------- */
|
|
||||||
// program ids
|
|
||||||
int pid0 = get_program_id(0);
|
|
||||||
int pid1 = get_program_id(1);
|
|
||||||
#ifdef DW
|
|
||||||
// load LUT header
|
|
||||||
int *header = lut + pid0 * 2;
|
|
||||||
int i = *(header + 0);
|
|
||||||
int j = *(header + 1);
|
|
||||||
int K = Kmax / TZ;
|
|
||||||
int lockid = select(TZ > 1, 1, 0);
|
|
||||||
int offk = pid1 * K;
|
|
||||||
int offm = i * TM;
|
|
||||||
int offn = j * TN;
|
|
||||||
int maxid = get_num_programs(1);
|
|
||||||
#else
|
|
||||||
// load LUT header
|
|
||||||
int *header = lut + pid1 * 5;
|
|
||||||
int offset = *(header + 0);
|
|
||||||
int K = *(header + 1);
|
|
||||||
int column = *(header + 2);
|
|
||||||
int lockid = *(header + 3);
|
|
||||||
int maxid = *(header + 4);
|
|
||||||
int *pinc = lut + offset;
|
|
||||||
int offk = (*pinc) * TK;
|
|
||||||
int offm = pid0 * TM;
|
|
||||||
int offn = column * TN;
|
|
||||||
#endif
|
|
||||||
// initialize a, b pointers
|
|
||||||
int rka[TK] = offk + 0 ... TK;
|
|
||||||
int rkb[TK] = offk + 0 ... TK;
|
|
||||||
int ram[TM] = offm + (0 ... TM);
|
|
||||||
int rbn[TN] = offn + (0 ... TN);
|
|
||||||
TYPE* pa[TM, TK] = A + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
|
|
||||||
TYPE* pb[TK, TN] = B + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
|
|
||||||
// pre-fetch
|
|
||||||
bool checka[TM, TK] = ram[:, newaxis] < M;
|
|
||||||
bool checkb[TK, TN] = 1;
|
|
||||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
|
||||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
|
||||||
|
|
||||||
/* ---------------- */
|
|
||||||
/* Inner Loop */
|
|
||||||
/* ---------------- */
|
|
||||||
// create result tile
|
|
||||||
float acc[TM, TN] = 0;
|
|
||||||
#ifdef DW
|
|
||||||
int step = TK;
|
|
||||||
#else
|
|
||||||
int step = 1;
|
|
||||||
#endif
|
|
||||||
for(int k = K; k > 0; k -= step) {
|
|
||||||
acc += a @ b;
|
|
||||||
// update pointers
|
|
||||||
#ifdef DW
|
|
||||||
int inc_a = TK * STRIDE_AK;
|
|
||||||
int inc_b = TK * STRIDE_BK;
|
|
||||||
#else
|
|
||||||
pinc += 1;
|
|
||||||
int inc_a = (*pinc) * TK * STRIDE_AK;
|
|
||||||
int inc_b = (*pinc) * TK * STRIDE_BK;
|
|
||||||
#endif
|
|
||||||
pa += inc_a;
|
|
||||||
pb += inc_b;
|
|
||||||
// pre-fetch
|
|
||||||
bool checka[TM, TK] = k > 1;
|
|
||||||
bool checkb[TK, TN] = k > 1;
|
|
||||||
a = *?(checka)pa;
|
|
||||||
b = *?(checkb)pb;
|
|
||||||
}
|
|
||||||
TYPE c[TM, TN] = acc;
|
|
||||||
|
|
||||||
/* ---------------- */
|
|
||||||
/* Epilogue */
|
|
||||||
/* ---------------- */
|
|
||||||
// initialize c pointers
|
|
||||||
int rcm[TM] = offm + (0 ... TM);
|
|
||||||
int rcn[TN] = offn + (0 ... TN);
|
|
||||||
TYPE* pc[TM, TN] = C + rcm[:, newaxis]*ldc + rcn[newaxis, :];
|
|
||||||
bool checkc[TM, TN] = rcm[:, newaxis] < M;
|
|
||||||
// write-back directly
|
|
||||||
if(lockid == 0) {
|
|
||||||
*?(checkc) pc = c;
|
|
||||||
}
|
|
||||||
// accumulate partial result using spin-locks
|
|
||||||
else {
|
|
||||||
int *plock = locks + get_program_id(0)*nlocks + lockid - 1;
|
|
||||||
int *pcount = plock + get_num_programs(0)*nlocks;
|
|
||||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
|
||||||
int count = *pcount;
|
|
||||||
if(count == 0)
|
|
||||||
*?(checkc) pc = c;
|
|
||||||
else
|
|
||||||
*?(checkc) pc = c + *?(checkc)pc;
|
|
||||||
atomic_xchg(pcount, (count + 1) % maxid);
|
|
||||||
atomic_xchg(plock, 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
'''
|
|
||||||
|
|
||||||
# dictionaries for cached triton kernels
|
|
||||||
y_kernel = dict()
|
|
||||||
dx_kernel = dict()
|
|
||||||
dw_kernel = dict()
|
|
||||||
|
|
||||||
# Given an array sizes representing reduction size for each
|
|
||||||
# column of a block-sparse matrix multiplication,
|
|
||||||
# performs load-balancing to achieve more smaller reductions
|
|
||||||
# of size seg_size
|
|
||||||
@staticmethod
|
|
||||||
def load_balance(sizes, seg_size=8):
|
|
||||||
div = sizes // seg_size
|
|
||||||
rem = sizes % seg_size
|
|
||||||
packs = div + (rem != 0).long()
|
|
||||||
width = packs.sum()
|
|
||||||
# split reduction into segments
|
|
||||||
segments = torch.empty(width, dtype=sizes.dtype)
|
|
||||||
column = torch.empty_like(segments)
|
|
||||||
lockid = torch.zeros_like(segments)
|
|
||||||
maxid = torch.zeros_like(segments)
|
|
||||||
nlocks = 0
|
|
||||||
current = 0
|
|
||||||
col_idx = 0
|
|
||||||
for i in range(len(sizes)):
|
|
||||||
d, r = div[i], rem[i]
|
|
||||||
last = current + d + (r > 0)
|
|
||||||
# column id
|
|
||||||
column[current:last] = col_idx
|
|
||||||
# lock id
|
|
||||||
if d > 1 or (d == 1 and r > 0):
|
|
||||||
nlocks += 1
|
|
||||||
lockid[current:last] = nlocks
|
|
||||||
maxid[current:last] = last - current
|
|
||||||
# segment size
|
|
||||||
segments[current:current+d] = seg_size
|
|
||||||
if r > 0:
|
|
||||||
segments[current+d] = r
|
|
||||||
current = last
|
|
||||||
col_idx += 1
|
|
||||||
offsets = torch.zeros_like(segments)
|
|
||||||
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
|
|
||||||
return segments, column, lockid, maxid, offsets
|
|
||||||
|
|
||||||
# Given a binary mask of 0s and 1s,
|
|
||||||
# Construct look-up table for efficient execution on GPUs
|
|
||||||
@staticmethod
|
|
||||||
def make_ydx_lut(mask, block_size):
|
|
||||||
# offsets in lookup table
|
|
||||||
sizes = torch.sum(mask, 0)
|
|
||||||
offsets = torch.zeros_like(sizes)
|
|
||||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
|
||||||
# load-balancing
|
|
||||||
segments, column, lockid, maxid, offsets = dot.load_balance(sizes)
|
|
||||||
# pointer increments
|
|
||||||
nnz = torch.nonzero(mask.T)
|
|
||||||
idx = nnz[:, 1]
|
|
||||||
incs = idx.clone()
|
|
||||||
incs[1:] -= idx[:-1]
|
|
||||||
incs[offsets] = idx[offsets]
|
|
||||||
# create header
|
|
||||||
width = column.size(0)
|
|
||||||
offsets += 5*width
|
|
||||||
header = torch.stack((offsets, segments, column, lockid, maxid), dim=1).view(-1).contiguous()
|
|
||||||
# create lut
|
|
||||||
lut = torch.cat((header, incs)).type(torch.int32).cuda()
|
|
||||||
# create locks
|
|
||||||
num_locks = max(1, lockid.max())
|
|
||||||
locks = torch.zeros((2*mask.size(0), num_locks), dtype=torch.int32).cuda()
|
|
||||||
return lut, locks, width
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def make_dw_lut(mask, depth, block_size):
|
|
||||||
nnz = torch.nonzero(mask)
|
|
||||||
# create lut
|
|
||||||
width = nnz.size(0)
|
|
||||||
i = nnz[:, 0]
|
|
||||||
j = nnz[:, 1]
|
|
||||||
lut = torch.stack((i, j), dim=1).view(-1).contiguous()
|
|
||||||
lut = lut.type(torch.int32).cuda()
|
|
||||||
# create locks
|
|
||||||
num_locks = 1
|
|
||||||
locks = torch.zeros((2*width, num_locks), dtype=torch.int32).cuda()
|
|
||||||
return lut, locks, width
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x, w, block_size,
|
|
||||||
y_lut, y_locks, y_width,
|
|
||||||
dx_lut, dx_locks, dx_width,
|
|
||||||
dw_lut, dw_locks, dw_width):
|
|
||||||
M, Kx = x.size()
|
|
||||||
Kw, N = w.size()
|
|
||||||
dtype = x.dtype
|
|
||||||
# memory strides
|
|
||||||
lda = Kx
|
|
||||||
ldb = N
|
|
||||||
ldc = N
|
|
||||||
# create kernel
|
|
||||||
key = (dtype, block_size)
|
|
||||||
if key not in dot.y_kernel:
|
|
||||||
defines = {'TM': 64, 'TN': block_size, 'TK': block_size, 'TYPE': dtype,
|
|
||||||
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
|
|
||||||
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb'}
|
|
||||||
dot.y_kernel[key] = triton.kernel(dot.src, defines=defines)
|
|
||||||
kernel = dot.y_kernel[key]
|
|
||||||
# allocate output
|
|
||||||
y = torch.empty((M, N), dtype=dtype, device=x.device)
|
|
||||||
# launch kernel
|
|
||||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), y_width]
|
|
||||||
kernel(x, w, y, lda, ldb, ldc, M, K, y_lut, y_locks, y_locks.size(1), grid=grid)
|
|
||||||
# save information in context
|
|
||||||
ctx.dx_width = dx_width
|
|
||||||
ctx.dw_width = dw_width
|
|
||||||
ctx.kernel = kernel
|
|
||||||
ctx.block_size = block_size
|
|
||||||
ctx.save_for_backward(x, w, dx_lut, dx_locks, dw_lut, dw_locks)
|
|
||||||
return y
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, dy):
|
|
||||||
# retrieve information in context
|
|
||||||
x, w, dx_lut, dx_locks, dw_lut, dw_locks = ctx.saved_tensors
|
|
||||||
dx_width = ctx.dx_width
|
|
||||||
dw_width = ctx.dw_width
|
|
||||||
block_size = ctx.block_size
|
|
||||||
kernel = ctx.kernel
|
|
||||||
# shapes
|
|
||||||
M, N = dy.size()
|
|
||||||
_, K = x.size()
|
|
||||||
dtype = x.dtype
|
|
||||||
################
|
|
||||||
# input gradient
|
|
||||||
################
|
|
||||||
dx = None
|
|
||||||
if ctx.needs_input_grad[0]:
|
|
||||||
# create kernel
|
|
||||||
key = (dtype, block_size)
|
|
||||||
if key not in dot.dx_kernel:
|
|
||||||
defines = {'TM': 64, 'TN': block_size, 'TK': block_size, 'TYPE': dtype,
|
|
||||||
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
|
|
||||||
'STRIDE_BN': 'ldb', 'STRIDE_BK': '1'}
|
|
||||||
dot.dx_kernel[key] = triton.kernel(dot.src, defines=defines)
|
|
||||||
kernel = dot.dx_kernel[key]
|
|
||||||
# allocate output
|
|
||||||
dx = torch.empty_like(x)
|
|
||||||
# launch kernel
|
|
||||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), dx_width]
|
|
||||||
kernel(dy, w, dx, N, N, K, M, N, dx_lut, dx_locks, dx_locks.size(1), grid=grid)
|
|
||||||
#################
|
|
||||||
# weight gradient
|
|
||||||
#################
|
|
||||||
dw = None
|
|
||||||
if ctx.needs_input_grad[1]:
|
|
||||||
# create kernel
|
|
||||||
key = (dtype, block_size)
|
|
||||||
if key not in dot.dw_kernel:
|
|
||||||
defines = {'TM': block_size, 'TN': block_size, 'TK': 8, 'TYPE': dtype,
|
|
||||||
'STRIDE_AM': '1', 'STRIDE_AK': 'lda',
|
|
||||||
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
|
||||||
'DW': True, 'TZ': 2}
|
|
||||||
dot.dw_kernel[key] = triton.kernel(dot.src, defines=defines)
|
|
||||||
kernel = dot.dw_kernel[key]
|
|
||||||
# allocate output
|
|
||||||
dw = torch.zeros_like(w)
|
|
||||||
# launch kernel
|
|
||||||
grid = lambda opt: [dw_width, opt.d('TZ')]
|
|
||||||
kernel(x, dy, dw, K, N, N, K, M, dw_lut, dw_locks, dw_locks.size(1), grid=grid)
|
|
||||||
# done
|
|
||||||
return dx, dw, None,\
|
|
||||||
None, None, None,\
|
|
||||||
None, None, None,\
|
|
||||||
None, None, None
|
|
||||||
linear = _linear.apply
|
|
||||||
|
|
||||||
class Linear(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, block_size, mask):
|
|
||||||
super(Linear, self).__init__()
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_features = out_features
|
|
||||||
self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
|
|
||||||
self.reset_parameter()
|
|
||||||
# create look-up tables
|
|
||||||
self.y_lut, self.y_locks, self.y_width = _linear.make_ydx_lut(mask, block_size)
|
|
||||||
self.dx_lut, self.dx_locks, self.dx_width = _linear.make_ydx_lut(mask.T, block_size)
|
|
||||||
self.dw_lut, self.dw_locks, self.dw_width = _linear.make_dw_lut(mask, M, block_size)
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
||||||
if self.bias is not None:
|
|
||||||
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
|
||||||
bound = 1 / math.sqrt(fan_in)
|
|
||||||
init.uniform_(self.bias, -bound, bound)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return linear(input, self.weight, self.block_size,
|
|
||||||
self.y_lut, self.y_locks, self.y_width,
|
|
||||||
self.dx_lut, self.dx_locks, self.dx_width,
|
|
||||||
self.dw_lut, self.dw_locks, self.dw_width)
|
|
||||||
|
|
||||||
|
|
||||||
def reference_dot(x, w, mask):
|
|
||||||
WS0, WS1 = w.size()
|
|
||||||
MS0, MS1 = mask.size()
|
|
||||||
assert WS0 % MS0 == 0
|
|
||||||
assert WS1 % MS1 == 0
|
|
||||||
block_size_0 = WS0 // MS0
|
|
||||||
block_size_1 = WS1 // MS1
|
|
||||||
assert block_size_0 == block_size_1
|
|
||||||
maskedw = w.clone()
|
|
||||||
for bi, wi in enumerate(range(0, WS0, block_size_0)):
|
|
||||||
for bj, wj in enumerate(range(0, WS1, block_size_1)):
|
|
||||||
maskedw[wi : wi+block_size_0,
|
|
||||||
wj : wj+block_size_1] *= mask[bi, bj]
|
|
||||||
return torch.matmul(x, maskedw)
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
# parameters
|
|
||||||
M, N, K = 256, 256, 256
|
|
||||||
BS = 16
|
|
||||||
# initialize inputs
|
|
||||||
mask = torch.randint(0, 2, (K//BS, N//BS))
|
|
||||||
x = torch.rand((M, K), dtype=torch.float32, requires_grad=True).cuda()
|
|
||||||
w = torch.rand((K, N), dtype=torch.float32, requires_grad=True).cuda()
|
|
||||||
x.retain_grad()
|
|
||||||
w.retain_grad()
|
|
||||||
# reference result
|
|
||||||
ry = reference_dot(x, w, mask)
|
|
||||||
dy = torch.rand_like(ry)
|
|
||||||
ry.backward(dy)
|
|
||||||
rdx = x.grad.clone()
|
|
||||||
rdw = w.grad.clone()
|
|
||||||
# reset gradients
|
|
||||||
x.grad.zero_()
|
|
||||||
w.grad.zero_()
|
|
||||||
# triton result
|
|
||||||
y_lut, y_locks, y_width = _linear.make_ydx_lut(mask, BS)
|
|
||||||
dx_lut, dx_locks, dx_width = _linear.make_ydx_lut(mask.T, BS)
|
|
||||||
dw_lut, dw_locks, dw_width = _linear.make_dw_lut(mask, M, BS)
|
|
||||||
ty = _linear.apply(x, w, BS,
|
|
||||||
y_lut, y_locks, y_width,
|
|
||||||
dx_lut, dx_locks, dx_width,
|
|
||||||
dw_lut, dw_locks, dw_width)
|
|
||||||
ty.backward(dy)
|
|
||||||
tdx = x.grad.clone()
|
|
||||||
tdw = w.grad.clone()
|
|
||||||
# test
|
|
||||||
print((ty - ry).abs().max())
|
|
||||||
print((tdx - rdx).abs().max())
|
|
||||||
print((tdw - rdw).abs().max())
|
|
Reference in New Issue
Block a user