diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index a24e422dc..bda5d3b16 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -247,7 +247,7 @@ __global__ void {name}( acc += a @ b; #ifdef MASK uint32 bits[TM, TN, TB] = bitcast(acc); - acc = bitcast(bits & MASK); + acc = bitcast(bits & MASK); #endif checkk = k > TK;