[dnn] better specification of recompilation key

This commit is contained in:
Philippe Tillet
2019-08-02 17:42:48 -07:00
parent 3b92ddf7e6
commit d9945692a9
31 changed files with 418 additions and 428 deletions

View File

@@ -30,7 +30,7 @@ namespace dnn{
* --------------- */
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty, float eps)
: base("batchnorm"),
: base("batchnorm_forward"),
C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) {
DHWB_ = D_*H_*W_*B_;
rcpDHWB_ = (float)1 / DHWB_;
@@ -40,12 +40,9 @@ size_t batchnorm_forward::num_flops() const {
return C_*DHWB_;
}
bool batchnorm_forward::operator <(const base& other) const {
auto *y = dynamic_cast<const batchnorm_forward*>(&other);
if(!y)
return true;
return std::tie(C_, D_, H_, W_, B_, ty_)
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
std::vector<int64_t> batchnorm_forward::retune_params() const {
return {C_, D_, H_, W_, B_};
}
base* batchnorm_forward::clone() const {
@@ -74,50 +71,50 @@ void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *ker
void batchnorm_forward::triton_c_src(std::ostream &os) const {
os <<
R"(
const tunable int32 TM = {32, 64, 128};
const tunable int TM = {32, 64, 128};
void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
restrict read_only fp32 *X,
restrict read_only fp32 *G,
restrict read_only fp32 *B,
int32 DHWN,
fp32 rcpDHWN, fp32 eps) {
int32 rx[TM] = 0 ... TM;
fp32 *px[TM];
fp32 x[TM];
int32 c = get_range_id(1);
fp32 g = *(G + c);
fp32 b = *(B + c);
void batchnorm_forward(float *Y, float *M, float *V,
restrict read_only float *X,
restrict read_only float *G,
restrict read_only float *B,
int DHWN,
float rcpDHWN, float eps) {
int rx[TM] = 0 ... TM;
float *px[TM];
float x[TM] = 0;
int c = get_range_id(1);
float g = *(G + c);
float b = *(B + c);
fp32 mean[TM] = 0;
float mean[TM] = 0;
px = X + rx + c*DHWN;
for(int32 i = 0; i < DHWN; i = i + TM){
for(int i = 0; i < DHWN; i = i + TM){
x = *px;
mean = mean + x;
px = px + TM;
}
fp32 *pm = M + c;
fp32 m = __sum(mean) * rcpDHWN;
float *pm = M + c;
float m = __sum(mean) * rcpDHWN;
*pm = m;
fp32 var[TM] = 0;
float var[TM] = 0;
px = X + rx + c*DHWN;
for(int32 i = 0; i < DHWN; i = i + TM){
for(int i = 0; i < DHWN; i = i + TM){
x = *px;
x = x - m;
var = var + x*x;
px = px + TM;
}
fp32 v = __sum(var) * rcpDHWN;
fp32 *pv = V + c;
float v = __sum(var) * rcpDHWN;
float *pv = V + c;
*pv = v;
fp32 rstdg = 1 / sqrt(v + eps) * g;
float rstdg = 1 / sqrt(v + eps) * g;
px = X + rx + c*DHWN;
fp32* py[TM] = Y + rx + c*DHWN;
for(int32 i = 0; i < DHWN; i = i + TM){
float* py[TM] = Y + rx + c*DHWN;
for(int i = 0; i < DHWN; i = i + TM){
x = *px;
fp32 y[TM] = (x - m)*rstdg + b;
float y[TM] = (x - m)*rstdg + b;
*py = y;
px = px + TM;
py = py + TM;
@@ -130,7 +127,7 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
* --------------- */
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps)
: base("batchnorm"),
: base("batchnorm_backward"),
C_(C), D_(D), H_(H), W_(W), B_(B),
ty_(ty), eps_(eps)
{ }
@@ -139,12 +136,8 @@ size_t batchnorm_backward::num_flops() const {
return C_*D_*H_*W_*B_;
}
bool batchnorm_backward::operator <(const base& other) const {
auto *y = dynamic_cast<const batchnorm_backward*>(&other);
if(!y)
return true;
return std::tie(C_, D_, H_, W_, B_, ty_)
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
std::vector<int64_t> batchnorm_backward::retune_params() const {
return {C_, D_, H_, W_, B_};
}
base* batchnorm_backward::clone() const {
@@ -174,54 +167,54 @@ void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *ke
void batchnorm_backward::triton_c_src(std::ostream &os) const {
os <<
R"(
const tunable int32 TM = {32, 64, 128};
const tunable int TM = {32, 64, 128};
void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
restrict read_only fp32 *DY,
restrict read_only fp32 *X,
restrict read_only fp32 *G,
restrict read_only fp32 *M,
restrict read_only fp32 *V,
int32 DHWN, fp32 rcpDHWN, fp32 epsilon) {
int32 rx[TM] = 0 ... TM;
int32 c = get_range_id(1);
int32 offset = c*DHWN;
fp32 g = *(G + c);
fp32 mean = *(M + c);
fp32 var = *(V + c);
fp32 rstd = 1 / sqrt(var + epsilon);
fp32* px[TM];
fp32* pdx[TM];
fp32* pdy[TM];
void batchnorm_backward(float *DX, float *DG, float *DB,
restrict read_only float *DY,
restrict read_only float *X,
restrict read_only float *G,
restrict read_only float *M,
restrict read_only float *V,
int DHWN, float rcpDHWN, float epsilon) {
int rx[TM] = 0 ... TM;
int c = get_range_id(1);
int offset = c*DHWN;
float g = *(G + c);
float mean = *(M + c);
float var = *(V + c);
float rstd = 1 / sqrt(var + epsilon);
float* px[TM];
float* pdx[TM];
float* pdy[TM];
px = X + rx + offset;
pdy = DY + rx + offset;
fp32 dg[TM] = 0;
fp32 db[TM] = 0;
for(int32 i = 0; i < DHWN; i = i + TM){
fp32 x[TM] = *px;
fp32 dy[TM] = *pdy;
float dg[TM] = 0;
float db[TM] = 0;
for(int i = 0; i < DHWN; i = i + TM){
float x[TM] = *px;
float dy[TM] = *pdy;
dg = dg + dy*(x - mean)*rstd;
db = db + dy;
px = px + TM;
pdy = pdy + TM;
}
fp32 sdg = __sum(dg);
fp32 sdb = __sum(db);
fp32 *pdg = DG + c;
fp32 *pdb = DB + c;
float sdg = __sum(dg);
float sdb = __sum(db);
float *pdg = DG + c;
float *pdb = DB + c;
*pdg = sdg;
*pdb = sdb;
px = X + rx + offset;
pdy = DY + rx + offset;
pdx = DX + rx + offset;
for(int32 i = 0; i < DHWN; i = i + TM){
fp32 x[TM] = *px;
fp32 dy[TM] = *pdy;
fp32 xhat[TM] = (x - mean) * rstd;
fp32 xtmp[TM] = (xhat * dg + db) * rcpDHWN;
fp32 dx[TM] = (dy - xtmp) * rstd * g;
for(int i = 0; i < DHWN; i = i + TM){
float x[TM] = *px;
float dy[TM] = *pdy;
float xhat[TM] = (x - mean) * rstd;
float xtmp[TM] = (xhat * dg + db) * rcpDHWN;
float dx[TM] = (dy - xtmp) * rstd * g;
*pdx = dx;
px = px + TM;
pdy = pdy + TM;