[dnn] better specification of recompilation key
This commit is contained in:
@@ -30,7 +30,7 @@ namespace dnn{
|
||||
* --------------- */
|
||||
|
||||
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty, float eps)
|
||||
: base("batchnorm"),
|
||||
: base("batchnorm_forward"),
|
||||
C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) {
|
||||
DHWB_ = D_*H_*W_*B_;
|
||||
rcpDHWB_ = (float)1 / DHWB_;
|
||||
@@ -40,12 +40,9 @@ size_t batchnorm_forward::num_flops() const {
|
||||
return C_*DHWB_;
|
||||
}
|
||||
|
||||
bool batchnorm_forward::operator <(const base& other) const {
|
||||
auto *y = dynamic_cast<const batchnorm_forward*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(C_, D_, H_, W_, B_, ty_)
|
||||
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
|
||||
|
||||
std::vector<int64_t> batchnorm_forward::retune_params() const {
|
||||
return {C_, D_, H_, W_, B_};
|
||||
}
|
||||
|
||||
base* batchnorm_forward::clone() const {
|
||||
@@ -74,50 +71,50 @@ void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *ker
|
||||
void batchnorm_forward::triton_c_src(std::ostream &os) const {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {32, 64, 128};
|
||||
const tunable int TM = {32, 64, 128};
|
||||
|
||||
void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
||||
restrict read_only fp32 *X,
|
||||
restrict read_only fp32 *G,
|
||||
restrict read_only fp32 *B,
|
||||
int32 DHWN,
|
||||
fp32 rcpDHWN, fp32 eps) {
|
||||
int32 rx[TM] = 0 ... TM;
|
||||
fp32 *px[TM];
|
||||
fp32 x[TM];
|
||||
int32 c = get_range_id(1);
|
||||
fp32 g = *(G + c);
|
||||
fp32 b = *(B + c);
|
||||
void batchnorm_forward(float *Y, float *M, float *V,
|
||||
restrict read_only float *X,
|
||||
restrict read_only float *G,
|
||||
restrict read_only float *B,
|
||||
int DHWN,
|
||||
float rcpDHWN, float eps) {
|
||||
int rx[TM] = 0 ... TM;
|
||||
float *px[TM];
|
||||
float x[TM] = 0;
|
||||
int c = get_range_id(1);
|
||||
float g = *(G + c);
|
||||
float b = *(B + c);
|
||||
|
||||
fp32 mean[TM] = 0;
|
||||
float mean[TM] = 0;
|
||||
px = X + rx + c*DHWN;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
mean = mean + x;
|
||||
px = px + TM;
|
||||
}
|
||||
fp32 *pm = M + c;
|
||||
fp32 m = __sum(mean) * rcpDHWN;
|
||||
float *pm = M + c;
|
||||
float m = __sum(mean) * rcpDHWN;
|
||||
*pm = m;
|
||||
|
||||
fp32 var[TM] = 0;
|
||||
float var[TM] = 0;
|
||||
px = X + rx + c*DHWN;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
x = x - m;
|
||||
var = var + x*x;
|
||||
px = px + TM;
|
||||
}
|
||||
fp32 v = __sum(var) * rcpDHWN;
|
||||
fp32 *pv = V + c;
|
||||
float v = __sum(var) * rcpDHWN;
|
||||
float *pv = V + c;
|
||||
*pv = v;
|
||||
fp32 rstdg = 1 / sqrt(v + eps) * g;
|
||||
float rstdg = 1 / sqrt(v + eps) * g;
|
||||
|
||||
px = X + rx + c*DHWN;
|
||||
fp32* py[TM] = Y + rx + c*DHWN;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
float* py[TM] = Y + rx + c*DHWN;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
fp32 y[TM] = (x - m)*rstdg + b;
|
||||
float y[TM] = (x - m)*rstdg + b;
|
||||
*py = y;
|
||||
px = px + TM;
|
||||
py = py + TM;
|
||||
@@ -130,7 +127,7 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
||||
* --------------- */
|
||||
|
||||
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps)
|
||||
: base("batchnorm"),
|
||||
: base("batchnorm_backward"),
|
||||
C_(C), D_(D), H_(H), W_(W), B_(B),
|
||||
ty_(ty), eps_(eps)
|
||||
{ }
|
||||
@@ -139,12 +136,8 @@ size_t batchnorm_backward::num_flops() const {
|
||||
return C_*D_*H_*W_*B_;
|
||||
}
|
||||
|
||||
bool batchnorm_backward::operator <(const base& other) const {
|
||||
auto *y = dynamic_cast<const batchnorm_backward*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(C_, D_, H_, W_, B_, ty_)
|
||||
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
|
||||
std::vector<int64_t> batchnorm_backward::retune_params() const {
|
||||
return {C_, D_, H_, W_, B_};
|
||||
}
|
||||
|
||||
base* batchnorm_backward::clone() const {
|
||||
@@ -174,54 +167,54 @@ void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *ke
|
||||
void batchnorm_backward::triton_c_src(std::ostream &os) const {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {32, 64, 128};
|
||||
const tunable int TM = {32, 64, 128};
|
||||
|
||||
void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
|
||||
restrict read_only fp32 *DY,
|
||||
restrict read_only fp32 *X,
|
||||
restrict read_only fp32 *G,
|
||||
restrict read_only fp32 *M,
|
||||
restrict read_only fp32 *V,
|
||||
int32 DHWN, fp32 rcpDHWN, fp32 epsilon) {
|
||||
int32 rx[TM] = 0 ... TM;
|
||||
int32 c = get_range_id(1);
|
||||
int32 offset = c*DHWN;
|
||||
fp32 g = *(G + c);
|
||||
fp32 mean = *(M + c);
|
||||
fp32 var = *(V + c);
|
||||
fp32 rstd = 1 / sqrt(var + epsilon);
|
||||
fp32* px[TM];
|
||||
fp32* pdx[TM];
|
||||
fp32* pdy[TM];
|
||||
void batchnorm_backward(float *DX, float *DG, float *DB,
|
||||
restrict read_only float *DY,
|
||||
restrict read_only float *X,
|
||||
restrict read_only float *G,
|
||||
restrict read_only float *M,
|
||||
restrict read_only float *V,
|
||||
int DHWN, float rcpDHWN, float epsilon) {
|
||||
int rx[TM] = 0 ... TM;
|
||||
int c = get_range_id(1);
|
||||
int offset = c*DHWN;
|
||||
float g = *(G + c);
|
||||
float mean = *(M + c);
|
||||
float var = *(V + c);
|
||||
float rstd = 1 / sqrt(var + epsilon);
|
||||
float* px[TM];
|
||||
float* pdx[TM];
|
||||
float* pdy[TM];
|
||||
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
fp32 dg[TM] = 0;
|
||||
fp32 db[TM] = 0;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
fp32 x[TM] = *px;
|
||||
fp32 dy[TM] = *pdy;
|
||||
float dg[TM] = 0;
|
||||
float db[TM] = 0;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
float x[TM] = *px;
|
||||
float dy[TM] = *pdy;
|
||||
dg = dg + dy*(x - mean)*rstd;
|
||||
db = db + dy;
|
||||
px = px + TM;
|
||||
pdy = pdy + TM;
|
||||
}
|
||||
fp32 sdg = __sum(dg);
|
||||
fp32 sdb = __sum(db);
|
||||
fp32 *pdg = DG + c;
|
||||
fp32 *pdb = DB + c;
|
||||
float sdg = __sum(dg);
|
||||
float sdb = __sum(db);
|
||||
float *pdg = DG + c;
|
||||
float *pdb = DB + c;
|
||||
*pdg = sdg;
|
||||
*pdb = sdb;
|
||||
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
pdx = DX + rx + offset;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
fp32 x[TM] = *px;
|
||||
fp32 dy[TM] = *pdy;
|
||||
fp32 xhat[TM] = (x - mean) * rstd;
|
||||
fp32 xtmp[TM] = (xhat * dg + db) * rcpDHWN;
|
||||
fp32 dx[TM] = (dy - xtmp) * rstd * g;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
float x[TM] = *px;
|
||||
float dy[TM] = *pdy;
|
||||
float xhat[TM] = (x - mean) * rstd;
|
||||
float xtmp[TM] = (xhat * dg + db) * rcpDHWN;
|
||||
float dx[TM] = (dy - xtmp) * rstd * g;
|
||||
*pdx = dx;
|
||||
px = px + TM;
|
||||
pdy = pdy + TM;
|
||||
|
Reference in New Issue
Block a user