[dnn]: Now implementing all existing DNN routines using common base template and auto-tuner
This commit is contained in:
@@ -29,17 +29,36 @@ namespace dnn{
|
||||
* Forward
|
||||
* --------------- */
|
||||
|
||||
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty)
|
||||
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(1e-5) {
|
||||
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty, float eps)
|
||||
: base("batchnorm"),
|
||||
C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) {
|
||||
DHWB_ = D_*H_*W_*B_;
|
||||
rcpDHWB_ = (float)1 / DHWB_;
|
||||
}
|
||||
|
||||
void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *y, driver::buffer *m, driver::buffer *v,
|
||||
driver::buffer *x, driver::buffer *g, driver::buffer *b,
|
||||
size_t, size_t nthreads) {
|
||||
size_t batchnorm_forward::num_flops() const {
|
||||
return C_*DHWB_;
|
||||
}
|
||||
|
||||
bool batchnorm_forward::operator <(const base& other) const {
|
||||
auto *y = dynamic_cast<const batchnorm_forward*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(C_, D_, H_, W_, B_, ty_)
|
||||
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
|
||||
}
|
||||
|
||||
base* batchnorm_forward::clone() const {
|
||||
return new batchnorm_forward(*this);
|
||||
}
|
||||
|
||||
void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned>&,
|
||||
size_t nthreads)
|
||||
{
|
||||
driver::buffer *y = args[0], *m = args[1], *v = args[2];
|
||||
driver::buffer *x = args[3], *g = args[4], *b = args[5];
|
||||
std::array<size_t, 3> grid = {(size_t)C_, 1, 1};
|
||||
kernel->setArg(0, y);
|
||||
kernel->setArg(1, m);
|
||||
@@ -53,7 +72,7 @@ void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
void batchnorm_forward::src(std::ostream &os) {
|
||||
void batchnorm_forward::triton_c_src(std::ostream &os) const {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {32, 64, 128};
|
||||
@@ -113,14 +132,32 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
||||
* --------------- */
|
||||
|
||||
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps)
|
||||
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps)
|
||||
: base("batchnorm"),
|
||||
C_(C), D_(D), H_(H), W_(W), B_(B),
|
||||
ty_(ty), eps_(eps)
|
||||
{ }
|
||||
|
||||
void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *dx, driver::buffer *dg, driver::buffer *db, driver::buffer *dy,
|
||||
driver::buffer *x, driver::buffer *g, driver::buffer *m, driver::buffer *v,
|
||||
size_t, size_t nthreads) {
|
||||
size_t batchnorm_backward::num_flops() const {
|
||||
return C_*D_*H_*W_*B_;
|
||||
}
|
||||
|
||||
bool batchnorm_backward::operator <(const base& other) const {
|
||||
auto *y = dynamic_cast<const batchnorm_backward*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(C_, D_, H_, W_, B_, ty_)
|
||||
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
|
||||
}
|
||||
|
||||
base* batchnorm_backward::clone() const {
|
||||
return new batchnorm_backward(*this);
|
||||
}
|
||||
|
||||
void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer *> args,
|
||||
const std::vector<unsigned> &, size_t nthreads) {
|
||||
driver::buffer *dx = args[0], *dg = args[1], *db = args[2], *dy = args[3];
|
||||
driver::buffer *x = args[4], *g = args[5], *m = args[6], *v = args[7];
|
||||
std::array<size_t, 3> grid = {1, (size_t)C_, 1};
|
||||
kernel->setArg(0, dx);
|
||||
kernel->setArg(1, dg);
|
||||
@@ -136,7 +173,7 @@ void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
void batchnorm_backward::src(std::ostream &os) {
|
||||
void batchnorm_backward::triton_c_src(std::ostream &os) const {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {32, 64, 128};
|
||||
|
Reference in New Issue
Block a user