[PYTHON][TENSORFLOW] Got rid of alloc_empty entirely; now doing
generating allocation code inside the tensorflow op
This commit is contained in:
@@ -89,9 +89,9 @@ inline std::string to_tf_ty(ir::type *ty) {
|
||||
if(ty->is_half_ty())
|
||||
return "float16";
|
||||
if(ty->is_float_ty())
|
||||
return "float32";
|
||||
return "float";
|
||||
if(ty->is_double_ty())
|
||||
return "float64";
|
||||
return "double";
|
||||
if(ty->is_pointer_ty())
|
||||
return "Tensor";
|
||||
throw std::runtime_error("unknown type");
|
||||
@@ -113,21 +113,50 @@ inline std::string ref_to_tf_ty(ir::type *ty) {
|
||||
}
|
||||
|
||||
|
||||
void gen_extract_inputs(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
void gen_extract_inputs(std::ostream &os, const std::vector<ir::argument*>& args, const std::vector<std::string>& outputs) {
|
||||
for(unsigned i = 0; i < args.size(); i++){
|
||||
ir::value *arg = args[i];
|
||||
std::string suffix = "";
|
||||
ir::type *tr_ty = arg->get_type();
|
||||
std::string tf_ty = ref_to_tf_ty(tr_ty);
|
||||
if(!tr_ty->is_pointer_ty())
|
||||
suffix = ".scalar<" + tf_ty + ">()()";
|
||||
os << " " << tf_ty << " " << arg->get_name() << " = context->input(" << i << ")" << suffix << ";\n ";
|
||||
const std::string& name = arg->get_name();
|
||||
std::string ty = to_tf_ty(arg->get_type());
|
||||
if(!arg->get_type()->is_pointer_ty())
|
||||
os << " " << ty << " " << name << " = context->input(" << i << ").scalar<" << ty << ">()();\n ";
|
||||
else if(std::find(outputs.begin(), outputs.end(), arg->get_name()) == outputs.end())
|
||||
os << " const Tensor* " << name << " = &context->input(" << i << ");\n ";
|
||||
else
|
||||
os << " Tensor* " << name << " = nullptr;\n ";
|
||||
}
|
||||
}
|
||||
|
||||
void gen_set_outputs(std::ostream &os, const std::vector<std::string>& outputs) {
|
||||
void gen_set_outputs(std::ostream &os, const std::vector<ir::argument*>& args, const std::vector<std::string>& outputs) {
|
||||
for(unsigned i = 0; i < outputs.size(); i++)
|
||||
os << " context->set_output(" << i << ", " << outputs[i] << ");\n ";
|
||||
os << " TensorShape shape" << i << ";\n ";
|
||||
// initialize shapes
|
||||
|
||||
std::vector<int> out_idx;
|
||||
for(size_t i = 0; i < outputs.size(); i++){
|
||||
std::string name = outputs[i];
|
||||
size_t idx;
|
||||
for(idx = 0; idx < args.size(); idx++)
|
||||
if(args[idx]->get_name() == name)
|
||||
break;
|
||||
if(idx == args.size())
|
||||
throw std::runtime_error("unknown output");
|
||||
out_idx.push_back(idx);
|
||||
}
|
||||
|
||||
for(unsigned i = 0; i < outputs.size(); i++)
|
||||
os << " const Tensor& " << outputs[i] << "_shape = context->input(" << out_idx[i] << ");\n ";
|
||||
for(unsigned i = 0; i < outputs.size(); i++)
|
||||
os << " const int32* " << outputs[i] << "_shape_data = (const int32*)" << outputs[i] << "_shape.tensor_data().data();\n ";
|
||||
for(unsigned i = 0; i < outputs.size(); i++)
|
||||
os << " size_t " << outputs[i] << "_rank = " << outputs[i] << "_shape.dim_size(0);\n ";
|
||||
for(unsigned i = 0; i < outputs.size(); i++)
|
||||
os << " for(size_t d = 0; d < " << outputs[i] << "_rank ; d++) "
|
||||
<< "shape" << i << ".AddDim(" << outputs[i] << "_shape_data[d]);\n ";
|
||||
|
||||
// allocate
|
||||
for(unsigned i = 0; i < outputs.size(); i++)
|
||||
os << " OP_REQUIRES_OK(context, context->allocate_output(" << i << ", shape" << i << ", &" << outputs[i] << "));\n ";
|
||||
}
|
||||
|
||||
void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
@@ -136,7 +165,7 @@ void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args)
|
||||
if(!arg->get_type()->is_pointer_ty())
|
||||
continue;
|
||||
const std::string& name = arg->get_name();
|
||||
os << " drv::cu_buffer cu_" + name + "(ctx, " + name + ".tensor_data().size(), (CUdeviceptr)" + name + ".tensor_data().data(), false);\n ";
|
||||
os << " drv::cu_buffer cu_" + name + "(ctx, " + name + "->tensor_data().size(), (CUdeviceptr)" + name + "->tensor_data().data(), false);\n ";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,7 +190,8 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect
|
||||
|
||||
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
const std::string &opname,
|
||||
const std::vector<ir::argument*>& args){
|
||||
const std::vector<ir::argument*>& args,
|
||||
const std::vector<std::string>& outputs){
|
||||
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
ir::argument *arg = args[i];
|
||||
@@ -171,20 +201,31 @@ void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
if(!arg->get_type()->is_pointer_ty())
|
||||
os << ".HostMemory(\"" + name + "\")";
|
||||
}
|
||||
for(size_t i = 0; i < outputs.size(); i++){
|
||||
std::string name = outputs[i];
|
||||
name[0] = std::tolower(name[0]);
|
||||
os << ".HostMemory(\"" << name << "_shape\")";
|
||||
}
|
||||
os << ", " + opname << ");\n";
|
||||
}
|
||||
|
||||
void gen_tf_register_op(std::ostream &os, const std::string &name,
|
||||
const std::vector<ir::argument*>& args,
|
||||
const std::vector<std::string>& outputs){
|
||||
|
||||
auto tolower = [](char c) { return std::tolower(c);};
|
||||
|
||||
os << "REGISTER_OP(\"" << name << "\")\n";
|
||||
for(size_t i = 0; i < args.size(); i++)
|
||||
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = arg->get_name();
|
||||
auto tolower = [](char c) { return std::tolower(c);};
|
||||
std::transform(name.begin(), name.end(), name.begin(), tolower);
|
||||
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
|
||||
os << " .Input(\"" << name << ": T" << i << "\")\n";
|
||||
if(std::find(outputs.begin(), outputs.end(), arg->get_name()) == outputs.end())
|
||||
os << " .Input(\"" << name << ": T" << i << "\")\n";
|
||||
else
|
||||
os << " .Input(\"" << name << "_shape: int32\")\n";
|
||||
}
|
||||
std::vector<int> out_idx;
|
||||
for(size_t i = 0; i < outputs.size(); i++){
|
||||
@@ -197,15 +238,22 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
|
||||
throw std::runtime_error("unknown output");
|
||||
out_idx.push_back(idx);
|
||||
}
|
||||
for(size_t i = 0; i < out_idx.size(); i++)
|
||||
os << " .Output(\"out" << i << ": T" << out_idx[i] << "\")\n";
|
||||
for(size_t i = 0; i < out_idx.size(); i++){
|
||||
std::string name = outputs[i];
|
||||
std::transform(name.begin(), name.end(), name.begin(), tolower);
|
||||
os << " .Output(\"" << name << ": T" << out_idx[i] << "\")\n";
|
||||
}
|
||||
os << " .Attr(\"id: int\")\n";
|
||||
os << " .Attr(\"bench: int\")\n";
|
||||
os << " .Attr(\"bench_id: int\")\n";
|
||||
os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {\n";
|
||||
os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* ctx) {\n";
|
||||
for(size_t i = 0; i < out_idx.size(); i++)
|
||||
os << " c->set_output(" << i << ", c->input(" << out_idx[i] << "));\n";
|
||||
os << " return Status::OK();\n";
|
||||
os << " shape_inference::ShapeHandle handle" << i << ";\n";
|
||||
for(size_t i = 0; i < out_idx.size(); i++)
|
||||
os << " ctx->MakeShapeFromShapeTensor(" << out_idx[i] << ", &handle" << i << ");\n";
|
||||
for(size_t i = 0; i < out_idx.size(); i++)
|
||||
os << " ctx->set_output(" << i << ", handle" << i << ");\n";
|
||||
os << " return Status::OK();\n";
|
||||
os << " })\n";
|
||||
|
||||
os << ";\n";
|
||||
@@ -237,6 +285,7 @@ std::tuple<std::string,
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
make_module(src, &*ir, opt);
|
||||
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
std::string name = fn->get_name();
|
||||
@@ -276,18 +325,20 @@ class )" << opname << R"(: public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
|
||||
// get device/stream
|
||||
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||
drv::cu_stream sstream(device.stream(), false);
|
||||
drv::context* ctx = sstream.context();
|
||||
drv::stream* stream = &sstream;
|
||||
|
||||
// extract inputs
|
||||
)";
|
||||
gen_extract_inputs(oss, fn->args());
|
||||
gen_extract_inputs(oss, fn->args(), outputs);
|
||||
oss << R"(
|
||||
// set outputs
|
||||
)";
|
||||
gen_set_outputs(oss, outputs);
|
||||
gen_set_outputs(oss, fn->args(), outputs);
|
||||
oss << R"(
|
||||
// wrap tensors
|
||||
)";
|
||||
@@ -309,7 +360,7 @@ private:
|
||||
|
||||
// register kernel builder
|
||||
)";
|
||||
gen_tf_register_kernel_builder(oss, cc_name, opname, fn->args());
|
||||
gen_tf_register_kernel_builder(oss, cc_name, opname, fn->args(), outputs);
|
||||
oss << R"(
|
||||
// register op
|
||||
)";
|
||||
|
@@ -8,6 +8,7 @@ class AllocEmptyOp : public OpKernel {
|
||||
explicit AllocEmptyOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
std::cout << "executing allocempty" << std::endl;
|
||||
// fetch input
|
||||
const Tensor& x = context->input(0);
|
||||
const int32* x_data = (const int32*)x.tensor_data().data();
|
||||
|
@@ -44,12 +44,7 @@ class function(metaclass = function_meta):
|
||||
@classmethod
|
||||
def apply_tensorflow(cls, *args, **kwargs):
|
||||
ctx = OpContext()
|
||||
# Acquire a mutex here to ensure that calls to alloc_empty()
|
||||
# are handled properly
|
||||
mutex = fw.gen_resource_variable_ops.mutex_v2()
|
||||
lock = fw.gen_resource_variable_ops.mutex_lock(mutex)
|
||||
with fw.tensorflow.control_dependencies([lock]):
|
||||
result = cls.forward(ctx, *args, **kwargs)
|
||||
result = cls.forward(ctx, *args, **kwargs)
|
||||
# Find a mapping between ::forward arguments and tensorflow op arguments
|
||||
remap = dict()
|
||||
for i, ix in enumerate(result.op.inputs):
|
||||
|
@@ -222,11 +222,17 @@ class kernel:
|
||||
libtriton.register_grid(op_id, _make_grid(args))
|
||||
# id for the benchmark result
|
||||
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
|
||||
# create operands
|
||||
# call framework function
|
||||
if fw.has_tensorflow():
|
||||
args = [x for x in args[:-1]]
|
||||
ret = self.fw_op(*args, id=op_id, bench=bench, bench_id=bench_id)
|
||||
# operands
|
||||
operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args[:-1]]
|
||||
# output data types
|
||||
kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id}
|
||||
for i, x in enumerate(args[:-1]):
|
||||
if isinstance(x, triton.utils.tf_empty_proxy):
|
||||
kwargs['T' + str(i)] = x.dtype
|
||||
# launch
|
||||
ret = self.fw_op(*operands, **kwargs)
|
||||
if bench > 0:
|
||||
bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id)
|
||||
elif fw.has_torch():
|
||||
|
@@ -4,7 +4,7 @@ import math
|
||||
class _batchnorm(triton.function):
|
||||
|
||||
fwd_src = """
|
||||
void batchnormForward(float *Y, float *M, float *V,
|
||||
void fwdbatchnorm(float *Y, float *M, float *V,
|
||||
float *X, float *G, float *B,
|
||||
int N, float rcpN, float eps) {
|
||||
int rx[TM] = 0 ... TM;
|
||||
@@ -52,6 +52,58 @@ void batchnormForward(float *Y, float *M, float *V,
|
||||
|
||||
fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V'])
|
||||
|
||||
bwd_src = """
|
||||
void batchnormBackward(float *DX, float *DG, float *DB,
|
||||
float *DY, float *X, float *G,
|
||||
float *M, float *V,
|
||||
int DHWN, float rcpDHWN, float epsilon) {
|
||||
int rx[TM] = 0 ... TM;
|
||||
int c = get_program_id(1);
|
||||
int offset = c*DHWN;
|
||||
float g = *(G + c);
|
||||
float mean = *(M + c);
|
||||
float var = *(V + c);
|
||||
float rstd = 1 / sqrtf(var + epsilon);
|
||||
float* px[TM];
|
||||
float* pdx[TM];
|
||||
float* pdy[TM];
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
float dg[TM] = 0;
|
||||
float db[TM] = 0;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
float x[TM] = *px;
|
||||
float dy[TM] = *pdy;
|
||||
dg = dg + dy*(x - mean)*rstd;
|
||||
db = db + dy;
|
||||
px = px + TM;
|
||||
pdy = pdy + TM;
|
||||
}
|
||||
float sdg = dg[+];
|
||||
float sdb = db[+];
|
||||
float *pdg = DG + c;
|
||||
float *pdb = DB + c;
|
||||
*pdg = sdg;
|
||||
*pdb = sdb;
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
pdx = DX + rx + offset;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
float x[TM] = *px;
|
||||
float dy[TM] = *pdy;
|
||||
float xhat[TM] = (x - mean) * rstd;
|
||||
float xtmp[TM] = (xhat * dg + db) * rcpDHWN;
|
||||
float dx[TM] = (dy - xtmp) * rstd * g;
|
||||
*pdx = dx;
|
||||
px = px + TM;
|
||||
pdy = pdy + TM;
|
||||
pdx = pdx + TM;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB'])
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, gamma, beta, eps):
|
||||
shape = triton.shape(x)
|
||||
@@ -63,13 +115,29 @@ void batchnormForward(float *Y, float *M, float *V,
|
||||
var = triton.empty([C], dtype=dtype)
|
||||
# execute kernels
|
||||
N = H*W*B
|
||||
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, N, 1./N, eps,
|
||||
y, mean, var = _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, N, 1./N, eps,
|
||||
lambda opt: [1, C],
|
||||
TM = 128)
|
||||
# save
|
||||
ctx.eps = eps
|
||||
ctx.save_for_backward(x, gamma, beta, mean, var)
|
||||
return y, mean, var
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
eps = ctx.eps
|
||||
x, gamma, beta, mean, var = ctx.saved_tensors
|
||||
dx = triton.empty(x.shape, dtype=x.dtype)
|
||||
dgamma = triton.empty(gamma.shape, dtype=gamma.dtype)
|
||||
dbeta = triton.empty(beta.shape, dtype=beta.dtype)
|
||||
# launch
|
||||
C, H, W, B = x.shape
|
||||
N = H*W*B
|
||||
_batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
|
||||
x, gamma, mean, var,
|
||||
N, 1./N, eps,
|
||||
lambda opt: [1, C],
|
||||
TM = 128)
|
||||
return dx, dgamma, dbeta, None
|
||||
|
||||
batchnorm = _batchnorm.apply
|
@@ -7,17 +7,16 @@ def cdiv(a, b):
|
||||
|
||||
class tf_empty_proxy:
|
||||
|
||||
def __init__(self, args, dtype):
|
||||
self.args = args
|
||||
def __init__(self, shape, dtype):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
def empty(shapes, dtype):
|
||||
def empty(shape, dtype):
|
||||
if fw.has_tensorflow():
|
||||
#return fw.tensorflow.Variable(np.empty(shapes),shape=shapes, dtype=dtype)
|
||||
args = [x.handle if isinstance(x, scalar) else fw.tensorflow.constant(x) for x in shapes]
|
||||
args = fw.tensorflow.stack(args)
|
||||
#return tf_empty_proxy(args, dtype)
|
||||
return fw.tf_extra_ops.alloc_empty(args, T = dtype)
|
||||
shape = [x.handle if isinstance(x, scalar) else fw.tensorflow.constant(x) for x in shape]
|
||||
shape = fw.tensorflow.stack(shape)
|
||||
return tf_empty_proxy(shape, dtype)
|
||||
#return fw.tf_extra_ops.alloc_empty(args, T = dtype)
|
||||
elif fw.has_torch():
|
||||
return fw.torch.empty(*shapes).cuda()
|
||||
|
||||
|
Reference in New Issue
Block a user