[PYTHON][OPS][EINSUM] Added support for masked accumulator
This commit is contained in:
committed by
Philippe Tillet
parent
69326ae6de
commit
fa4ec7ea65
@@ -202,7 +202,7 @@ __global__ void {name}(
|
||||
int *pbdelta[TK] = BD + offbdelta;"""
|
||||
|
||||
src += f"""
|
||||
|
||||
|
||||
// prefetch
|
||||
bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
|
||||
bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
|
||||
@@ -214,7 +214,12 @@ __global__ void {name}(
|
||||
// accumulate
|
||||
float acc[TM, TN, TB] = 0;
|
||||
for(int k = matmul_k; k > 0; k -= TK) {{
|
||||
acc += a @ b;"""
|
||||
acc += a @ b;
|
||||
#ifdef MASK
|
||||
uint32 bits[TM, TN, TB] = bitcast<uint32[TM,TN,TB]>(acc);
|
||||
acc = bitcast<TYPE[TM, TN, TB]>(bits & MASK);
|
||||
#endif
|
||||
"""
|
||||
|
||||
if not use_lut_a or not use_lut_b:
|
||||
src += f"""
|
||||
@@ -447,7 +452,7 @@ __global__ void {name}(
|
||||
locks = None
|
||||
kernel_cache = dict()
|
||||
|
||||
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays):
|
||||
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays, mask):
|
||||
# parse symbols
|
||||
expr_a, expr_bc = einsum.split(",")
|
||||
expr_b, expr_c = expr_bc.split("->")
|
||||
@@ -546,7 +551,9 @@ __global__ void {name}(
|
||||
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
|
||||
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
|
||||
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
|
||||
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype }
|
||||
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||
if mask:
|
||||
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
||||
# information on compute
|
||||
self.dtype = dtype
|
||||
self.flops = 2 * B * M * N * K
|
||||
@@ -578,6 +585,7 @@ __global__ void {name}(
|
||||
def forward(ctx, einsum, a, b, shape_c, **kwargs):
|
||||
bench = kwargs['bench'] if 'bench' in kwargs else False
|
||||
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
|
||||
mask = kwargs['mask'] if 'mask' in kwargs else None
|
||||
# allocate output
|
||||
dtype = a.dtype
|
||||
c = triton.empty(shape_c, dtype=dtype)
|
||||
@@ -585,11 +593,13 @@ __global__ void {name}(
|
||||
cache = _einsum.instance_cache
|
||||
key = (einsum, dtype,
|
||||
a.stride(), b.stride(), c.stride(),
|
||||
a.shape, b.shape, c.shape)
|
||||
a.shape, b.shape, c.shape,
|
||||
mask)
|
||||
if key not in cache:
|
||||
cache[key] = _einsum.instance(einsum, dtype,
|
||||
a.stride(), b.stride(), c.stride(),
|
||||
a.shape, b.shape, c.shape, arrays)
|
||||
a.shape, b.shape, c.shape, arrays,
|
||||
mask)
|
||||
instance = cache[key]
|
||||
speed = instance.run(a, b, c, bench)
|
||||
# save information in context
|
||||
@@ -611,7 +621,7 @@ __global__ void {name}(
|
||||
############################
|
||||
|
||||
@staticmethod
|
||||
def sym_invert(sym_c, sym_x, prefix, renamed, inverse):
|
||||
def sym_invert(sym_c, sym_x, prefix, renamed, inverse, mask):
|
||||
for i, expr in enumerate(sym_x):
|
||||
if expr.is_symbol:
|
||||
continue
|
||||
|
Reference in New Issue
Block a user