[PYTHON][OPS][EINSUM] Added support for masked accumulator
This commit is contained in:
committed by
Philippe Tillet
parent
69326ae6de
commit
fa4ec7ea65
@@ -214,7 +214,12 @@ __global__ void {name}(
|
|||||||
// accumulate
|
// accumulate
|
||||||
float acc[TM, TN, TB] = 0;
|
float acc[TM, TN, TB] = 0;
|
||||||
for(int k = matmul_k; k > 0; k -= TK) {{
|
for(int k = matmul_k; k > 0; k -= TK) {{
|
||||||
acc += a @ b;"""
|
acc += a @ b;
|
||||||
|
#ifdef MASK
|
||||||
|
uint32 bits[TM, TN, TB] = bitcast<uint32[TM,TN,TB]>(acc);
|
||||||
|
acc = bitcast<TYPE[TM, TN, TB]>(bits & MASK);
|
||||||
|
#endif
|
||||||
|
"""
|
||||||
|
|
||||||
if not use_lut_a or not use_lut_b:
|
if not use_lut_a or not use_lut_b:
|
||||||
src += f"""
|
src += f"""
|
||||||
@@ -447,7 +452,7 @@ __global__ void {name}(
|
|||||||
locks = None
|
locks = None
|
||||||
kernel_cache = dict()
|
kernel_cache = dict()
|
||||||
|
|
||||||
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays):
|
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays, mask):
|
||||||
# parse symbols
|
# parse symbols
|
||||||
expr_a, expr_bc = einsum.split(",")
|
expr_a, expr_bc = einsum.split(",")
|
||||||
expr_b, expr_c = expr_bc.split("->")
|
expr_b, expr_c = expr_bc.split("->")
|
||||||
@@ -547,6 +552,8 @@ __global__ void {name}(
|
|||||||
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
|
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
|
||||||
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
|
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
|
||||||
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||||
|
if mask:
|
||||||
|
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
||||||
# information on compute
|
# information on compute
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.flops = 2 * B * M * N * K
|
self.flops = 2 * B * M * N * K
|
||||||
@@ -578,6 +585,7 @@ __global__ void {name}(
|
|||||||
def forward(ctx, einsum, a, b, shape_c, **kwargs):
|
def forward(ctx, einsum, a, b, shape_c, **kwargs):
|
||||||
bench = kwargs['bench'] if 'bench' in kwargs else False
|
bench = kwargs['bench'] if 'bench' in kwargs else False
|
||||||
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
|
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
|
||||||
|
mask = kwargs['mask'] if 'mask' in kwargs else None
|
||||||
# allocate output
|
# allocate output
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
c = triton.empty(shape_c, dtype=dtype)
|
c = triton.empty(shape_c, dtype=dtype)
|
||||||
@@ -585,11 +593,13 @@ __global__ void {name}(
|
|||||||
cache = _einsum.instance_cache
|
cache = _einsum.instance_cache
|
||||||
key = (einsum, dtype,
|
key = (einsum, dtype,
|
||||||
a.stride(), b.stride(), c.stride(),
|
a.stride(), b.stride(), c.stride(),
|
||||||
a.shape, b.shape, c.shape)
|
a.shape, b.shape, c.shape,
|
||||||
|
mask)
|
||||||
if key not in cache:
|
if key not in cache:
|
||||||
cache[key] = _einsum.instance(einsum, dtype,
|
cache[key] = _einsum.instance(einsum, dtype,
|
||||||
a.stride(), b.stride(), c.stride(),
|
a.stride(), b.stride(), c.stride(),
|
||||||
a.shape, b.shape, c.shape, arrays)
|
a.shape, b.shape, c.shape, arrays,
|
||||||
|
mask)
|
||||||
instance = cache[key]
|
instance = cache[key]
|
||||||
speed = instance.run(a, b, c, bench)
|
speed = instance.run(a, b, c, bench)
|
||||||
# save information in context
|
# save information in context
|
||||||
@@ -611,7 +621,7 @@ __global__ void {name}(
|
|||||||
############################
|
############################
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sym_invert(sym_c, sym_x, prefix, renamed, inverse):
|
def sym_invert(sym_c, sym_x, prefix, renamed, inverse, mask):
|
||||||
for i, expr in enumerate(sym_x):
|
for i, expr in enumerate(sym_x):
|
||||||
if expr.is_symbol:
|
if expr.is_symbol:
|
||||||
continue
|
continue
|
||||||
|
Reference in New Issue
Block a user