From fa4ec7ea65a4a040dc495e6a4e0f1951ac9e2d1b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 13 Feb 2020 18:11:12 -0500 Subject: [PATCH] [PYTHON][OPS][EINSUM] Added support for masked accumulator --- python/triton/ops/einsum.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 8ebf9f439..4bce9729d 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -202,7 +202,7 @@ __global__ void {name}( int *pbdelta[TK] = BD + offbdelta;""" src += f""" - + // prefetch bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m; bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n; @@ -214,7 +214,12 @@ __global__ void {name}( // accumulate float acc[TM, TN, TB] = 0; for(int k = matmul_k; k > 0; k -= TK) {{ - acc += a @ b;""" + acc += a @ b; + #ifdef MASK + uint32 bits[TM, TN, TB] = bitcast(acc); + acc = bitcast(bits & MASK); + #endif + """ if not use_lut_a or not use_lut_b: src += f""" @@ -447,7 +452,7 @@ __global__ void {name}( locks = None kernel_cache = dict() - def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays): + def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays, mask): # parse symbols expr_a, expr_bc = einsum.split(",") expr_b, expr_c = expr_bc.split("->") @@ -546,7 +551,9 @@ __global__ void {name}( TZ = [x for x in [1, 2, 4, 8, 16, 32] \ if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256] TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2] - self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype } + self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype} + if mask: + self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10) # information on compute self.dtype = dtype self.flops = 2 * B * M * N * K @@ -578,6 +585,7 @@ __global__ void {name}( def forward(ctx, einsum, a, b, shape_c, **kwargs): bench = kwargs['bench'] if 'bench' in kwargs else False arrays = kwargs['arrays'] if 'arrays' in kwargs else dict() + mask = kwargs['mask'] if 'mask' in kwargs else None # allocate output dtype = a.dtype c = triton.empty(shape_c, dtype=dtype) @@ -585,11 +593,13 @@ __global__ void {name}( cache = _einsum.instance_cache key = (einsum, dtype, a.stride(), b.stride(), c.stride(), - a.shape, b.shape, c.shape) + a.shape, b.shape, c.shape, + mask) if key not in cache: cache[key] = _einsum.instance(einsum, dtype, a.stride(), b.stride(), c.stride(), - a.shape, b.shape, c.shape, arrays) + a.shape, b.shape, c.shape, arrays, + mask) instance = cache[key] speed = instance.run(a, b, c, bench) # save information in context @@ -611,7 +621,7 @@ __global__ void {name}( ############################ @staticmethod - def sym_invert(sym_c, sym_x, prefix, renamed, inverse): + def sym_invert(sym_c, sym_x, prefix, renamed, inverse, mask): for i, expr in enumerate(sym_x): if expr.is_symbol: continue