diff --git a/python/triton/ops/batchnorm.py b/python/triton/ops/batchnorm.py index fb6d94017..f21faabdb 100644 --- a/python/triton/ops/batchnorm.py +++ b/python/triton/ops/batchnorm.py @@ -13,22 +13,9 @@ void fwdbatchnorm(float *Y, float *M, float *V, float *px[TM] = X + rm + c*N; float* py[TM] = Y + rm + c*N; - // compute mean - float accm[TM] = 0; - for(int i = 0; i < N; i = i + TM) - accm = accm + *(px + i); - float mean = (float)accm[+] / N; - *(M + c) = mean; - - // compute variance - float accv[TM] = 0; - for(int i = 0; i < N; i = i + TM){ - float x[TM] = *(px + i); - x = x - mean; - accv = accv + x*x; - } - float var = (float)accv[+] / N; - *(V + c) = var; + // fetch mean/var + float mean = *(M + c); + float var = *(V + c); // Normalize batch float gamma = *(G + c);