diff --git a/include/triton/runtime/arg.h b/include/triton/runtime/arg.h index 1a741077c..4c932b8a4 100644 --- a/include/triton/runtime/arg.h +++ b/include/triton/runtime/arg.h @@ -75,6 +75,7 @@ public: // accessors arg_type type() const { return ty_; } void* data() const { return (void*)&val_; } + driver::buffer* buffer() const { return val_.buf; } private: diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index c5b109f6a..85424acde 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include "triton/codegen/analysis/axes.h" #include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/liveness.h" @@ -343,6 +344,17 @@ function::cache_key_t function::get_key(driver::stream *stream, const std::vecto // returns program with best compilation options for given parameter function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn, const std::vector& args) { +// // copy buffer argument so that auto-tuning doesn't corrupt data +// std::list> copies; +// std::vector args = args; +// for(arg& x: args) +// if(x.type() == BUFFER_T){ +// driver::buffer* old = x.buffer(); +// driver::context* ctx = old->context(); +// size_t size = old->size(); +// copies.push_back(std::make_shared(ctx, size)); +// x = arg(copies.back().get()); +// } // fast path -- no autotuning necessary if(callers_.size() == 1) return &*callers_.begin()->second; diff --git a/python/examples/blocksparse.py b/python/examples/blocksparse.py new file mode 100644 index 000000000..c7f1c451a --- /dev/null +++ b/python/examples/blocksparse.py @@ -0,0 +1,364 @@ +import triton +import torch + +class _linear(torch.autograd.Function): + + src = ''' + __global__ void main (TYPE* A __readonly __noalias __aligned(16), + TYPE* B __readonly __noalias __aligned(16), + TYPE* C __writeonly __noalias __aligned(16), + int lda, int ldb, int ldc, + int M, int Kmax, + int* lut, + int* locks, int nlocks) { + /* ---------------- */ + /* Prologue */ + /* ---------------- */ + // program ids + int pid0 = get_program_id(0); + int pid1 = get_program_id(1); +#ifdef DW + // load LUT header + int *header = lut + pid0 * 2; + int i = *(header + 0); + int j = *(header + 1); + int K = Kmax / TZ; + int lockid = select(TZ > 1, 1, 0); + int offk = pid1 * K; + int offm = i * TM; + int offn = j * TN; + int maxid = get_num_programs(1); +#else + // load LUT header + int *header = lut + pid1 * 5; + int offset = *(header + 0); + int K = *(header + 1); + int column = *(header + 2); + int lockid = *(header + 3); + int maxid = *(header + 4); + int *pinc = lut + offset; + int offk = (*pinc) * TK; + int offm = pid0 * TM; + int offn = column * TN; +#endif + // initialize a, b pointers + int rka[TK] = offk + 0 ... TK; + int rkb[TK] = offk + 0 ... TK; + int ram[TM] = offm + (0 ... TM); + int rbn[TN] = offn + (0 ... TN); + TYPE* pa[TM, TK] = A + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK; + TYPE* pb[TK, TN] = B + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK; + // pre-fetch + bool checka[TM, TK] = ram[:, newaxis] < M; + bool checkb[TK, TN] = 1; + TYPE a[TM, TK] = checka ? *pa : 0; + TYPE b[TK, TN] = checkb ? *pb : 0; + + /* ---------------- */ + /* Inner Loop */ + /* ---------------- */ + // create result tile + float acc[TM, TN] = 0; +#ifdef DW + int step = TK; +#else + int step = 1; +#endif + for(int k = K; k > 0; k -= step) { + acc += a @ b; + // update pointers +#ifdef DW + int inc_a = TK * STRIDE_AK; + int inc_b = TK * STRIDE_BK; +#else + pinc += 1; + int inc_a = (*pinc) * TK * STRIDE_AK; + int inc_b = (*pinc) * TK * STRIDE_BK; +#endif + pa += inc_a; + pb += inc_b; + // pre-fetch + bool checka[TM, TK] = k > 1; + bool checkb[TK, TN] = k > 1; + a = *?(checka)pa; + b = *?(checkb)pb; + } + TYPE c[TM, TN] = acc; + + /* ---------------- */ + /* Epilogue */ + /* ---------------- */ + // initialize c pointers + int rcm[TM] = offm + (0 ... TM); + int rcn[TN] = offn + (0 ... TN); + TYPE* pc[TM, TN] = C + rcm[:, newaxis]*ldc + rcn[newaxis, :]; + bool checkc[TM, TN] = rcm[:, newaxis] < M; + // write-back directly + if(lockid == 0) { + *?(checkc) pc = c; + } + // accumulate partial result using spin-locks + else { + int *plock = locks + get_program_id(0)*nlocks + lockid - 1; + int *pcount = plock + get_num_programs(0)*nlocks; + for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); + int count = *pcount; + if(count == 0) + *?(checkc) pc = c; + else + *?(checkc) pc = c + *?(checkc)pc; + atomic_xchg(pcount, (count + 1) % maxid); + atomic_xchg(plock, 0); + } + } +''' + + # dictionaries for cached triton kernels + y_kernel = dict() + dx_kernel = dict() + dw_kernel = dict() + + # Given an array sizes representing reduction size for each + # column of a block-sparse matrix multiplication, + # performs load-balancing to achieve more smaller reductions + # of size seg_size + @staticmethod + def load_balance(sizes, seg_size=8): + div = sizes // seg_size + rem = sizes % seg_size + packs = div + (rem != 0).long() + width = packs.sum() + # split reduction into segments + segments = torch.empty(width, dtype=sizes.dtype) + column = torch.empty_like(segments) + lockid = torch.zeros_like(segments) + maxid = torch.zeros_like(segments) + nlocks = 0 + current = 0 + col_idx = 0 + for i in range(len(sizes)): + d, r = div[i], rem[i] + last = current + d + (r > 0) + # column id + column[current:last] = col_idx + # lock id + if d > 1 or (d == 1 and r > 0): + nlocks += 1 + lockid[current:last] = nlocks + maxid[current:last] = last - current + # segment size + segments[current:current+d] = seg_size + if r > 0: + segments[current+d] = r + current = last + col_idx += 1 + offsets = torch.zeros_like(segments) + offsets[1:] = torch.cumsum(segments[:-1], dim=0) + return segments, column, lockid, maxid, offsets + + # Given a binary mask of 0s and 1s, + # Construct look-up table for efficient execution on GPUs + @staticmethod + def make_ydx_lut(mask, block_size): + # offsets in lookup table + sizes = torch.sum(mask, 0) + offsets = torch.zeros_like(sizes) + offsets[1:] = torch.cumsum(sizes[:-1], dim=0) + # load-balancing + segments, column, lockid, maxid, offsets = dot.load_balance(sizes) + # pointer increments + nnz = torch.nonzero(mask.T) + idx = nnz[:, 1] + incs = idx.clone() + incs[1:] -= idx[:-1] + incs[offsets] = idx[offsets] + # create header + width = column.size(0) + offsets += 5*width + header = torch.stack((offsets, segments, column, lockid, maxid), dim=1).view(-1).contiguous() + # create lut + lut = torch.cat((header, incs)).type(torch.int32).cuda() + # create locks + num_locks = max(1, lockid.max()) + locks = torch.zeros((2*mask.size(0), num_locks), dtype=torch.int32).cuda() + return lut, locks, width + + @staticmethod + def make_dw_lut(mask, depth, block_size): + nnz = torch.nonzero(mask) + # create lut + width = nnz.size(0) + i = nnz[:, 0] + j = nnz[:, 1] + lut = torch.stack((i, j), dim=1).view(-1).contiguous() + lut = lut.type(torch.int32).cuda() + # create locks + num_locks = 1 + locks = torch.zeros((2*width, num_locks), dtype=torch.int32).cuda() + return lut, locks, width + + @staticmethod + def forward(ctx, x, w, block_size, + y_lut, y_locks, y_width, + dx_lut, dx_locks, dx_width, + dw_lut, dw_locks, dw_width): + M, Kx = x.size() + Kw, N = w.size() + dtype = x.dtype + # memory strides + lda = Kx + ldb = N + ldc = N + # create kernel + key = (dtype, block_size) + if key not in dot.y_kernel: + defines = {'TM': 64, 'TN': block_size, 'TK': block_size, 'TYPE': dtype, + 'STRIDE_AM': 'lda', 'STRIDE_AK': '1', + 'STRIDE_BN': '1', 'STRIDE_BK': 'ldb'} + dot.y_kernel[key] = triton.kernel(dot.src, defines=defines) + kernel = dot.y_kernel[key] + # allocate output + y = torch.empty((M, N), dtype=dtype, device=x.device) + # launch kernel + grid = lambda opt: [triton.cdiv(M, opt.d('TM')), y_width] + kernel(x, w, y, lda, ldb, ldc, M, K, y_lut, y_locks, y_locks.size(1), grid=grid) + # save information in context + ctx.dx_width = dx_width + ctx.dw_width = dw_width + ctx.kernel = kernel + ctx.block_size = block_size + ctx.save_for_backward(x, w, dx_lut, dx_locks, dw_lut, dw_locks) + return y + + @staticmethod + def backward(ctx, dy): + # retrieve information in context + x, w, dx_lut, dx_locks, dw_lut, dw_locks = ctx.saved_tensors + dx_width = ctx.dx_width + dw_width = ctx.dw_width + block_size = ctx.block_size + kernel = ctx.kernel + # shapes + M, N = dy.size() + _, K = x.size() + dtype = x.dtype + ################ + # input gradient + ################ + dx = None + if ctx.needs_input_grad[0]: + # create kernel + key = (dtype, block_size) + if key not in dot.dx_kernel: + defines = {'TM': 64, 'TN': block_size, 'TK': block_size, 'TYPE': dtype, + 'STRIDE_AM': 'lda', 'STRIDE_AK': '1', + 'STRIDE_BN': 'ldb', 'STRIDE_BK': '1'} + dot.dx_kernel[key] = triton.kernel(dot.src, defines=defines) + kernel = dot.dx_kernel[key] + # allocate output + dx = torch.empty_like(x) + # launch kernel + grid = lambda opt: [triton.cdiv(M, opt.d('TM')), dx_width] + kernel(dy, w, dx, N, N, K, M, N, dx_lut, dx_locks, dx_locks.size(1), grid=grid) + ################# + # weight gradient + ################# + dw = None + if ctx.needs_input_grad[1]: + # create kernel + key = (dtype, block_size) + if key not in dot.dw_kernel: + defines = {'TM': block_size, 'TN': block_size, 'TK': 8, 'TYPE': dtype, + 'STRIDE_AM': '1', 'STRIDE_AK': 'lda', + 'STRIDE_BN': '1', 'STRIDE_BK': 'ldb', + 'DW': True, 'TZ': 2} + dot.dw_kernel[key] = triton.kernel(dot.src, defines=defines) + kernel = dot.dw_kernel[key] + # allocate output + dw = torch.zeros_like(w) + # launch kernel + grid = lambda opt: [dw_width, opt.d('TZ')] + kernel(x, dy, dw, K, N, N, K, M, dw_lut, dw_locks, dw_locks.size(1), grid=grid) + # done + return dx, dw, None,\ + None, None, None,\ + None, None, None,\ + None, None, None +linear = _linear.apply + +class Linear(torch.nn.Module): + + def __init__(self, in_features, out_features, block_size, mask): + super(Linear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) + self.reset_parameter() + # create look-up tables + self.y_lut, self.y_locks, self.y_width = _linear.make_ydx_lut(mask, block_size) + self.dx_lut, self.dx_locks, self.dx_width = _linear.make_ydx_lut(mask.T, block_size) + self.dw_lut, self.dw_locks, self.dw_width = _linear.make_dw_lut(mask, M, block_size) + + def reset_parameters(self): + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + return linear(input, self.weight, self.block_size, + self.y_lut, self.y_locks, self.y_width, + self.dx_lut, self.dx_locks, self.dx_width, + self.dw_lut, self.dw_locks, self.dw_width) + + +def reference_dot(x, w, mask): + WS0, WS1 = w.size() + MS0, MS1 = mask.size() + assert WS0 % MS0 == 0 + assert WS1 % MS1 == 0 + block_size_0 = WS0 // MS0 + block_size_1 = WS1 // MS1 + assert block_size_0 == block_size_1 + maskedw = w.clone() + for bi, wi in enumerate(range(0, WS0, block_size_0)): + for bj, wj in enumerate(range(0, WS1, block_size_1)): + maskedw[wi : wi+block_size_0, + wj : wj+block_size_1] *= mask[bi, bj] + return torch.matmul(x, maskedw) + +torch.manual_seed(0) +# parameters +M, N, K = 256, 256, 256 +BS = 16 +# initialize inputs +mask = torch.randint(0, 2, (K//BS, N//BS)) +x = torch.rand((M, K), dtype=torch.float32, requires_grad=True).cuda() +w = torch.rand((K, N), dtype=torch.float32, requires_grad=True).cuda() +x.retain_grad() +w.retain_grad() +# reference result +ry = reference_dot(x, w, mask) +dy = torch.rand_like(ry) +ry.backward(dy) +rdx = x.grad.clone() +rdw = w.grad.clone() +# reset gradients +x.grad.zero_() +w.grad.zero_() +# triton result +y_lut, y_locks, y_width = _linear.make_ydx_lut(mask, BS) +dx_lut, dx_locks, dx_width = _linear.make_ydx_lut(mask.T, BS) +dw_lut, dw_locks, dw_width = _linear.make_dw_lut(mask, M, BS) +ty = _linear.apply(x, w, BS, + y_lut, y_locks, y_width, + dx_lut, dx_locks, dx_width, + dw_lut, dw_locks, dw_width) +ty.backward(dy) +tdx = x.grad.clone() +tdw = w.grad.clone() +# test +print((ty - ry).abs().max()) +print((tdx - rdx).abs().max()) +print((tdw - rdw).abs().max()) \ No newline at end of file diff --git a/python/triton/kernel.py b/python/triton/kernel.py index fe5e0aabd..8b6011337 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -148,71 +148,41 @@ def _make_framework_op(arg_types): class kernel: - def __init__(self, src): - self.fw_id = dict() - self.fw_grids = dict() - self.fw_op = None + def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]): self.src = src + # create constants self.cst = dict() + # create triton op + macros = [] + for k, v in defines.items(): + cvt = lambda x: _cvt_to_def_str(x) + if(isinstance(v, list)): + values = list(map(cvt, v)) + else: + values = [cvt(v)] + macros.append((k, values)) + opt = libtriton.options_space() + opt.defines = macros + opt.num_warps = num_warps + self.op_id = libtriton.make_op_id() + libtriton.register_fn(self.op_id, self.src, opt, os.path.realpath(libtriton.__file__)) + # create pytorch hook + arg_types = libtriton.get_fn_signature(self.src, opt) + self.fw_op = _make_framework_op(arg_types) def set_constant(self, name, value): - self.cst[name] = value + libtriton.register_cst(self.op_id, name, value) def __call__(self, *args, **kwargs): - ######################## - # JIT Options - ######################## - num_warps = kwargs['num_warps'] if 'num_warps' in kwargs else [2, 4, 8] - defines = kwargs['defines'] if 'defines' in kwargs else dict() + # launch options bench = kwargs['bench'] if 'bench' in kwargs else 0 + bench_id = libtriton.make_scalar_id() if bench > 0 else -1 + # launch grid if 'grid' not in kwargs: raise RuntimeError('Must provide grid for kernel launch') grid = kwargs['grid'] - - - ######################### - # cache - ######################## - # create a new framework op when defines are different - key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in defines.items()]) - if key not in self.fw_id.keys(): - # code generation options - macros = [] - for k, v in defines.items(): - cvt = lambda x: _cvt_to_def_str(x) - if(isinstance(v, list)): - values = list(map(cvt, v)) - else: - values = [cvt(v)] - macros.append((k, values)) - opt = libtriton.options_space() - opt.defines = macros - opt.num_warps = num_warps - # create triton function for this op - op_id = libtriton.make_op_id() - self.fw_id[key] = op_id - libtriton.register_fn(op_id, self.src, opt, os.path.realpath(libtriton.__file__)) - for name, value in self.cst.items(): - libtriton.register_cst(op_id, name, value) - # create pytorch hook for this op - arg_types = libtriton.get_fn_signature(self.src, opt) - if self.fw_op is None: - self.fw_op = _make_framework_op(arg_types) - - ######################## - # initialize - ######################## - op_id = self.fw_id[key] - libtriton.register_grid(op_id, grid) - bench_id = libtriton.make_scalar_id() if bench > 0 else -1 - - ############################ - # call torch function - ############################ - if fw.has_torch(): - self.fw_op(op_id, bench, bench_id, *args) - if bench > 0: - return libtriton.retrieve_scalar(bench_id) - - else: - assert False \ No newline at end of file + libtriton.register_grid(self.op_id, grid) + # launch + self.fw_op(self.op_id, bench, bench_id, *args) + if bench > 0: + return libtriton.retrieve_scalar(bench_id) \ No newline at end of file diff --git a/python/triton/nn/conv.py b/python/triton/nn/conv.py index 8a6c744d8..a6966f6ae 100644 --- a/python/triton/nn/conv.py +++ b/python/triton/nn/conv.py @@ -125,7 +125,7 @@ def replace_conv2d(model, acc_bitmask = None): #torch.Size([128, 256, 8, 8]) torch.Size([512, 256, 3, 3]) if __name__ == '__main__': - N, C, H, W, K, RS = 128, 64, 30, 30, 128, 1 + N, C, H, W, K, RS = 128, 64, 30, 30, 128, 3 #N, C, H, W, K, RS = 128, 128, 15, 15, 256, 3 #N, C, H, W, K, RS = 128, 256, 8, 8, 512, 3 pad, stride = 0, 1 diff --git a/python/triton/ops/batchnorm.py b/python/triton/ops/batchnorm.py index 5b1cc806d..cc4c23ee2 100644 --- a/python/triton/ops/batchnorm.py +++ b/python/triton/ops/batchnorm.py @@ -42,7 +42,6 @@ void fwdbatchnorm(float *Y, float *M, float *V, } } """ - fwd_kernel = triton.kernel(fwd_src) bwd_src = """ void bwdbatchnorm(float *DX, float *DG, float *DB, @@ -89,10 +88,16 @@ void bwdbatchnorm(float *DX, float *DG, float *DB, } } """ - bwd_kernel = triton.kernel(bwd_src) + + fwd_kernel = None + bwd_kernel = None @staticmethod def forward(ctx, x, gamma, beta, eps): + # lazy compilation of kernel + if _batchnorm.fwd_kernel is None: + _batchnorm.fwd_kernel = triton.kernel(fwd_src, defines = {'TM': 128}) + # shapes shape = triton.shape(x) dtype = x.dtype # allocate outputs @@ -102,8 +107,7 @@ void bwdbatchnorm(float *DX, float *DG, float *DB, var = triton.empty([C], dtype=dtype) # execute kernels _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps, - grid = lambda opt: [1, C], - defines = {'TM': 128}) + grid = lambda opt: [1, C]) # save ctx.save_for_backward(x, gamma, beta, mean, var) ctx.eps = eps @@ -111,6 +115,9 @@ void bwdbatchnorm(float *DX, float *DG, float *DB, @staticmethod def backward(ctx, dy): + # lazy compilation of kernel + if _batchnorm.bwd_kernel is None: + _batchnorm.bwd_kernel = triton.kernel(bwd_src, defines = {'TN': 128}) # retrieve info x, gamma, beta, mean, var = ctx.saved_tensors eps = ctx.eps @@ -123,8 +130,7 @@ void bwdbatchnorm(float *DX, float *DG, float *DB, _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy, x, gamma, mean, var, H*W*B, eps, - grid = lambda opt: [1, C], - defines = {'TM': 128}) + grid = lambda opt: [1, C]) return dx, dgamma, dbeta, None batchnorm = _batchnorm.apply \ No newline at end of file diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index bda5d3b16..464d588ca 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -69,7 +69,7 @@ class _einsum(torch.autograd.Function): ret = dict(zip(expr, ret)) return ret - def make_kernel(name, + def make_kernel(name, dtype, mask, expr_a, expr_b, expr_c, axes_m, axes_n, axes_k, axes_b, multipleof_a, multipleof_b, multipleof_c, @@ -329,8 +329,15 @@ __global__ void {name}( #endif } """ - - ret = triton.kernel(src) + # compilation options + TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16] + TK = 16 if dtype==torch.float16 else 8 + defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype} + if mask is not None: + defines['MASK'] = '{0:#0{1}x}'.format(mask, 10) + # create kernel + ret = triton.kernel(src, defines=defines) + # set constant if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: ret.set_constant('AD', delta_a) if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT: @@ -541,14 +548,15 @@ __global__ void {name}( stride_a_last = stride_a[-1] stride_b_last = stride_b[-1] stride_c_last = stride_c[-1] - name = f'{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\ + name = f'{dtype}_{mask}_{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\ f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\ f'_{stride_a_last}_{stride_b_last}_{stride_c_last}' # recompile if necessary cache = _einsum.instance.kernel_cache if name not in cache: cachesize = len(cache) - cache[name] = _einsum.make_kernel(f'__einsum{cachesize}', + cache[name] = _einsum.make_kernel(f'__einsum{cachesize}', + dtype, mask, sym_a, sym_b, sym_c, axes_m, axes_n, axes_k, axes_b, stride_a_multiple, stride_b_multiple, stride_c_multiple, @@ -604,19 +612,6 @@ __global__ void {name}( self.pos_vars = len(self.args) self.varnames = varnames self.args += [None] * len(varnames) - # tile size ranges - MAX_GZ = triton.cdiv(K, 2048) - TMs = [16] + [x for x in [32, 64, 128] if x <= M] - TNs = [16] + [x for x in [32, 64, 128] if x <= N] - TBs = [x for x in [1, 2, 4, 8] if x <= B] - TZs = [x for x in [1, 2, 4, 8, 16, 32] if x <= MAX_GZ] - # tile sizes - TM, TN, TB, TZ = _einsum.instance._tile(M, N, B, TMs, TNs, TBs, TZs, TK) - TM, TN, TB, TZ = 64, 128, 1, 1 - self.macros = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype} - self.num_warps = [4] - if mask is not None: - self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10) # save information on the operation self.expr_a = expr_a self.expr_b = expr_b @@ -634,7 +629,7 @@ __global__ void {name}( self.args[self.pos_c] = c for i, name in enumerate(self.varnames): self.args[self.pos_vars + i] = values[name] - return self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros, num_warps=self.num_warps) + return self.kernel(*self.args, grid=self.grid, bench=bench) diff --git a/python/triton/utils.py b/python/triton/utils.py index da8a1e8f9..7112870d2 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -4,7 +4,7 @@ import numpy as np import weakref def cdiv(a, b): - return -(-a // b) + return (a + b - 1) // b class tf_empty_proxy: