[PYTHON][OPS][EINSUM] Added support for masked accumulator

This commit is contained in:
Philippe Tillet
2020-02-13 18:11:12 -05:00
committed by Philippe Tillet
parent 69326ae6de
commit fa4ec7ea65

View File

@@ -214,7 +214,12 @@ __global__ void {name}(
// accumulate // accumulate
float acc[TM, TN, TB] = 0; float acc[TM, TN, TB] = 0;
for(int k = matmul_k; k > 0; k -= TK) {{ for(int k = matmul_k; k > 0; k -= TK) {{
acc += a @ b;""" acc += a @ b;
#ifdef MASK
uint32 bits[TM, TN, TB] = bitcast<uint32[TM,TN,TB]>(acc);
acc = bitcast<TYPE[TM, TN, TB]>(bits & MASK);
#endif
"""
if not use_lut_a or not use_lut_b: if not use_lut_a or not use_lut_b:
src += f""" src += f"""
@@ -447,7 +452,7 @@ __global__ void {name}(
locks = None locks = None
kernel_cache = dict() kernel_cache = dict()
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays): def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays, mask):
# parse symbols # parse symbols
expr_a, expr_bc = einsum.split(",") expr_a, expr_bc = einsum.split(",")
expr_b, expr_c = expr_bc.split("->") expr_b, expr_c = expr_bc.split("->")
@@ -547,6 +552,8 @@ __global__ void {name}(
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256] if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2] TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype} self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
if mask:
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
# information on compute # information on compute
self.dtype = dtype self.dtype = dtype
self.flops = 2 * B * M * N * K self.flops = 2 * B * M * N * K
@@ -578,6 +585,7 @@ __global__ void {name}(
def forward(ctx, einsum, a, b, shape_c, **kwargs): def forward(ctx, einsum, a, b, shape_c, **kwargs):
bench = kwargs['bench'] if 'bench' in kwargs else False bench = kwargs['bench'] if 'bench' in kwargs else False
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict() arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
mask = kwargs['mask'] if 'mask' in kwargs else None
# allocate output # allocate output
dtype = a.dtype dtype = a.dtype
c = triton.empty(shape_c, dtype=dtype) c = triton.empty(shape_c, dtype=dtype)
@@ -585,11 +593,13 @@ __global__ void {name}(
cache = _einsum.instance_cache cache = _einsum.instance_cache
key = (einsum, dtype, key = (einsum, dtype,
a.stride(), b.stride(), c.stride(), a.stride(), b.stride(), c.stride(),
a.shape, b.shape, c.shape) a.shape, b.shape, c.shape,
mask)
if key not in cache: if key not in cache:
cache[key] = _einsum.instance(einsum, dtype, cache[key] = _einsum.instance(einsum, dtype,
a.stride(), b.stride(), c.stride(), a.stride(), b.stride(), c.stride(),
a.shape, b.shape, c.shape, arrays) a.shape, b.shape, c.shape, arrays,
mask)
instance = cache[key] instance = cache[key]
speed = instance.run(a, b, c, bench) speed = instance.run(a, b, c, bench)
# save information in context # save information in context
@@ -611,7 +621,7 @@ __global__ void {name}(
############################ ############################
@staticmethod @staticmethod
def sym_invert(sym_c, sym_x, prefix, renamed, inverse): def sym_invert(sym_c, sym_x, prefix, renamed, inverse, mask):
for i, expr in enumerate(sym_x): for i, expr in enumerate(sym_x):
if expr.is_symbol: if expr.is_symbol:
continue continue