[PYTHON][TENSORFLOW] More work

This commit is contained in:
Philippe Tillet
2019-10-30 18:39:58 -04:00
parent fd09f9c99d
commit e0fe8d9058
6 changed files with 184 additions and 104 deletions

View File

@@ -45,11 +45,12 @@ if mode == MODE.TF:
fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype)
fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype)
# execute
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3])
fw_y = triton.ops.batchnorm(fw_x, fw_mean, fw_var, fw_gamma, fw_beta, 1e-4)
fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy)
# run
sess = tf.InteractiveSession()
feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy}
sess.run(tf.global_variables_initializer())
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)
print(result)

View File

@@ -112,51 +112,70 @@ inline std::string ref_to_tf_ty(ir::type *ty) {
return res;
}
std::string tf_normalize(const std::string& name) {
std::string ret = name;
auto tolower = [](char c) { return std::tolower(c);};
std::transform(ret.begin(), ret.end(), ret.begin(), tolower);
return ret;
}
void gen_extract_inputs(std::ostream &os, const std::vector<ir::argument*>& args, const std::vector<std::string>& outputs) {
struct tf_alloc_t{
enum type_t{
OUTPUT,
TEMP
};
tf_alloc_t(const std::string& _name, type_t _type)
: name(_name), type(_type), tf_name(tf_normalize(_name)){ }
std::string tf_name;
std::string name;
type_t type;
size_t shape_id;
};
typedef std::vector<tf_alloc_t> alloc_map_t;
void gen_extract_inputs(std::ostream &os, const std::vector<ir::argument*>& args, const alloc_map_t& allocs) {
for(unsigned i = 0; i < args.size(); i++){
ir::value *arg = args[i];
const std::string& name = arg->get_name();
std::string ty = to_tf_ty(arg->get_type());
if(!arg->get_type()->is_pointer_ty())
os << " " << ty << " " << name << " = context->input(" << i << ").scalar<" << ty << ">()();\n ";
else if(std::find(outputs.begin(), outputs.end(), arg->get_name()) == outputs.end())
else if(std::find_if(allocs.begin(), allocs.end(),
[&](tf_alloc_t x) {
return x.name == name;
}) == allocs.end())
os << " const Tensor* " << name << " = &context->input(" << i << ");\n ";
else
os << " Tensor* " << name << " = nullptr;\n ";
}
}
void gen_set_outputs(std::ostream &os, const std::vector<ir::argument*>& args, const std::vector<std::string>& outputs) {
for(unsigned i = 0; i < outputs.size(); i++)
os << " TensorShape shape" << i << ";\n ";
void gen_set_outputs(std::ostream &os, const std::vector<ir::argument*>& args, const alloc_map_t& allocs) {
// initialize shapes
std::vector<int> out_idx;
for(size_t i = 0; i < outputs.size(); i++){
std::string name = outputs[i];
size_t idx;
for(idx = 0; idx < args.size(); idx++)
if(args[idx]->get_name() == name)
break;
if(idx == args.size())
throw std::runtime_error("unknown output");
out_idx.push_back(idx);
}
for(unsigned i = 0; i < outputs.size(); i++)
os << " const Tensor& " << outputs[i] << "_shape = context->input(" << out_idx[i] << ");\n ";
for(unsigned i = 0; i < outputs.size(); i++)
os << " const int32* " << outputs[i] << "_shape_data = (const int32*)" << outputs[i] << "_shape.tensor_data().data();\n ";
for(unsigned i = 0; i < outputs.size(); i++)
os << " size_t " << outputs[i] << "_rank = " << outputs[i] << "_shape.dim_size(0);\n ";
for(unsigned i = 0; i < outputs.size(); i++)
os << " for(size_t d = 0; d < " << outputs[i] << "_rank ; d++) "
<< "shape" << i << ".AddDim(" << outputs[i] << "_shape_data[d]);\n ";
for(const auto& x: allocs)
os << " TensorShape " << x.name << "_shape;\n ";
for(const auto& x: allocs)
os << " const Tensor& " << x.name << "_shape_tensor = context->input(" << x.shape_id << ");\n ";
for(const auto& x: allocs)
os << " const int32* " << x.name << "_shape_data = (const int32*)" << x.name << "_shape_tensor.tensor_data().data();\n ";
for(const auto& x: allocs)
os << " size_t " << x.name << "_rank = " << x.name << "_shape_tensor.dim_size(0);\n ";
for(const auto& x: allocs)
os << " for(size_t d = 0; d < " << x.name << "_rank ; d++) "
<< x.name << "_shape.AddDim(" << x.name << "_shape_data[d]);\n ";
// allocate
for(unsigned i = 0; i < outputs.size(); i++)
os << " OP_REQUIRES_OK(context, context->allocate_output(" << i << ", shape" << i << ", &" << outputs[i] << "));\n ";
int output = 0;
for(const auto& x: allocs){
if(x.type == tf_alloc_t::OUTPUT)
os << " OP_REQUIRES_OK(context, context->allocate_output(" << output++ << ", " << x.name << "_shape, &" << x.name << "));\n ";
else
os << " OP_REQUIRES_OK(context, context->allocate_temp(" << x.name << "_type, " << x.name << "_shape, " << x.name << "));\n ";
}
}
void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args) {
@@ -169,7 +188,7 @@ void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args)
}
}
void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vector<ir::argument*>& args) {
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
os << " std::function<void()> run = [&](){\n ";
os << " (*id_fn_map.at(id_))({";
for(unsigned i = 0; i < args.size() ; i++){
@@ -181,9 +200,9 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect
os << ", ";
os << name;
}
os << "}, *id_grid_map.at(id_), stream);\n";
os << "}, *id_grid_map.at(id_), stream);\n ";
os << " };\n ";
os << " run();";
os << " run();\n ";
os << " if(bench_ > 0)\n ";
os << " i64scalar_map[bench_id_] = triton::tools::bench(run, stream);\n ";
}
@@ -191,69 +210,53 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
const std::string &opname,
const std::vector<ir::argument*>& args,
const std::vector<std::string>& outputs){
const alloc_map_t& allocs){
auto tolower = [](char c) { return std::tolower(c);};
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
for(size_t i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
std::string name = arg->get_name();
std::transform(name.begin(), name.end(), name.begin(), tolower);
std::string name = tf_normalize(arg->get_name());
if(!arg->get_type()->is_pointer_ty())
os << ".HostMemory(\"" + name + "\")";
}
for(size_t i = 0; i < outputs.size(); i++){
std::string name = outputs[i];
std::transform(name.begin(), name.end(), name.begin(), tolower);
os << ".HostMemory(\"" << name << "_shape\")";
}
for(const auto& x: allocs)
os << ".HostMemory(\"" << x.tf_name << "_shape\")";
os << ", " + opname << ");\n";
}
void gen_tf_register_op(std::ostream &os, const std::string &name,
const std::vector<ir::argument*>& args,
const std::vector<std::string>& outputs){
const alloc_map_t& allocs){
auto tolower = [](char c) { return std::tolower(c);};
os << "REGISTER_OP(\"" << name << "\")\n";
for(size_t i = 0; i < args.size(); i++)
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
for(size_t i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
std::string name = arg->get_name();
std::transform(name.begin(), name.end(), name.begin(), tolower);
if(std::find(outputs.begin(), outputs.end(), arg->get_name()) == outputs.end())
std::string name = tf_normalize(arg->get_name());
if(std::find_if(allocs.begin(), allocs.end(),
[&](tf_alloc_t x) {
return name == x.tf_name;
}) == allocs.end())
os << " .Input(\"" << name << ": T" << i << "\")\n";
else
os << " .Input(\"" << name << "_shape: int32\")\n";
}
std::vector<int> out_idx;
for(size_t i = 0; i < outputs.size(); i++){
std::string name = outputs[i];
size_t idx;
for(idx = 0; idx < args.size(); idx++)
if(args[idx]->get_name() == name)
break;
if(idx == args.size())
throw std::runtime_error("unknown output");
out_idx.push_back(idx);
}
for(size_t i = 0; i < out_idx.size(); i++){
std::string name = outputs[i];
std::transform(name.begin(), name.end(), name.begin(), tolower);
os << " .Output(\"" << name << ": T" << out_idx[i] << "\")\n";
}
for(const auto& x: allocs)
if(x.type == tf_alloc_t::OUTPUT)
os << " .Output(\"" << x.tf_name << ": T" << x.shape_id << "\")\n";
os << " .Attr(\"id: int\")\n";
os << " .Attr(\"bench: int\")\n";
os << " .Attr(\"bench_id: int\")\n";
os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* ctx) {\n";
for(size_t i = 0; i < out_idx.size(); i++)
os << " shape_inference::ShapeHandle handle" << i << ";\n";
for(size_t i = 0; i < out_idx.size(); i++)
os << " ctx->MakeShapeFromShapeTensor(" << out_idx[i] << ", &handle" << i << ");\n";
for(size_t i = 0; i < out_idx.size(); i++)
os << " ctx->set_output(" << i << ", handle" << i << ");\n";
size_t current = 0;
for(const auto& x: allocs)
if(x.type == tf_alloc_t::OUTPUT){
os << " shape_inference::ShapeHandle " << x.tf_name << "_handle;\n";
os << " ctx->MakeShapeFromShapeTensor(" << x.shape_id << ", &" << x.tf_name << "_handle);\n";
os << " ctx->set_output(" << current++ << ", " << x.tf_name << "_handle);\n";
}
os << " return Status::OK();\n";
os << " })\n";
@@ -280,6 +283,7 @@ void make_module(const std::string& src, ir::module* ir,
std::tuple<std::string,
std::string> make_tensorflow_src(const std::string& src,
const std::vector<std::string>& outputs,
const std::vector<std::string>& tmp,
const runtime::function::options_space_t& opt)
{
// triton-ir code-gen
@@ -289,10 +293,28 @@ std::tuple<std::string,
// function
ir::function* fn = ir->get_function_list().front();
const std::vector<ir::argument*>& args = fn->args();
std::string name = fn->get_name();
std::string cc_name = name;
cc_name[0] = static_cast<char>(std::toupper(cc_name[0]));
std::string opname = cc_name + "Op";
// allocation info
alloc_map_t allocs;
for(size_t i = 0; i < outputs.size(); i++)
allocs.push_back(tf_alloc_t(outputs[i], tf_alloc_t::OUTPUT));
for(size_t i = 0; i < tmp.size(); i++)
allocs.push_back(tf_alloc_t(tmp[i], tf_alloc_t::TEMP));
for(auto &x: allocs){
size_t idx;
for(idx = 0; idx < args.size(); idx++)
if(args[idx]->get_name() == x.name)
break;
if(idx == args.size())
throw std::runtime_error("unknown output");
x.shape_id = idx;
}
std::ostringstream oss;
oss << R"(
@@ -323,6 +345,11 @@ class )" << opname << R"(: public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
OP_REQUIRES_OK(context, context->GetAttr("bench", &bench_));
OP_REQUIRES_OK(context, context->GetAttr("bench_id", &bench_id_));
)";
for(const auto& alloc: allocs)
oss << " OP_REQUIRES_OK(context, context->GetAttr(\"T" << alloc.shape_id << "\", &" << alloc.name << "_type));\n ";
oss << R"(
}
void Compute(OpKernelContext* context){
@@ -335,21 +362,21 @@ class )" << opname << R"(: public OpKernel {
// extract inputs
)";
gen_extract_inputs(oss, fn->args(), outputs);
gen_extract_inputs(oss, args, allocs);
oss << R"(
// set outputs
)";
gen_set_outputs(oss, fn->args(), outputs);
gen_set_outputs(oss, args, allocs);
oss << R"(
// wrap tensors
)";
gen_make_handles(oss, fn->args());
gen_make_handles(oss, args);
oss << R"(
)";
oss << R"(
// launch function
)";
gen_make_launch_function(oss, outputs.size(), fn->args());
gen_make_launch_function(oss, args);
oss << R"(
}
@@ -357,15 +384,20 @@ private:
int id_;
int bench_;
int64 bench_id_;
)";
for(const auto& alloc: allocs)
oss << "DataType " << alloc.name << "_type;\n ";
oss << R"(
};
// register kernel builder
)";
gen_tf_register_kernel_builder(oss, cc_name, opname, fn->args(), outputs);
gen_tf_register_kernel_builder(oss, cc_name, opname, args, allocs);
oss << R"(
// register op
)";
gen_tf_register_op(oss, cc_name, fn->args(), outputs);
gen_tf_register_op(oss, cc_name, args, allocs);
return {oss.str(), name};
}
@@ -517,6 +549,7 @@ void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
std::tuple<std::string,
std::string> make_torch_src(const std::string& src,
const std::vector<std::string>& outputs,
const std::vector<std::string>& tmp,
const runtime::function::options_space_t& opt) {
// triton-ir code-gen
ir::context ctx;

View File

@@ -41,24 +41,43 @@ class function(metaclass = function_meta):
return cls.backward(ctx, grad_output)
return TorchFunction.apply(*args, **kwargs)
@classmethod
def extract_tf_tensors(cls, lst, err):
for x in lst:
if x and not isinstance(x, triton.utils.tf_empty_proxy):
raise ValueError('Results of ' + err + ' must be created using triton.empty()')
if x and x.tensor is None:
raise ValueError('Empty tensor never filled during ' + err)
return [x.tensor if x else None for x in lst]
@classmethod
def apply_tensorflow(cls, *args, **kwargs):
ctx = OpContext()
result = cls.forward(ctx, *args, **kwargs)
op = result[0].op if isinstance(result, tuple) else result.op
# check that all the results stem from triton.empty
# and get the corresponding TF tensors if possible
result = result if isinstance(result, tuple) else (result, )
result = function.extract_tf_tensors(result, 'forward')
# Find a mapping between ::forward arguments and tensorflow op arguments
op = result[0].op
remap = dict()
for i, ix in enumerate(result.op.inputs):
for i, ix in enumerate(op.inputs):
for j, jx in enumerate(args):
if ix is jx:
remap[j] = i
# register backward
# Register backward pass
ctx_registry[op] = ctx
name = op.op_def.name
if not cls.registered:
@fw.tensorflow.RegisterGradient(name)
def gradient(op, *dys):
grad = cls.backward(ctx_registry[op], dys if len(dys) > 1 else dys[0])
def gradient(op, *dy):
dy = dy if len(dy) > 1 else dy[0]
grad = cls.backward(ctx_registry[op], dy)
grad = function.extract_tf_tensors(grad, 'backward')
# Remap gradient in the right order
ret = [None] * len(op.inputs)
for i in range(len(grad)):
@@ -67,7 +86,8 @@ class function(metaclass = function_meta):
# Return
return ret
cls.registered = True
# return result tensor
# Return tensor
return result
@classmethod

View File

@@ -15,11 +15,11 @@ import triton.frameworks as fw
import triton.utils
import triton._C.libtriton as libtriton
def _make_framework_src(src, out, grid):
def _make_framework_src(src, out, tmp, grid):
if fw.has_tensorflow():
return libtriton.make_tensorflow_src(src, out, grid)
return libtriton.make_tensorflow_src(src, out, tmp, grid)
elif fw.has_torch:
return libtriton.make_torch_src(src, out, grid)
return libtriton.make_torch_src(src, out, tmp, grid)
else:
assert False
@@ -152,8 +152,8 @@ def _cvt_to_def_str(obj):
return str(obj)
def _make_framework_op(src, outputs, options):
src, name = _make_framework_src(src, outputs, options)
def _make_framework_op(src, outputs, tmp, options):
src, name = _make_framework_src(src, outputs, tmp, options)
cache_path = _make_cache_path(src)
cpp, so = _write_bindings(src, cache_path)
_build(cpp, cache_path)
@@ -181,12 +181,13 @@ bench_registry = triton.utils.id_dict()
class kernel:
def __init__(self, src, outputs):
def __init__(self, src, outputs, tmp=[]):
self.fw_id = dict()
self.fw_grids = dict()
self.fw_op = None
self.src = src
self.outputs = outputs
self.tmp = tmp
def __call__(self, *args, **kwargs):
# create a new framework op when defines are different
@@ -210,7 +211,7 @@ class kernel:
# register function
libtriton.register_fn(op_id, self.src, opt)
if self.fw_op is None:
self.fw_op = _make_framework_op(self.src, self.outputs, opt)
self.fw_op = _make_framework_op(self.src, self.outputs, self.tmp, opt)
# benchmarking info
bench = 0
@@ -225,6 +226,7 @@ class kernel:
# call framework function
if fw.has_tensorflow():
# operands
outputs = [x for x in args[:-1] if isinstance(x, triton.utils.tf_empty_proxy)]
operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args[:-1]]
# output data types
kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id}
@@ -233,13 +235,16 @@ class kernel:
kwargs['T' + str(i)] = x.dtype
# launch
ret = self.fw_op(*operands, **kwargs)
assert len(ret) == len(outputs)
# record results
for i in range(len(outputs)):
outputs[i].tensor = ret[i]
if bench > 0:
bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id)
elif fw.has_torch():
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
ret = self.fw_op(op_id, bench, bench_id, *args)
self.fw_op(op_id, bench, bench_id, *args)
if bench > 0:
bench_registry[ret] = libtriton.retrieve_scalar(op_id)
else:
assert False
return ret
assert False

