From 9fda39f64c1fd2a29a91125780579a721fd8bcd6 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 5 Mar 2020 13:32:42 -0500 Subject: [PATCH] [PYTHON][EXAMPLES] Removed BlockSparse examples; see https://github.com/ptillet/torch-blocksparse.git --- python/examples/blocksparse.py | 364 --------------------------------- 1 file changed, 364 deletions(-) delete mode 100644 python/examples/blocksparse.py diff --git a/python/examples/blocksparse.py b/python/examples/blocksparse.py deleted file mode 100644 index c7f1c451a..000000000 --- a/python/examples/blocksparse.py +++ /dev/null @@ -1,364 +0,0 @@ -import triton -import torch - -class _linear(torch.autograd.Function): - - src = ''' - __global__ void main (TYPE* A __readonly __noalias __aligned(16), - TYPE* B __readonly __noalias __aligned(16), - TYPE* C __writeonly __noalias __aligned(16), - int lda, int ldb, int ldc, - int M, int Kmax, - int* lut, - int* locks, int nlocks) { - /* ---------------- */ - /* Prologue */ - /* ---------------- */ - // program ids - int pid0 = get_program_id(0); - int pid1 = get_program_id(1); -#ifdef DW - // load LUT header - int *header = lut + pid0 * 2; - int i = *(header + 0); - int j = *(header + 1); - int K = Kmax / TZ; - int lockid = select(TZ > 1, 1, 0); - int offk = pid1 * K; - int offm = i * TM; - int offn = j * TN; - int maxid = get_num_programs(1); -#else - // load LUT header - int *header = lut + pid1 * 5; - int offset = *(header + 0); - int K = *(header + 1); - int column = *(header + 2); - int lockid = *(header + 3); - int maxid = *(header + 4); - int *pinc = lut + offset; - int offk = (*pinc) * TK; - int offm = pid0 * TM; - int offn = column * TN; -#endif - // initialize a, b pointers - int rka[TK] = offk + 0 ... TK; - int rkb[TK] = offk + 0 ... TK; - int ram[TM] = offm + (0 ... TM); - int rbn[TN] = offn + (0 ... TN); - TYPE* pa[TM, TK] = A + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK; - TYPE* pb[TK, TN] = B + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK; - // pre-fetch - bool checka[TM, TK] = ram[:, newaxis] < M; - bool checkb[TK, TN] = 1; - TYPE a[TM, TK] = checka ? *pa : 0; - TYPE b[TK, TN] = checkb ? *pb : 0; - - /* ---------------- */ - /* Inner Loop */ - /* ---------------- */ - // create result tile - float acc[TM, TN] = 0; -#ifdef DW - int step = TK; -#else - int step = 1; -#endif - for(int k = K; k > 0; k -= step) { - acc += a @ b; - // update pointers -#ifdef DW - int inc_a = TK * STRIDE_AK; - int inc_b = TK * STRIDE_BK; -#else - pinc += 1; - int inc_a = (*pinc) * TK * STRIDE_AK; - int inc_b = (*pinc) * TK * STRIDE_BK; -#endif - pa += inc_a; - pb += inc_b; - // pre-fetch - bool checka[TM, TK] = k > 1; - bool checkb[TK, TN] = k > 1; - a = *?(checka)pa; - b = *?(checkb)pb; - } - TYPE c[TM, TN] = acc; - - /* ---------------- */ - /* Epilogue */ - /* ---------------- */ - // initialize c pointers - int rcm[TM] = offm + (0 ... TM); - int rcn[TN] = offn + (0 ... TN); - TYPE* pc[TM, TN] = C + rcm[:, newaxis]*ldc + rcn[newaxis, :]; - bool checkc[TM, TN] = rcm[:, newaxis] < M; - // write-back directly - if(lockid == 0) { - *?(checkc) pc = c; - } - // accumulate partial result using spin-locks - else { - int *plock = locks + get_program_id(0)*nlocks + lockid - 1; - int *pcount = plock + get_num_programs(0)*nlocks; - for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); - int count = *pcount; - if(count == 0) - *?(checkc) pc = c; - else - *?(checkc) pc = c + *?(checkc)pc; - atomic_xchg(pcount, (count + 1) % maxid); - atomic_xchg(plock, 0); - } - } -''' - - # dictionaries for cached triton kernels - y_kernel = dict() - dx_kernel = dict() - dw_kernel = dict() - - # Given an array sizes representing reduction size for each - # column of a block-sparse matrix multiplication, - # performs load-balancing to achieve more smaller reductions - # of size seg_size - @staticmethod - def load_balance(sizes, seg_size=8): - div = sizes // seg_size - rem = sizes % seg_size - packs = div + (rem != 0).long() - width = packs.sum() - # split reduction into segments - segments = torch.empty(width, dtype=sizes.dtype) - column = torch.empty_like(segments) - lockid = torch.zeros_like(segments) - maxid = torch.zeros_like(segments) - nlocks = 0 - current = 0 - col_idx = 0 - for i in range(len(sizes)): - d, r = div[i], rem[i] - last = current + d + (r > 0) - # column id - column[current:last] = col_idx - # lock id - if d > 1 or (d == 1 and r > 0): - nlocks += 1 - lockid[current:last] = nlocks - maxid[current:last] = last - current - # segment size - segments[current:current+d] = seg_size - if r > 0: - segments[current+d] = r - current = last - col_idx += 1 - offsets = torch.zeros_like(segments) - offsets[1:] = torch.cumsum(segments[:-1], dim=0) - return segments, column, lockid, maxid, offsets - - # Given a binary mask of 0s and 1s, - # Construct look-up table for efficient execution on GPUs - @staticmethod - def make_ydx_lut(mask, block_size): - # offsets in lookup table - sizes = torch.sum(mask, 0) - offsets = torch.zeros_like(sizes) - offsets[1:] = torch.cumsum(sizes[:-1], dim=0) - # load-balancing - segments, column, lockid, maxid, offsets = dot.load_balance(sizes) - # pointer increments - nnz = torch.nonzero(mask.T) - idx = nnz[:, 1] - incs = idx.clone() - incs[1:] -= idx[:-1] - incs[offsets] = idx[offsets] - # create header - width = column.size(0) - offsets += 5*width - header = torch.stack((offsets, segments, column, lockid, maxid), dim=1).view(-1).contiguous() - # create lut - lut = torch.cat((header, incs)).type(torch.int32).cuda() - # create locks - num_locks = max(1, lockid.max()) - locks = torch.zeros((2*mask.size(0), num_locks), dtype=torch.int32).cuda() - return lut, locks, width - - @staticmethod - def make_dw_lut(mask, depth, block_size): - nnz = torch.nonzero(mask) - # create lut - width = nnz.size(0) - i = nnz[:, 0] - j = nnz[:, 1] - lut = torch.stack((i, j), dim=1).view(-1).contiguous() - lut = lut.type(torch.int32).cuda() - # create locks - num_locks = 1 - locks = torch.zeros((2*width, num_locks), dtype=torch.int32).cuda() - return lut, locks, width - - @staticmethod - def forward(ctx, x, w, block_size, - y_lut, y_locks, y_width, - dx_lut, dx_locks, dx_width, - dw_lut, dw_locks, dw_width): - M, Kx = x.size() - Kw, N = w.size() - dtype = x.dtype - # memory strides - lda = Kx - ldb = N - ldc = N - # create kernel - key = (dtype, block_size) - if key not in dot.y_kernel: - defines = {'TM': 64, 'TN': block_size, 'TK': block_size, 'TYPE': dtype, - 'STRIDE_AM': 'lda', 'STRIDE_AK': '1', - 'STRIDE_BN': '1', 'STRIDE_BK': 'ldb'} - dot.y_kernel[key] = triton.kernel(dot.src, defines=defines) - kernel = dot.y_kernel[key] - # allocate output - y = torch.empty((M, N), dtype=dtype, device=x.device) - # launch kernel - grid = lambda opt: [triton.cdiv(M, opt.d('TM')), y_width] - kernel(x, w, y, lda, ldb, ldc, M, K, y_lut, y_locks, y_locks.size(1), grid=grid) - # save information in context - ctx.dx_width = dx_width - ctx.dw_width = dw_width - ctx.kernel = kernel - ctx.block_size = block_size - ctx.save_for_backward(x, w, dx_lut, dx_locks, dw_lut, dw_locks) - return y - - @staticmethod - def backward(ctx, dy): - # retrieve information in context - x, w, dx_lut, dx_locks, dw_lut, dw_locks = ctx.saved_tensors - dx_width = ctx.dx_width - dw_width = ctx.dw_width - block_size = ctx.block_size - kernel = ctx.kernel - # shapes - M, N = dy.size() - _, K = x.size() - dtype = x.dtype - ################ - # input gradient - ################ - dx = None - if ctx.needs_input_grad[0]: - # create kernel - key = (dtype, block_size) - if key not in dot.dx_kernel: - defines = {'TM': 64, 'TN': block_size, 'TK': block_size, 'TYPE': dtype, - 'STRIDE_AM': 'lda', 'STRIDE_AK': '1', - 'STRIDE_BN': 'ldb', 'STRIDE_BK': '1'} - dot.dx_kernel[key] = triton.kernel(dot.src, defines=defines) - kernel = dot.dx_kernel[key] - # allocate output - dx = torch.empty_like(x) - # launch kernel - grid = lambda opt: [triton.cdiv(M, opt.d('TM')), dx_width] - kernel(dy, w, dx, N, N, K, M, N, dx_lut, dx_locks, dx_locks.size(1), grid=grid) - ################# - # weight gradient - ################# - dw = None - if ctx.needs_input_grad[1]: - # create kernel - key = (dtype, block_size) - if key not in dot.dw_kernel: - defines = {'TM': block_size, 'TN': block_size, 'TK': 8, 'TYPE': dtype, - 'STRIDE_AM': '1', 'STRIDE_AK': 'lda', - 'STRIDE_BN': '1', 'STRIDE_BK': 'ldb', - 'DW': True, 'TZ': 2} - dot.dw_kernel[key] = triton.kernel(dot.src, defines=defines) - kernel = dot.dw_kernel[key] - # allocate output - dw = torch.zeros_like(w) - # launch kernel - grid = lambda opt: [dw_width, opt.d('TZ')] - kernel(x, dy, dw, K, N, N, K, M, dw_lut, dw_locks, dw_locks.size(1), grid=grid) - # done - return dx, dw, None,\ - None, None, None,\ - None, None, None,\ - None, None, None -linear = _linear.apply - -class Linear(torch.nn.Module): - - def __init__(self, in_features, out_features, block_size, mask): - super(Linear, self).__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) - self.reset_parameter() - # create look-up tables - self.y_lut, self.y_locks, self.y_width = _linear.make_ydx_lut(mask, block_size) - self.dx_lut, self.dx_locks, self.dx_width = _linear.make_ydx_lut(mask.T, block_size) - self.dw_lut, self.dw_locks, self.dw_width = _linear.make_dw_lut(mask, M, block_size) - - def reset_parameters(self): - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) - init.uniform_(self.bias, -bound, bound) - - def forward(self, input): - return linear(input, self.weight, self.block_size, - self.y_lut, self.y_locks, self.y_width, - self.dx_lut, self.dx_locks, self.dx_width, - self.dw_lut, self.dw_locks, self.dw_width) - - -def reference_dot(x, w, mask): - WS0, WS1 = w.size() - MS0, MS1 = mask.size() - assert WS0 % MS0 == 0 - assert WS1 % MS1 == 0 - block_size_0 = WS0 // MS0 - block_size_1 = WS1 // MS1 - assert block_size_0 == block_size_1 - maskedw = w.clone() - for bi, wi in enumerate(range(0, WS0, block_size_0)): - for bj, wj in enumerate(range(0, WS1, block_size_1)): - maskedw[wi : wi+block_size_0, - wj : wj+block_size_1] *= mask[bi, bj] - return torch.matmul(x, maskedw) - -torch.manual_seed(0) -# parameters -M, N, K = 256, 256, 256 -BS = 16 -# initialize inputs -mask = torch.randint(0, 2, (K//BS, N//BS)) -x = torch.rand((M, K), dtype=torch.float32, requires_grad=True).cuda() -w = torch.rand((K, N), dtype=torch.float32, requires_grad=True).cuda() -x.retain_grad() -w.retain_grad() -# reference result -ry = reference_dot(x, w, mask) -dy = torch.rand_like(ry) -ry.backward(dy) -rdx = x.grad.clone() -rdw = w.grad.clone() -# reset gradients -x.grad.zero_() -w.grad.zero_() -# triton result -y_lut, y_locks, y_width = _linear.make_ydx_lut(mask, BS) -dx_lut, dx_locks, dx_width = _linear.make_ydx_lut(mask.T, BS) -dw_lut, dw_locks, dw_width = _linear.make_dw_lut(mask, M, BS) -ty = _linear.apply(x, w, BS, - y_lut, y_locks, y_width, - dx_lut, dx_locks, dx_width, - dw_lut, dw_locks, dw_width) -ty.backward(dy) -tdx = x.grad.clone() -tdw = w.grad.clone() -# test -print((ty - ry).abs().max()) -print((tdx - rdx).abs().max()) -print((tdw - rdw).abs().max()) \ No newline at end of file