testing some register gradient
This commit is contained in:
@@ -84,7 +84,7 @@ def cdiv(a, b):
|
||||
|
||||
class dot:
|
||||
|
||||
def __init__(self, trans_a = False, trans_b = True):
|
||||
def __init__(self, trans_a = False, trans_b = False):
|
||||
self.dot = triton.op(src, ['C'])
|
||||
self.trans_a = trans_a
|
||||
self.trans_b = trans_b
|
||||
@@ -102,26 +102,36 @@ class dot:
|
||||
return self.dot(a, b, c, M, N, K, lda, ldb, ldc,
|
||||
lambda opt: [cdiv(M, opt.d('TM')), cdiv(N, opt.d('TN'))],
|
||||
AT = self.trans_a, BT = self.trans_b, TYPE = tf.float16,
|
||||
TM = [32, 64, 128], TN = [32, 64, 128], TK = [32])
|
||||
TM = [128], TN = [ 128], TK = [32])
|
||||
|
||||
dot_tn = dot()
|
||||
dot_nt = dot(False, True)
|
||||
dot_nn = dot(False, False)
|
||||
dot_tn = dot(True, False)
|
||||
dot_tt = dot(True, True)
|
||||
|
||||
@triton.register_gradient(dot)
|
||||
def _dot_grad(op, dy):
|
||||
a = op.inputs[0]
|
||||
b = op.inputs[1]
|
||||
return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
|
||||
|
||||
def run_dot():
|
||||
M, N, K = 128, 128, 128
|
||||
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||
b = tf.placeholder(tf.float16, shape=[N, K])
|
||||
# c = tf.matmul(a, b, transpose_a=True)
|
||||
c = dot_tn(a, b)
|
||||
c = dot_nn(a, b)
|
||||
grads = tf.gradients(c, [a])
|
||||
# Reference
|
||||
ha = np.random.rand(M, K).astype(np.float16)
|
||||
hb = np.random.rand(N, K).astype(np.float16)
|
||||
# Run
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([c], feed_dict = {a: ha,
|
||||
result = sess.run([grads], feed_dict = {a: ha,
|
||||
b: hb})[0]
|
||||
# Test
|
||||
hresult = np.dot(ha.T, hb).T
|
||||
hresult = np.dot(ha.T, hb.T).T
|
||||
dif = np.abs(result - hresult)
|
||||
np.savetxt('dif.dat', dif, '%2.4f')
|
||||
print(hresult)
|
||||
|
@@ -24,21 +24,19 @@ namespace rt = triton::runtime;
|
||||
|
||||
/* TF triton op properties */
|
||||
|
||||
std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
|
||||
std::map<size_t, rt::function*> id_fn_map;
|
||||
std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
std::map<size_t, int64_t> i64scalar_map;
|
||||
|
||||
void register_grid(size_t id,
|
||||
const rt::function::grid_fn_ty& grid_fn) {
|
||||
id_grid_map[id] = grid_fn;
|
||||
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
|
||||
}
|
||||
|
||||
void register_fn(size_t id,
|
||||
const std::string& src,
|
||||
const rt::function::options_space_t& opt) {
|
||||
bool is_inserted = id_fn_map.insert({id, new rt::function(src, opt)}).second;
|
||||
if(!is_inserted)
|
||||
assert(false);
|
||||
id_fn_map[id].reset(new rt::function(src, opt));
|
||||
}
|
||||
|
||||
size_t make_op_id() {
|
||||
@@ -135,7 +133,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
|
||||
os << ", ";
|
||||
os << name;
|
||||
}
|
||||
os << "}, id_grid_map.at(id_), stream); \n";
|
||||
os << "}, *id_grid_map.at(id_), stream); \n";
|
||||
}
|
||||
|
||||
void gen_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
@@ -214,7 +212,7 @@ std::tuple<std::string,
|
||||
parser.Parse();
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::unique_ptr<ir::module>(new ir::module("", ctx));
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
Generator gen(&parser);
|
||||
gen.Gen(&*ir);
|
||||
// function
|
||||
@@ -245,8 +243,8 @@ using GPUDevice = Eigen::GpuDevice;
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
extern std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
|
||||
extern std::map<size_t, rt::function*> id_fn_map;
|
||||
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
|
||||
|
||||
class )" << opname << R"(: public OpKernel {
|
||||
@@ -294,6 +292,7 @@ oss << R"(
|
||||
)";
|
||||
gen_register_op(oss, cc_name, fn->args(), outputs);
|
||||
|
||||
|
||||
return {oss.str(), name};
|
||||
}
|
||||
|
||||
|
@@ -13,6 +13,8 @@ import setuptools
|
||||
import libtriton
|
||||
# frameworks
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so')
|
||||
|
||||
@@ -230,13 +232,24 @@ class op:
|
||||
op = self.fw_ops[key]
|
||||
# register grid
|
||||
grid = _make_grid(args)
|
||||
libtriton.register_grid(op_id, grid)
|
||||
self.fw_grids[key] = grid
|
||||
libtriton.register_grid(op_id, self.fw_grids[key])
|
||||
# create operands
|
||||
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
|
||||
# call framework op
|
||||
return op(*op_args, id=op_id)
|
||||
|
||||
|
||||
class register_gradient:
|
||||
|
||||
def __init__(self, op):
|
||||
self.op = op
|
||||
|
||||
def __call__(self, f):
|
||||
name = 'Dot'
|
||||
ops.RegisterGradient(name)(f)
|
||||
|
||||
|
||||
def empty(shapes):
|
||||
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
|
||||
args = tf.stack(args)
|
||||
|
Reference in New Issue
Block a user