View File

@@ -13,9 +13,22 @@ void fwdbatchnorm(float *Y, float *M, float *V,
float *px[TM] = X + rm + c*N;
float* py[TM] = Y + rm + c*N;
// fetch mean/var
float mean = *(M + c);
float var = *(V + c);
// compute mean
float accm[TM] = 0;
for(int i = 0; i < N; i = i + TM)
accm = accm + *(px + i);
float mean = (float)accm[+] / N;
*(M + c) = mean;
// compute variance
float accv[TM] = 0;
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
x = x - mean;
accv = accv + x*x;
}
float var = (float)accv[+] / N;
*(V + c) = var;
// Normalize batch
float gamma = *(G + c);
@@ -28,7 +41,7 @@ void fwdbatchnorm(float *Y, float *M, float *V,
}
}
"""
fwd_kernel = triton.kernel(fwd_src, ['Y'])
fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V'])
bwd_src = """
void bwdbatchnorm(float *DX, float *DG, float *DB,
@@ -78,23 +91,26 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB'])
@staticmethod
def forward(ctx, x, mean, var, gamma, beta, eps):
def forward(ctx, x, gamma, beta, eps):
shape = triton.shape(x)
dtype = x.dtype
# allocate outputs
C, H, W, B = shape[0], shape[1], shape[2], shape[3]
y = triton.empty(shape, dtype=dtype)
mean = triton.empty([C], dtype=dtype)
var = triton.empty([C], dtype=dtype)
# execute kernels
y = _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
lambda opt: [1, C],
TM = 128)
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
lambda opt: [1, C],
TM = 128)
# save
ctx.save_for_backward(x, gamma, beta, mean, var)
ctx.save_for_backward(x, gamma, beta, mean.tensor, var.tensor)
ctx.eps = eps
return y
@staticmethod
def backward(ctx, dy):
def backward(ctx, grads):
dy, dmean, dvar = grads
# retrieve info
x, gamma, beta, mean, var = ctx.saved_tensors
eps = ctx.eps
@@ -104,11 +120,11 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
# execute
C, H, W, B = triton.shape(x)
dx, dgamma, dbeta = _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
_batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
x, gamma, mean, var,
H*W*B, eps,
lambda opt: [1, C],
TM = 128)
return dx, None, None, dgamma, dbeta, None
return dx, dgamma, dbeta, None
batchnorm = _batchnorm.apply

View File

@@ -10,6 +10,11 @@ class tf_empty_proxy:
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
self.tensor = None
def to_tensor(self):
assert self.tensor
return self.tensor
def empty(shape, dtype):
if fw.has_tensorflow():