[PYTHON][TENSORFLOW] More work
This commit is contained in:
@@ -45,11 +45,12 @@ if mode == MODE.TF:
|
||||
fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype)
|
||||
fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype)
|
||||
# execute
|
||||
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
|
||||
fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3])
|
||||
fw_y = triton.ops.batchnorm(fw_x, fw_mean, fw_var, fw_gamma, fw_beta, 1e-4)
|
||||
fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy)
|
||||
# run
|
||||
sess = tf.InteractiveSession()
|
||||
feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy}
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)
|
||||
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)
|
||||
print(result)
|
@@ -112,51 +112,70 @@ inline std::string ref_to_tf_ty(ir::type *ty) {
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string tf_normalize(const std::string& name) {
|
||||
std::string ret = name;
|
||||
auto tolower = [](char c) { return std::tolower(c);};
|
||||
std::transform(ret.begin(), ret.end(), ret.begin(), tolower);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void gen_extract_inputs(std::ostream &os, const std::vector<ir::argument*>& args, const std::vector<std::string>& outputs) {
|
||||
struct tf_alloc_t{
|
||||
enum type_t{
|
||||
OUTPUT,
|
||||
TEMP
|
||||
};
|
||||
|
||||
tf_alloc_t(const std::string& _name, type_t _type)
|
||||
: name(_name), type(_type), tf_name(tf_normalize(_name)){ }
|
||||
|
||||
std::string tf_name;
|
||||
std::string name;
|
||||
type_t type;
|
||||
size_t shape_id;
|
||||
};
|
||||
|
||||
typedef std::vector<tf_alloc_t> alloc_map_t;
|
||||
|
||||
|
||||
void gen_extract_inputs(std::ostream &os, const std::vector<ir::argument*>& args, const alloc_map_t& allocs) {
|
||||
for(unsigned i = 0; i < args.size(); i++){
|
||||
ir::value *arg = args[i];
|
||||
const std::string& name = arg->get_name();
|
||||
std::string ty = to_tf_ty(arg->get_type());
|
||||
if(!arg->get_type()->is_pointer_ty())
|
||||
os << " " << ty << " " << name << " = context->input(" << i << ").scalar<" << ty << ">()();\n ";
|
||||
else if(std::find(outputs.begin(), outputs.end(), arg->get_name()) == outputs.end())
|
||||
else if(std::find_if(allocs.begin(), allocs.end(),
|
||||
[&](tf_alloc_t x) {
|
||||
return x.name == name;
|
||||
}) == allocs.end())
|
||||
os << " const Tensor* " << name << " = &context->input(" << i << ");\n ";
|
||||
else
|
||||
os << " Tensor* " << name << " = nullptr;\n ";
|
||||
}
|
||||
}
|
||||
|
||||
void gen_set_outputs(std::ostream &os, const std::vector<ir::argument*>& args, const std::vector<std::string>& outputs) {
|
||||
for(unsigned i = 0; i < outputs.size(); i++)
|
||||
os << " TensorShape shape" << i << ";\n ";
|
||||
void gen_set_outputs(std::ostream &os, const std::vector<ir::argument*>& args, const alloc_map_t& allocs) {
|
||||
// initialize shapes
|
||||
|
||||
std::vector<int> out_idx;
|
||||
for(size_t i = 0; i < outputs.size(); i++){
|
||||
std::string name = outputs[i];
|
||||
size_t idx;
|
||||
for(idx = 0; idx < args.size(); idx++)
|
||||
if(args[idx]->get_name() == name)
|
||||
break;
|
||||
if(idx == args.size())
|
||||
throw std::runtime_error("unknown output");
|
||||
out_idx.push_back(idx);
|
||||
}
|
||||
|
||||
for(unsigned i = 0; i < outputs.size(); i++)
|
||||
os << " const Tensor& " << outputs[i] << "_shape = context->input(" << out_idx[i] << ");\n ";
|
||||
for(unsigned i = 0; i < outputs.size(); i++)
|
||||
os << " const int32* " << outputs[i] << "_shape_data = (const int32*)" << outputs[i] << "_shape.tensor_data().data();\n ";
|
||||
for(unsigned i = 0; i < outputs.size(); i++)
|
||||
os << " size_t " << outputs[i] << "_rank = " << outputs[i] << "_shape.dim_size(0);\n ";
|
||||
for(unsigned i = 0; i < outputs.size(); i++)
|
||||
os << " for(size_t d = 0; d < " << outputs[i] << "_rank ; d++) "
|
||||
<< "shape" << i << ".AddDim(" << outputs[i] << "_shape_data[d]);\n ";
|
||||
for(const auto& x: allocs)
|
||||
os << " TensorShape " << x.name << "_shape;\n ";
|
||||
for(const auto& x: allocs)
|
||||
os << " const Tensor& " << x.name << "_shape_tensor = context->input(" << x.shape_id << ");\n ";
|
||||
for(const auto& x: allocs)
|
||||
os << " const int32* " << x.name << "_shape_data = (const int32*)" << x.name << "_shape_tensor.tensor_data().data();\n ";
|
||||
for(const auto& x: allocs)
|
||||
os << " size_t " << x.name << "_rank = " << x.name << "_shape_tensor.dim_size(0);\n ";
|
||||
for(const auto& x: allocs)
|
||||
os << " for(size_t d = 0; d < " << x.name << "_rank ; d++) "
|
||||
<< x.name << "_shape.AddDim(" << x.name << "_shape_data[d]);\n ";
|
||||
|
||||
// allocate
|
||||
for(unsigned i = 0; i < outputs.size(); i++)
|
||||
os << " OP_REQUIRES_OK(context, context->allocate_output(" << i << ", shape" << i << ", &" << outputs[i] << "));\n ";
|
||||
int output = 0;
|
||||
for(const auto& x: allocs){
|
||||
if(x.type == tf_alloc_t::OUTPUT)
|
||||
os << " OP_REQUIRES_OK(context, context->allocate_output(" << output++ << ", " << x.name << "_shape, &" << x.name << "));\n ";
|
||||
else
|
||||
os << " OP_REQUIRES_OK(context, context->allocate_temp(" << x.name << "_type, " << x.name << "_shape, " << x.name << "));\n ";
|
||||
}
|
||||
}
|
||||
|
||||
void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
@@ -169,7 +188,7 @@ void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args)
|
||||
}
|
||||
}
|
||||
|
||||
void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vector<ir::argument*>& args) {
|
||||
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
os << " std::function<void()> run = [&](){\n ";
|
||||
os << " (*id_fn_map.at(id_))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
@@ -181,9 +200,9 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect
|
||||
os << ", ";
|
||||
os << name;
|
||||
}
|
||||
os << "}, *id_grid_map.at(id_), stream);\n";
|
||||
os << "}, *id_grid_map.at(id_), stream);\n ";
|
||||
os << " };\n ";
|
||||
os << " run();";
|
||||
os << " run();\n ";
|
||||
os << " if(bench_ > 0)\n ";
|
||||
os << " i64scalar_map[bench_id_] = triton::tools::bench(run, stream);\n ";
|
||||
}
|
||||
@@ -191,69 +210,53 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect
|
||||
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
const std::string &opname,
|
||||
const std::vector<ir::argument*>& args,
|
||||
const std::vector<std::string>& outputs){
|
||||
const alloc_map_t& allocs){
|
||||
|
||||
auto tolower = [](char c) { return std::tolower(c);};
|
||||
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = arg->get_name();
|
||||
std::transform(name.begin(), name.end(), name.begin(), tolower);
|
||||
std::string name = tf_normalize(arg->get_name());
|
||||
if(!arg->get_type()->is_pointer_ty())
|
||||
os << ".HostMemory(\"" + name + "\")";
|
||||
}
|
||||
for(size_t i = 0; i < outputs.size(); i++){
|
||||
std::string name = outputs[i];
|
||||
std::transform(name.begin(), name.end(), name.begin(), tolower);
|
||||
os << ".HostMemory(\"" << name << "_shape\")";
|
||||
}
|
||||
for(const auto& x: allocs)
|
||||
os << ".HostMemory(\"" << x.tf_name << "_shape\")";
|
||||
os << ", " + opname << ");\n";
|
||||
}
|
||||
|
||||
void gen_tf_register_op(std::ostream &os, const std::string &name,
|
||||
const std::vector<ir::argument*>& args,
|
||||
const std::vector<std::string>& outputs){
|
||||
const alloc_map_t& allocs){
|
||||
|
||||
auto tolower = [](char c) { return std::tolower(c);};
|
||||
|
||||
os << "REGISTER_OP(\"" << name << "\")\n";
|
||||
for(size_t i = 0; i < args.size(); i++)
|
||||
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = arg->get_name();
|
||||
std::transform(name.begin(), name.end(), name.begin(), tolower);
|
||||
if(std::find(outputs.begin(), outputs.end(), arg->get_name()) == outputs.end())
|
||||
std::string name = tf_normalize(arg->get_name());
|
||||
if(std::find_if(allocs.begin(), allocs.end(),
|
||||
[&](tf_alloc_t x) {
|
||||
return name == x.tf_name;
|
||||
}) == allocs.end())
|
||||
os << " .Input(\"" << name << ": T" << i << "\")\n";
|
||||
else
|
||||
os << " .Input(\"" << name << "_shape: int32\")\n";
|
||||
}
|
||||
std::vector<int> out_idx;
|
||||
for(size_t i = 0; i < outputs.size(); i++){
|
||||
std::string name = outputs[i];
|
||||
size_t idx;
|
||||
for(idx = 0; idx < args.size(); idx++)
|
||||
if(args[idx]->get_name() == name)
|
||||
break;
|
||||
if(idx == args.size())
|
||||
throw std::runtime_error("unknown output");
|
||||
out_idx.push_back(idx);
|
||||
}
|
||||
for(size_t i = 0; i < out_idx.size(); i++){
|
||||
std::string name = outputs[i];
|
||||
std::transform(name.begin(), name.end(), name.begin(), tolower);
|
||||
os << " .Output(\"" << name << ": T" << out_idx[i] << "\")\n";
|
||||
}
|
||||
for(const auto& x: allocs)
|
||||
if(x.type == tf_alloc_t::OUTPUT)
|
||||
os << " .Output(\"" << x.tf_name << ": T" << x.shape_id << "\")\n";
|
||||
os << " .Attr(\"id: int\")\n";
|
||||
os << " .Attr(\"bench: int\")\n";
|
||||
os << " .Attr(\"bench_id: int\")\n";
|
||||
os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* ctx) {\n";
|
||||
for(size_t i = 0; i < out_idx.size(); i++)
|
||||
os << " shape_inference::ShapeHandle handle" << i << ";\n";
|
||||
for(size_t i = 0; i < out_idx.size(); i++)
|
||||
os << " ctx->MakeShapeFromShapeTensor(" << out_idx[i] << ", &handle" << i << ");\n";
|
||||
for(size_t i = 0; i < out_idx.size(); i++)
|
||||
os << " ctx->set_output(" << i << ", handle" << i << ");\n";
|
||||
size_t current = 0;
|
||||
for(const auto& x: allocs)
|
||||
if(x.type == tf_alloc_t::OUTPUT){
|
||||
os << " shape_inference::ShapeHandle " << x.tf_name << "_handle;\n";
|
||||
os << " ctx->MakeShapeFromShapeTensor(" << x.shape_id << ", &" << x.tf_name << "_handle);\n";
|
||||
os << " ctx->set_output(" << current++ << ", " << x.tf_name << "_handle);\n";
|
||||
}
|
||||
os << " return Status::OK();\n";
|
||||
os << " })\n";
|
||||
|
||||
@@ -280,6 +283,7 @@ void make_module(const std::string& src, ir::module* ir,
|
||||
std::tuple<std::string,
|
||||
std::string> make_tensorflow_src(const std::string& src,
|
||||
const std::vector<std::string>& outputs,
|
||||
const std::vector<std::string>& tmp,
|
||||
const runtime::function::options_space_t& opt)
|
||||
{
|
||||
// triton-ir code-gen
|
||||
@@ -289,10 +293,28 @@ std::tuple<std::string,
|
||||
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
const std::vector<ir::argument*>& args = fn->args();
|
||||
std::string name = fn->get_name();
|
||||
std::string cc_name = name;
|
||||
cc_name[0] = static_cast<char>(std::toupper(cc_name[0]));
|
||||
std::string opname = cc_name + "Op";
|
||||
|
||||
// allocation info
|
||||
alloc_map_t allocs;
|
||||
for(size_t i = 0; i < outputs.size(); i++)
|
||||
allocs.push_back(tf_alloc_t(outputs[i], tf_alloc_t::OUTPUT));
|
||||
for(size_t i = 0; i < tmp.size(); i++)
|
||||
allocs.push_back(tf_alloc_t(tmp[i], tf_alloc_t::TEMP));
|
||||
|
||||
for(auto &x: allocs){
|
||||
size_t idx;
|
||||
for(idx = 0; idx < args.size(); idx++)
|
||||
if(args[idx]->get_name() == x.name)
|
||||
break;
|
||||
if(idx == args.size())
|
||||
throw std::runtime_error("unknown output");
|
||||
x.shape_id = idx;
|
||||
}
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << R"(
|
||||
@@ -323,6 +345,11 @@ class )" << opname << R"(: public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("bench", &bench_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("bench_id", &bench_id_));
|
||||
)";
|
||||
for(const auto& alloc: allocs)
|
||||
oss << " OP_REQUIRES_OK(context, context->GetAttr(\"T" << alloc.shape_id << "\", &" << alloc.name << "_type));\n ";
|
||||
|
||||
oss << R"(
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
@@ -335,21 +362,21 @@ class )" << opname << R"(: public OpKernel {
|
||||
|
||||
// extract inputs
|
||||
)";
|
||||
gen_extract_inputs(oss, fn->args(), outputs);
|
||||
gen_extract_inputs(oss, args, allocs);
|
||||
oss << R"(
|
||||
// set outputs
|
||||
)";
|
||||
gen_set_outputs(oss, fn->args(), outputs);
|
||||
gen_set_outputs(oss, args, allocs);
|
||||
oss << R"(
|
||||
// wrap tensors
|
||||
)";
|
||||
gen_make_handles(oss, fn->args());
|
||||
gen_make_handles(oss, args);
|
||||
oss << R"(
|
||||
)";
|
||||
oss << R"(
|
||||
// launch function
|
||||
)";
|
||||
gen_make_launch_function(oss, outputs.size(), fn->args());
|
||||
gen_make_launch_function(oss, args);
|
||||
oss << R"(
|
||||
}
|
||||
|
||||
@@ -357,15 +384,20 @@ private:
|
||||
int id_;
|
||||
int bench_;
|
||||
int64 bench_id_;
|
||||
)";
|
||||
for(const auto& alloc: allocs)
|
||||
oss << "DataType " << alloc.name << "_type;\n ";
|
||||
|
||||
oss << R"(
|
||||
};
|
||||
|
||||
// register kernel builder
|
||||
)";
|
||||
gen_tf_register_kernel_builder(oss, cc_name, opname, fn->args(), outputs);
|
||||
gen_tf_register_kernel_builder(oss, cc_name, opname, args, allocs);
|
||||
oss << R"(
|
||||
// register op
|
||||
)";
|
||||
gen_tf_register_op(oss, cc_name, fn->args(), outputs);
|
||||
gen_tf_register_op(oss, cc_name, args, allocs);
|
||||
|
||||
return {oss.str(), name};
|
||||
}
|
||||
@@ -517,6 +549,7 @@ void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
||||
std::tuple<std::string,
|
||||
std::string> make_torch_src(const std::string& src,
|
||||
const std::vector<std::string>& outputs,
|
||||
const std::vector<std::string>& tmp,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
|
@@ -41,24 +41,43 @@ class function(metaclass = function_meta):
|
||||
return cls.backward(ctx, grad_output)
|
||||
return TorchFunction.apply(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def extract_tf_tensors(cls, lst, err):
|
||||
for x in lst:
|
||||
if x and not isinstance(x, triton.utils.tf_empty_proxy):
|
||||
raise ValueError('Results of ' + err + ' must be created using triton.empty()')
|
||||
if x and x.tensor is None:
|
||||
raise ValueError('Empty tensor never filled during ' + err)
|
||||
return [x.tensor if x else None for x in lst]
|
||||
|
||||
@classmethod
|
||||
def apply_tensorflow(cls, *args, **kwargs):
|
||||
ctx = OpContext()
|
||||
result = cls.forward(ctx, *args, **kwargs)
|
||||
op = result[0].op if isinstance(result, tuple) else result.op
|
||||
|
||||
# check that all the results stem from triton.empty
|
||||
# and get the corresponding TF tensors if possible
|
||||
result = result if isinstance(result, tuple) else (result, )
|
||||
result = function.extract_tf_tensors(result, 'forward')
|
||||
|
||||
# Find a mapping between ::forward arguments and tensorflow op arguments
|
||||
op = result[0].op
|
||||
remap = dict()
|
||||
for i, ix in enumerate(result.op.inputs):
|
||||
for i, ix in enumerate(op.inputs):
|
||||
for j, jx in enumerate(args):
|
||||
if ix is jx:
|
||||
remap[j] = i
|
||||
# register backward
|
||||
|
||||
# Register backward pass
|
||||
ctx_registry[op] = ctx
|
||||
name = op.op_def.name
|
||||
if not cls.registered:
|
||||
@fw.tensorflow.RegisterGradient(name)
|
||||
def gradient(op, *dys):
|
||||
grad = cls.backward(ctx_registry[op], dys if len(dys) > 1 else dys[0])
|
||||
def gradient(op, *dy):
|
||||
dy = dy if len(dy) > 1 else dy[0]
|
||||
grad = cls.backward(ctx_registry[op], dy)
|
||||
grad = function.extract_tf_tensors(grad, 'backward')
|
||||
|
||||
# Remap gradient in the right order
|
||||
ret = [None] * len(op.inputs)
|
||||
for i in range(len(grad)):
|
||||
@@ -67,7 +86,8 @@ class function(metaclass = function_meta):
|
||||
# Return
|
||||
return ret
|
||||
cls.registered = True
|
||||
# return result tensor
|
||||
|
||||
# Return tensor
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
|
@@ -15,11 +15,11 @@ import triton.frameworks as fw
|
||||
import triton.utils
|
||||
import triton._C.libtriton as libtriton
|
||||
|
||||
def _make_framework_src(src, out, grid):
|
||||
def _make_framework_src(src, out, tmp, grid):
|
||||
if fw.has_tensorflow():
|
||||
return libtriton.make_tensorflow_src(src, out, grid)
|
||||
return libtriton.make_tensorflow_src(src, out, tmp, grid)
|
||||
elif fw.has_torch:
|
||||
return libtriton.make_torch_src(src, out, grid)
|
||||
return libtriton.make_torch_src(src, out, tmp, grid)
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -152,8 +152,8 @@ def _cvt_to_def_str(obj):
|
||||
return str(obj)
|
||||
|
||||
|
||||
def _make_framework_op(src, outputs, options):
|
||||
src, name = _make_framework_src(src, outputs, options)
|
||||
def _make_framework_op(src, outputs, tmp, options):
|
||||
src, name = _make_framework_src(src, outputs, tmp, options)
|
||||
cache_path = _make_cache_path(src)
|
||||
cpp, so = _write_bindings(src, cache_path)
|
||||
_build(cpp, cache_path)
|
||||
@@ -181,12 +181,13 @@ bench_registry = triton.utils.id_dict()
|
||||
|
||||
class kernel:
|
||||
|
||||
def __init__(self, src, outputs):
|
||||
def __init__(self, src, outputs, tmp=[]):
|
||||
self.fw_id = dict()
|
||||
self.fw_grids = dict()
|
||||
self.fw_op = None
|
||||
self.src = src
|
||||
self.outputs = outputs
|
||||
self.tmp = tmp
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# create a new framework op when defines are different
|
||||
@@ -210,7 +211,7 @@ class kernel:
|
||||
# register function
|
||||
libtriton.register_fn(op_id, self.src, opt)
|
||||
if self.fw_op is None:
|
||||
self.fw_op = _make_framework_op(self.src, self.outputs, opt)
|
||||
self.fw_op = _make_framework_op(self.src, self.outputs, self.tmp, opt)
|
||||
|
||||
# benchmarking info
|
||||
bench = 0
|
||||
@@ -225,6 +226,7 @@ class kernel:
|
||||
# call framework function
|
||||
if fw.has_tensorflow():
|
||||
# operands
|
||||
outputs = [x for x in args[:-1] if isinstance(x, triton.utils.tf_empty_proxy)]
|
||||
operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args[:-1]]
|
||||
# output data types
|
||||
kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id}
|
||||
@@ -233,13 +235,16 @@ class kernel:
|
||||
kwargs['T' + str(i)] = x.dtype
|
||||
# launch
|
||||
ret = self.fw_op(*operands, **kwargs)
|
||||
assert len(ret) == len(outputs)
|
||||
# record results
|
||||
for i in range(len(outputs)):
|
||||
outputs[i].tensor = ret[i]
|
||||
if bench > 0:
|
||||
bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id)
|
||||
elif fw.has_torch():
|
||||
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
|
||||
ret = self.fw_op(op_id, bench, bench_id, *args)
|
||||
self.fw_op(op_id, bench, bench_id, *args)
|
||||
if bench > 0:
|
||||
bench_registry[ret] = libtriton.retrieve_scalar(op_id)
|
||||
else:
|
||||
assert False
|
||||
return ret
|
||||
assert False
|
@@ -13,9 +13,22 @@ void fwdbatchnorm(float *Y, float *M, float *V,
|
||||
float *px[TM] = X + rm + c*N;
|
||||
float* py[TM] = Y + rm + c*N;
|
||||
|
||||
// fetch mean/var
|
||||
float mean = *(M + c);
|
||||
float var = *(V + c);
|
||||
// compute mean
|
||||
float accm[TM] = 0;
|
||||
for(int i = 0; i < N; i = i + TM)
|
||||
accm = accm + *(px + i);
|
||||
float mean = (float)accm[+] / N;
|
||||
*(M + c) = mean;
|
||||
|
||||
// compute variance
|
||||
float accv[TM] = 0;
|
||||
for(int i = 0; i < N; i = i + TM){
|
||||
float x[TM] = *(px + i);
|
||||
x = x - mean;
|
||||
accv = accv + x*x;
|
||||
}
|
||||
float var = (float)accv[+] / N;
|
||||
*(V + c) = var;
|
||||
|
||||
// Normalize batch
|
||||
float gamma = *(G + c);
|
||||
@@ -28,7 +41,7 @@ void fwdbatchnorm(float *Y, float *M, float *V,
|
||||
}
|
||||
}
|
||||
"""
|
||||
fwd_kernel = triton.kernel(fwd_src, ['Y'])
|
||||
fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V'])
|
||||
|
||||
bwd_src = """
|
||||
void bwdbatchnorm(float *DX, float *DG, float *DB,
|
||||
@@ -78,23 +91,26 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
|
||||
bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB'])
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, mean, var, gamma, beta, eps):
|
||||
def forward(ctx, x, gamma, beta, eps):
|
||||
shape = triton.shape(x)
|
||||
dtype = x.dtype
|
||||
# allocate outputs
|
||||
C, H, W, B = shape[0], shape[1], shape[2], shape[3]
|
||||
y = triton.empty(shape, dtype=dtype)
|
||||
mean = triton.empty([C], dtype=dtype)
|
||||
var = triton.empty([C], dtype=dtype)
|
||||
# execute kernels
|
||||
y = _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
|
||||
lambda opt: [1, C],
|
||||
TM = 128)
|
||||
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
|
||||
lambda opt: [1, C],
|
||||
TM = 128)
|
||||
# save
|
||||
ctx.save_for_backward(x, gamma, beta, mean, var)
|
||||
ctx.save_for_backward(x, gamma, beta, mean.tensor, var.tensor)
|
||||
ctx.eps = eps
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
def backward(ctx, grads):
|
||||
dy, dmean, dvar = grads
|
||||
# retrieve info
|
||||
x, gamma, beta, mean, var = ctx.saved_tensors
|
||||
eps = ctx.eps
|
||||
@@ -104,11 +120,11 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
|
||||
dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
|
||||
# execute
|
||||
C, H, W, B = triton.shape(x)
|
||||
dx, dgamma, dbeta = _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
|
||||
_batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
|
||||
x, gamma, mean, var,
|
||||
H*W*B, eps,
|
||||
lambda opt: [1, C],
|
||||
TM = 128)
|
||||
return dx, None, None, dgamma, dbeta, None
|
||||
return dx, dgamma, dbeta, None
|
||||
|
||||
batchnorm = _batchnorm.apply
|
@@ -10,6 +10,11 @@ class tf_empty_proxy:
|
||||
def __init__(self, shape, dtype):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.tensor = None
|
||||
|
||||
def to_tensor(self):
|
||||
assert self.tensor
|
||||
return self.tensor
|
||||
|
||||
def empty(shape, dtype):
|
||||
if fw.has_tensorflow():
|
||||
|
Reference in New Issue
Block a user