[PYTHON][OPS][EINSUM] Added support for masked accumulator

This commit is contained in:
Philippe Tillet
2020-02-13 18:11:12 -05:00
committed by Philippe Tillet
parent 69326ae6de
commit fa4ec7ea65

View File

@@ -202,7 +202,7 @@ __global__ void {name}(
int *pbdelta[TK] = BD + offbdelta;"""
src += f"""
// prefetch
bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
@@ -214,7 +214,12 @@ __global__ void {name}(
// accumulate
float acc[TM, TN, TB] = 0;
for(int k = matmul_k; k > 0; k -= TK) {{
acc += a @ b;"""
acc += a @ b;
#ifdef MASK
uint32 bits[TM, TN, TB] = bitcast<uint32[TM,TN,TB]>(acc);
acc = bitcast<TYPE[TM, TN, TB]>(bits & MASK);
#endif
"""
if not use_lut_a or not use_lut_b:
src += f"""
@@ -447,7 +452,7 @@ __global__ void {name}(
locks = None
kernel_cache = dict()
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays):
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays, mask):
# parse symbols
expr_a, expr_bc = einsum.split(",")
expr_b, expr_c = expr_bc.split("->")
@@ -546,7 +551,9 @@ __global__ void {name}(
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype }
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
if mask:
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
# information on compute
self.dtype = dtype
self.flops = 2 * B * M * N * K
@@ -578,6 +585,7 @@ __global__ void {name}(
def forward(ctx, einsum, a, b, shape_c, **kwargs):
bench = kwargs['bench'] if 'bench' in kwargs else False
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
mask = kwargs['mask'] if 'mask' in kwargs else None
# allocate output
dtype = a.dtype
c = triton.empty(shape_c, dtype=dtype)
@@ -585,11 +593,13 @@ __global__ void {name}(
cache = _einsum.instance_cache
key = (einsum, dtype,
a.stride(), b.stride(), c.stride(),
a.shape, b.shape, c.shape)
a.shape, b.shape, c.shape,
mask)
if key not in cache:
cache[key] = _einsum.instance(einsum, dtype,
a.stride(), b.stride(), c.stride(),
a.shape, b.shape, c.shape, arrays)
a.shape, b.shape, c.shape, arrays,
mask)
instance = cache[key]
speed = instance.run(a, b, c, bench)
# save information in context
@@ -611,7 +621,7 @@ __global__ void {name}(
############################
@staticmethod
def sym_invert(sym_c, sym_x, prefix, renamed, inverse):
def sym_invert(sym_c, sym_x, prefix, renamed, inverse, mask):
for i, expr in enumerate(sym_x):
if expr.is_symbol:
continue