[examples/python/pytorch] added batchnorm cpp extension
This commit is contained in:
@@ -35,31 +35,31 @@ public:
|
||||
triton::driver::context* ctx = sstream.context();
|
||||
triton::driver::stream* stream = &sstream;
|
||||
// get inputs
|
||||
const Tensor& x = context->input(0);
|
||||
const Tensor& g = context->input(1);
|
||||
const Tensor& b = context->input(2);
|
||||
const Tensor& fw_x = context->input(0);
|
||||
const Tensor& fw_g = context->input(1);
|
||||
const Tensor& fw_b = context->input(2);
|
||||
// get sizes
|
||||
int C = x.dim_size(0);
|
||||
int H = x.dim_size(1);
|
||||
int W = x.dim_size(2);
|
||||
int B = x.dim_size(3);
|
||||
int C = fw_x.dim_size(0);
|
||||
int H = fw_x.dim_size(1);
|
||||
int W = fw_x.dim_size(2);
|
||||
int B = fw_x.dim_size(3);
|
||||
// allocate outputs
|
||||
Tensor* y = nullptr;
|
||||
Tensor* m = nullptr;
|
||||
Tensor* v = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, x.shape(), &y));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, g.shape(), &m));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, g.shape(), &v));
|
||||
Tensor* fw_y = nullptr;
|
||||
Tensor* fw_m = nullptr;
|
||||
Tensor* fw_v = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, fw_x.shape(), &fw_y));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_m));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_v));
|
||||
// triton handles
|
||||
triton::driver::cu_buffer tx(ctx, (CUdeviceptr)x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tg(ctx, (CUdeviceptr)g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tb(ctx, (CUdeviceptr)b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer ty(ctx, (CUdeviceptr)y->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tm(ctx, (CUdeviceptr)m->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)fw_b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer y(ctx, (CUdeviceptr)fw_y->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v->flat<float>().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
batchnorm.enqueue(stream, {&ty, &tm, &tv, &tx, &tg, &tb});
|
||||
batchnorm.enqueue(stream, {&y, &m, &v, &x, &g, &b});
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -99,35 +99,35 @@ public:
|
||||
triton::driver::context* ctx = sstream.context();
|
||||
triton::driver::stream* stream = &sstream;
|
||||
// get inputs
|
||||
const Tensor& dy = context->input(0);
|
||||
const Tensor& x = context->input(1);
|
||||
const Tensor& g = context->input(2);
|
||||
const Tensor& m = context->input(3);
|
||||
const Tensor& v = context->input(4);
|
||||
const Tensor& fw_dy = context->input(0);
|
||||
const Tensor& fw_x = context->input(1);
|
||||
const Tensor& fw_g = context->input(2);
|
||||
const Tensor& fw_m = context->input(3);
|
||||
const Tensor& fw_v = context->input(4);
|
||||
// get sizes
|
||||
int C = x.dim_size(0);
|
||||
int H = x.dim_size(1);
|
||||
int W = x.dim_size(2);
|
||||
int B = x.dim_size(3);
|
||||
int C = fw_x.dim_size(0);
|
||||
int H = fw_x.dim_size(1);
|
||||
int W = fw_x.dim_size(2);
|
||||
int B = fw_x.dim_size(3);
|
||||
// allocate outputs
|
||||
Tensor* dx = nullptr;
|
||||
Tensor* dg = nullptr;
|
||||
Tensor* db = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, x.shape(), &dx));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, g.shape(), &dg));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, g.shape(), &db));
|
||||
Tensor* fw_dx = nullptr;
|
||||
Tensor* fw_dg = nullptr;
|
||||
Tensor* fw_db = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, fw_x.shape(), &fw_dx));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_dg));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_db));
|
||||
// triton handles
|
||||
triton::driver::cu_buffer tdy(ctx, (CUdeviceptr)dy.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tx(ctx, (CUdeviceptr)x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tg(ctx, (CUdeviceptr)g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tm(ctx, (CUdeviceptr)m.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdx(ctx, (CUdeviceptr)dx->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdg(ctx, (CUdeviceptr)dg->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dy(ctx, (CUdeviceptr)fw_dy.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dx(ctx, (CUdeviceptr)fw_dx->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db->flat<float>().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||
batchnorm.enqueue(stream, {&tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv});
|
||||
batchnorm.enqueue(stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
|
||||
}
|
||||
|
||||
private:
|
||||
|
Reference in New Issue
Block a user