[python] refactoring in anticipation of pytorch support
This commit is contained in:
@@ -34,19 +34,19 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
|
||||
/* pointers for A */
|
||||
#if AT == 1
|
||||
TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda;
|
||||
TYPE* pa[TK, TM] = A + rka[:, newaxis]*lda + rxa[newaxis, :];
|
||||
TYPE a[TK, TM] = *pa;
|
||||
#else
|
||||
TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis];
|
||||
TYPE* pa[TM, TK] = A + rka[newaxis, :] + rxa[:, newaxis]*lda;
|
||||
TYPE a[TM, TK] = *pa;
|
||||
#endif
|
||||
|
||||
/* pointers for B */
|
||||
#if BT == 1
|
||||
TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis];
|
||||
TYPE* pb[TN, TK] = B + rkb[newaxis, :] + ryb[:, newaxis]*ldb;
|
||||
TYPE b[TN, TK] = *pb;
|
||||
#else
|
||||
TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb;
|
||||
TYPE* pb[TK, TN] = B + rkb[:, newaxis]*ldb + ryb[newaxis, :];
|
||||
TYPE b[TK, TN] = *pb;
|
||||
#endif
|
||||
|
||||
@@ -54,14 +54,14 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = USEA @ USEB + xc;
|
||||
#if AT == 1
|
||||
pa = pa + TK;
|
||||
#else
|
||||
pa = pa + TK*lda;
|
||||
#else
|
||||
pa = pa + TK;
|
||||
#endif
|
||||
#if BT == 1
|
||||
pb = pb + TK*ldb;
|
||||
#else
|
||||
pb = pb + TK;
|
||||
#else
|
||||
pb = pb + TK*ldb;
|
||||
#endif
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
@@ -70,19 +70,19 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
/* epilogue */
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
TYPE* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc;
|
||||
TYPE c[TM, TN] = xc;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
*pc = c;
|
||||
*?(checkc) pc = c;
|
||||
}
|
||||
"""
|
||||
|
||||
def cdiv(a, b):
|
||||
return -(-a // b)
|
||||
|
||||
class dot:
|
||||
class dot_op:
|
||||
|
||||
def __init__(self, trans_a = False, trans_b = False):
|
||||
self.dot = triton.op(src, ['C'])
|
||||
@@ -93,10 +93,18 @@ class dot:
|
||||
shape_a = triton.shape(a)
|
||||
shape_b = triton.shape(b)
|
||||
M = shape_a[0]
|
||||
K = shape_a[1]
|
||||
N = shape_b[0]
|
||||
lda = M
|
||||
ldb = K
|
||||
Ka = shape_a[1]
|
||||
Kb = shape_b[0]
|
||||
N = shape_b[1]
|
||||
# transpose shapes
|
||||
if self.trans_a:
|
||||
M, Ka = Ka, M
|
||||
if self.trans_b:
|
||||
Kb, N = N, Kb
|
||||
K = Ka
|
||||
# contiguous dimensions
|
||||
lda = Ka
|
||||
ldb = N
|
||||
ldc = N
|
||||
c = triton.empty([M, N])
|
||||
return self.dot(a, b, c, M, N, K, lda, ldb, ldc,
|
||||
@@ -104,34 +112,34 @@ class dot:
|
||||
AT = self.trans_a, BT = self.trans_b, TYPE = tf.float16,
|
||||
TM = [128], TN = [ 128], TK = [32])
|
||||
|
||||
dot_nt = dot(False, True)
|
||||
dot_nn = dot(False, False)
|
||||
dot_tn = dot(True, False)
|
||||
dot_tt = dot(True, True)
|
||||
dot_nt = dot_op(False, True)
|
||||
dot_nn = dot_op(False, False)
|
||||
dot_tn = dot_op(True, False)
|
||||
dot_tt = dot_op(True, True)
|
||||
|
||||
@triton.register_gradient(dot)
|
||||
def _dot_grad(op, dy):
|
||||
a = op.inputs[0]
|
||||
b = op.inputs[1]
|
||||
return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
|
||||
# @triton.register_gradient(dot_op)
|
||||
# def _dot_grad(op, dy):
|
||||
# a = op.inputs[0]
|
||||
# b = op.inputs[1]
|
||||
# return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
|
||||
|
||||
def run_dot():
|
||||
M, N, K = 128, 128, 128
|
||||
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||
b = tf.placeholder(tf.float16, shape=[N, K])
|
||||
# c = tf.matmul(a, b, transpose_a=True)
|
||||
c = dot_nn(a, b)
|
||||
grads = tf.gradients(c, [a])
|
||||
c = dot_nt(a, b)
|
||||
# grads = tf.gradients(c, [a])
|
||||
# Reference
|
||||
ha = np.random.rand(M, K).astype(np.float16)
|
||||
hb = np.random.rand(N, K).astype(np.float16)
|
||||
hb = np.random.rand(K, N).astype(np.float16)
|
||||
# Run
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([grads], feed_dict = {a: ha,
|
||||
result = sess.run([c], feed_dict = {a: ha,
|
||||
b: hb})[0]
|
||||
# Test
|
||||
hresult = np.dot(ha.T, hb.T).T
|
||||
hresult = np.dot(ha, hb.T)
|
||||
dif = np.abs(result - hresult)
|
||||
np.savetxt('dif.dat', dif, '%2.4f')
|
||||
print(hresult)
|
||||
|
@@ -1,7 +0,0 @@
|
||||
#include <stdio.h>
|
||||
|
||||
int main(){
|
||||
const char* TEST = "test\n";
|
||||
const char* LOL = "lol\n";
|
||||
printf("%s\n",DTYPE);
|
||||
}
|
@@ -136,7 +136,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
|
||||
os << "}, *id_grid_map.at(id_), stream); \n";
|
||||
}
|
||||
|
||||
void gen_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
const std::string &opname,
|
||||
const std::vector<ir::argument*>& args){
|
||||
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
|
||||
@@ -151,7 +151,7 @@ void gen_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
os << ", " + opname << ");\n";
|
||||
}
|
||||
|
||||
void gen_register_op(std::ostream &os, const std::string &name,
|
||||
void gen_tf_register_op(std::ostream &os, const std::string &name,
|
||||
const std::vector<ir::argument*>& args,
|
||||
const std::vector<std::string>& outputs){
|
||||
os << "REGISTER_OP(\"" << name << "\")\n";
|
||||
@@ -195,15 +195,12 @@ extern int get_program_id(int);
|
||||
)";
|
||||
}
|
||||
|
||||
std::tuple<std::string,
|
||||
std::string> make_tensorflow_src(std::string src,
|
||||
const std::vector<std::string>& outputs,
|
||||
const runtime::function::options_space_t& opt)
|
||||
{
|
||||
src = preheader() + src;
|
||||
void make_module(const std::string& src, ir::module* ir,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
std::string copy = preheader() + src;
|
||||
// pre-process
|
||||
TokenSequence tokens;
|
||||
Preprocessor cpp(&src, true);
|
||||
Preprocessor cpp(©, true);
|
||||
for(auto it: opt.defines){
|
||||
cpp.AddMacro(it.first, &it.second[0]);
|
||||
}
|
||||
@@ -211,11 +208,19 @@ std::tuple<std::string,
|
||||
// parse
|
||||
Parser parser(tokens);
|
||||
parser.Parse();
|
||||
Generator gen(&parser);
|
||||
gen.Gen(ir);
|
||||
}
|
||||
|
||||
std::tuple<std::string,
|
||||
std::string> make_tensorflow_src(const std::string& src,
|
||||
const std::vector<std::string>& outputs,
|
||||
const runtime::function::options_space_t& opt)
|
||||
{
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
Generator gen(&parser);
|
||||
gen.Gen(&*ir);
|
||||
make_module(src, &*ir, opt);
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
std::string name = fn->get_name();
|
||||
@@ -287,16 +292,145 @@ private:
|
||||
|
||||
// register kernel builder
|
||||
)";
|
||||
gen_register_kernel_builder(oss, cc_name, opname, fn->args());
|
||||
gen_tf_register_kernel_builder(oss, cc_name, opname, fn->args());
|
||||
oss << R"(
|
||||
// register op
|
||||
)";
|
||||
gen_register_op(oss, cc_name, fn->args(), outputs);
|
||||
|
||||
gen_tf_register_op(oss, cc_name, fn->args(), outputs);
|
||||
|
||||
return {oss.str(), name};
|
||||
}
|
||||
|
||||
|
||||
inline std::string to_torch_ty(ir::type *ty) {
|
||||
if(ty->is_integer_ty(1))
|
||||
return "bool";
|
||||
if(ty->is_integer_ty(8))
|
||||
return "int8";
|
||||
if(ty->is_integer_ty(16))
|
||||
return "int16";
|
||||
if(ty->is_integer_ty(32))
|
||||
return "int32";
|
||||
if(ty->is_integer_ty(64))
|
||||
return "int64";
|
||||
if(ty->is_half_ty())
|
||||
return "float16";
|
||||
if(ty->is_float_ty())
|
||||
return "float32";
|
||||
if(ty->is_double_ty())
|
||||
return "float64";
|
||||
if(ty->is_pointer_ty())
|
||||
return "Tensor";
|
||||
throw std::runtime_error("unknown type");
|
||||
}
|
||||
|
||||
|
||||
|
||||
void gen_torch_signature(std::ostringstream& oss,
|
||||
ir::function* fn,
|
||||
const std::vector<std::string>& outputs,
|
||||
const std::string& name) {
|
||||
const auto& args = fn->args();
|
||||
std::vector<ir::type*> out_types;
|
||||
for(const std::string& out: outputs) {
|
||||
auto it = std::find_if(args.begin(), args.end(),
|
||||
[&](ir::argument* arg) { return arg->get_name() == out; });
|
||||
if(it == args.end())
|
||||
throw std::runtime_error("unknown argument");
|
||||
out_types.push_back((*it)->get_type());
|
||||
}
|
||||
|
||||
oss << "std::tuple<";
|
||||
for(size_t i = 0; i < out_types.size(); i++){
|
||||
if(i > 0)
|
||||
oss << ", ";
|
||||
oss << to_torch_ty(out_types[i]);
|
||||
}
|
||||
oss << "> ";
|
||||
oss << name << "(";
|
||||
oss << "int64 id" << std::endl;
|
||||
for(size_t i = 0; i < args.size(); i++) {
|
||||
ir::argument* arg = args[i];
|
||||
if(i > 0)
|
||||
oss << ", ";
|
||||
oss << to_torch_ty(arg->get_type()) << " " << arg->get_name();
|
||||
}
|
||||
oss << ")";
|
||||
}
|
||||
|
||||
void gen_torch_init_driver(std::ostringstream &oss) {
|
||||
oss << " // Wrap CUDA handles" << std::endl;
|
||||
oss << " c10::DeviceIndex device = torcha.storage().device().index();" << std::endl;
|
||||
oss << " // Get stream" << std::endl;
|
||||
oss << " CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();" << std::endl;
|
||||
oss << " triton::driver::cu_stream stream(custream, false);" << std::endl;
|
||||
oss << " triton::driver::context* ctx = stream.context();" << std::endl;
|
||||
}
|
||||
|
||||
void gen_torch_make_handles(std::ostream &os,
|
||||
const std::vector<ir::argument*>& args) {
|
||||
for(unsigned i = 0; i < args.size(); i++){
|
||||
ir::argument *arg = args[i];
|
||||
if(!arg->get_type()->is_pointer_ty())
|
||||
continue;
|
||||
const std::string& name = arg->get_name();
|
||||
os << " drv::cu_buffer cu_" + name + "(ctx, " + name + ".storage().size(), (CUdeviceptr)" + name + ".storage.data(), false);\n ";
|
||||
}
|
||||
}
|
||||
|
||||
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
os << " (*id_fn_map.at(id))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = arg->get_name();
|
||||
if(arg->get_type()->is_pointer_ty())
|
||||
name = "&cu_" + name;
|
||||
if(i > 0)
|
||||
os << ", ";
|
||||
os << name;
|
||||
}
|
||||
os << "}, *id_grid_map.at(id), stream); \n";
|
||||
}
|
||||
|
||||
|
||||
std::tuple<std::string,
|
||||
std::string> make_pytorch_src(const std::string& src,
|
||||
const std::vector<std::string>& outputs,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
make_module(src, &*ir, opt);
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
std::string name = fn->get_name();
|
||||
// generate framework code
|
||||
std::ostringstream oss;
|
||||
oss << R"(
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||
|
||||
)";
|
||||
|
||||
gen_torch_signature(oss, fn, outputs, name);
|
||||
oss << " {" << std::endl;
|
||||
gen_torch_init_driver(oss);
|
||||
gen_torch_make_handles(oss, fn->args());
|
||||
gen_torch_make_launch_function(oss, fn->args());
|
||||
oss << std::endl << "}";
|
||||
|
||||
oss << "static auto registry = torch::jit::RegisterOperators(\"triton::" << name << "\", &" << name << ");" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
typedef triton::runtime::function::options_t options_t;
|
||||
typedef triton::runtime::function::options_space_t options_space_t;
|
||||
|
||||
@@ -307,6 +441,8 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
m.def("make_tensorflow_src", &make_tensorflow_src,
|
||||
"Creates C++ source code for a custom Tensorflow op "
|
||||
"corresponding to the specified Triton kernel");
|
||||
m.def("make_pytorch_src", &make_pytorch_src,
|
||||
"Creates C++ source code for a custom PyTorch op ");
|
||||
|
||||
// bindings for triton classes
|
||||
pybind11::class_<options_t>(m, "options")
|
||||
|
@@ -11,18 +11,60 @@ import setuptools.command.build_ext
|
||||
import setuptools
|
||||
# triton
|
||||
import libtriton
|
||||
# frameworks
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so')
|
||||
torch_id = 'torch'
|
||||
tensorflow_id = 'tensorflow'
|
||||
|
||||
torch = None
|
||||
tensorflow = None
|
||||
tf_extra_ops = None
|
||||
|
||||
def _import_torch():
|
||||
global torch
|
||||
if torch is None:
|
||||
import torch
|
||||
|
||||
def _import_tensorflow():
|
||||
global tensorflow
|
||||
if tensorflow is None:
|
||||
import tensorflow
|
||||
|
||||
def _import_tf_extra_ops():
|
||||
global tf_extra_ops
|
||||
if tf_extra_ops is None:
|
||||
path = os.path.dirname(libtriton.__file__)
|
||||
path = os.path.join(path, 'libextra_tf_ops.so')
|
||||
_import_tensorflow()
|
||||
tf_extra_ops = tensorflow.load_op_library(path)
|
||||
|
||||
|
||||
def make_bindings(src, out, grid):
|
||||
return libtriton.make_tensorflow_src(src, out, grid)
|
||||
def _find_framework(default = None):
|
||||
is_tf_imported = 'tensorflow' in sys.modules
|
||||
is_torch_imported = 'torch' in sys.modules
|
||||
if default:
|
||||
if default not in [tensorflow_id, torch_id]:
|
||||
raise ValueError('unsupported framework')
|
||||
else:
|
||||
return default
|
||||
elif is_tf_imported and not is_torch_imported:
|
||||
return tensorflow_id
|
||||
elif is_torch_imported and not is_tf_imported:
|
||||
return torch_id
|
||||
else:
|
||||
raise ValueError('cannot determine imported framework, '
|
||||
'please provide framework argument')
|
||||
|
||||
def make_cache_path(src):
|
||||
|
||||
def _make_framework_src(src, out, grid, framework):
|
||||
if framework == tensorflow_id:
|
||||
return libtriton.make_tensorflow_src(src, out, grid)
|
||||
elif framework == torch_id:
|
||||
return libtriton.make_torch_src(src, out, grid)
|
||||
else:
|
||||
assert False
|
||||
|
||||
def _make_cache_path(src):
|
||||
md5 = hashlib.sha1(src.encode())
|
||||
hexhash = md5.hexdigest()
|
||||
home = os.path.expanduser('~')
|
||||
@@ -32,10 +74,10 @@ def make_cache_path(src):
|
||||
os.makedirs(cachepath)
|
||||
return cachepath
|
||||
|
||||
def write_bindings(src, root):
|
||||
cpp = os.path.join(root, 'tensorflow.cpp')
|
||||
def _write_bindings(src, root, framework):
|
||||
cpp = os.path.join(root, '{framework}.cpp'.format(framework=framework))
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(root, 'tensorflow{suffix}'.format(suffix=suffix))
|
||||
so = os.path.join(root, '{framework}{suffix}'.format(framework=framework, suffix=suffix))
|
||||
recompile = False
|
||||
# recompile if .so does not exist
|
||||
if not os.path.exists(cpp) or not os.path.exists(so):
|
||||
@@ -50,18 +92,32 @@ def write_bindings(src, root):
|
||||
# return path of cpp file
|
||||
return (cpp, so)
|
||||
|
||||
def build(src, path):
|
||||
def _build(src, path, framework):
|
||||
# include directories
|
||||
triton_include_dirs = ['/home/philippe/development/triton/include']
|
||||
tensorflow_include_dirs = [tf.sysconfig.get_include()]
|
||||
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
|
||||
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
|
||||
include_dirs = triton_include_dirs
|
||||
# library directories
|
||||
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
|
||||
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
|
||||
library_dirs = triton_library_dirs + tensorflow_library_dirs
|
||||
library_dirs = triton_library_dirs
|
||||
# libraries
|
||||
libraries = ['tensorflow_framework', 'triton']
|
||||
libraries = ['triton']
|
||||
# add framework
|
||||
if framework == tensorflow_id:
|
||||
_import_tensorflow()
|
||||
library_dirs += [tensorflow.sysconfig.get_lib()]
|
||||
include_dirs += [tensorflow.sysconfig.get_lib()]
|
||||
libraries += ['tensorflow_framework']
|
||||
elif framework == torch_id:
|
||||
_import_torch()
|
||||
prefix = os.path.dirname(torch.__file__)
|
||||
library_dirs += [os.path.join(prefix, 'lib')]
|
||||
include_dirs += [os.path.join(prefix, 'lib', 'include'),
|
||||
os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'),
|
||||
os.path.join(prefix, 'include'),
|
||||
os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')]
|
||||
libraries += ['torch']
|
||||
else:
|
||||
assert False
|
||||
# extra arguments
|
||||
extra_compile_args = []
|
||||
extra_link_args = []
|
||||
@@ -93,25 +149,138 @@ def build(src, path):
|
||||
setuptools.setup(**args)
|
||||
shutil.rmtree(tmp)
|
||||
|
||||
def _cvt_to_def_str(obj):
|
||||
def _cvt_to_def_str(obj, framework):
|
||||
# bool
|
||||
if isinstance(obj, bool):
|
||||
return str(int(obj))
|
||||
if isinstance(obj, tf.DType):
|
||||
return {tf.int8: 'char',
|
||||
tf.int16: 'short',
|
||||
tf.int32: 'int',
|
||||
tf.int64: 'long',
|
||||
tf.float16: 'half',
|
||||
tf.float32: 'float',
|
||||
tf.float64: 'double'}[obj]
|
||||
# tensorflow type
|
||||
if framework == tensorflow_id:
|
||||
_import_tensorflow()
|
||||
if isinstance(obj, tensorflow.DType):
|
||||
return {tensorflow.int8: 'char',
|
||||
tensorflow.int16: 'short',
|
||||
tensorflow.int32: 'int',
|
||||
tensorflow.int64: 'long',
|
||||
tensorflow.float16: 'half',
|
||||
tensorflow.float32: 'float',
|
||||
tensorflow.float64: 'double'}[obj]
|
||||
# torch type
|
||||
elif framework == torch_id:
|
||||
_import_torch()
|
||||
if isinstance(obj, torch.dtype):
|
||||
return {torch.int8: 'char',
|
||||
torch.int16: 'short',
|
||||
torch.int32: 'int',
|
||||
torch.int64: 'long',
|
||||
torch.float16: 'half',
|
||||
torch.float32: 'float',
|
||||
torch.float64: 'double'}[obj]
|
||||
else:
|
||||
assert False
|
||||
# default
|
||||
return str(obj)
|
||||
|
||||
|
||||
def _make_framework_op(src, outputs, options, framework):
|
||||
src, name = _make_framework_src(src, outputs, options, framework)
|
||||
cache_path = _make_cache_path(src)
|
||||
cpp, so = _write_bindings(src, cache_path, framework)
|
||||
_build(cpp, cache_path, framework)
|
||||
if framework == tensorflow_id:
|
||||
_import_tensorflow()
|
||||
return tensorflow.load_op_library(so).__dict__[name]
|
||||
elif framework == torch_id:
|
||||
_import_torch()
|
||||
torch.ops.load_library(so)
|
||||
return torch.ops.triton.__dict__[name]
|
||||
else:
|
||||
assert False
|
||||
|
||||
def _make_grid(args) :
|
||||
scalars = [x for x in args[:-1] if isinstance(x, scalar)]
|
||||
def grid(opt):
|
||||
for x in scalars:
|
||||
x.set_assume_initialized()
|
||||
result = args[-1](opt)
|
||||
for x in scalars:
|
||||
x.unset_assume_initialized()
|
||||
return result
|
||||
return grid
|
||||
|
||||
class op:
|
||||
|
||||
def __init__(self, src, outputs, framework = None):
|
||||
self.fw_id = dict()
|
||||
self.fw_ops = dict()
|
||||
self.fw_grids = dict()
|
||||
self.src = src
|
||||
self.outputs = outputs
|
||||
self.framework = _find_framework(None)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# create a new op when defines are different
|
||||
key = zip(kwargs.keys(), kwargs.values())
|
||||
if key not in self.fw_ops:
|
||||
# code generation options
|
||||
defines = []
|
||||
for k, v in kwargs.items():
|
||||
cvt = lambda x: _cvt_to_def_str(x, self.framework)
|
||||
try:
|
||||
values = list(map(cvt, v))
|
||||
except TypeError:
|
||||
values = [cvt(v)]
|
||||
defines.append((k, values))
|
||||
opt = libtriton.options_space()
|
||||
opt.defines = defines
|
||||
opt.num_warps = [1, 2, 4, 8]
|
||||
# create unique id for this op
|
||||
op_id = libtriton.make_op_id()
|
||||
self.fw_id[key] = op_id
|
||||
# register function
|
||||
libtriton.register_fn(op_id, self.src, opt)
|
||||
self.fw_ops[key] = _make_framework_op(self.src, self.outputs, opt, self.framework)
|
||||
|
||||
# retrieve framework op
|
||||
op_id = self.fw_id[key]
|
||||
op = self.fw_ops[key]
|
||||
# register grid
|
||||
grid = _make_grid(args)
|
||||
self.fw_grids[key] = grid
|
||||
libtriton.register_grid(op_id, self.fw_grids[key])
|
||||
# create operands
|
||||
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
|
||||
# call framework op
|
||||
return op(*op_args, id=op_id)
|
||||
|
||||
|
||||
# class register_gradient:
|
||||
|
||||
# def __init__(self, op):
|
||||
# self.op = op
|
||||
|
||||
# def __call__(self, f):
|
||||
# name = 'Dot'
|
||||
# ops.RegisterGradient(name)(f)
|
||||
|
||||
|
||||
def empty(shapes, framework = None):
|
||||
framework = _find_framework(framework)
|
||||
if framework == tensorflow_id:
|
||||
_import_tensorflow()
|
||||
_import_tf_extra_ops
|
||||
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
|
||||
args = tensorflow.stack(args)
|
||||
return tf_extra_ops.alloc_empty(args)
|
||||
elif framework == torch_id:
|
||||
_import_torch()
|
||||
return torch.empty(*shapes)
|
||||
|
||||
class scalar:
|
||||
|
||||
def __init__(self, x):
|
||||
_import_tf_extra_ops()
|
||||
self.id = libtriton.make_scalar_id()
|
||||
self.handle = extra_ops.register_scalar(x, id=self.id)
|
||||
self.handle = tf_extra_ops.register_scalar(x, id=self.id)
|
||||
self.assume_initialized = False
|
||||
|
||||
def set_assume_initialized(self):
|
||||
@@ -174,83 +343,6 @@ class lazy_shape:
|
||||
return scalar(self.shape[key])
|
||||
|
||||
def shape(A) :
|
||||
return lazy_shape(tf.shape(A))
|
||||
_import_tensorflow()
|
||||
return lazy_shape(tensorflow.shape(A))
|
||||
|
||||
def _make_tensorflow_op(src, outputs, options):
|
||||
src, name = make_bindings(src, outputs, options)
|
||||
cache_path = make_cache_path(src)
|
||||
cpp, so = write_bindings(src, cache_path)
|
||||
build(cpp, cache_path)
|
||||
result = tf.load_op_library(so)
|
||||
return result.__dict__[name]
|
||||
|
||||
def _make_grid(args) :
|
||||
scalars = [x for x in args[:-1] if isinstance(x, scalar)]
|
||||
def grid(opt):
|
||||
for x in scalars:
|
||||
x.set_assume_initialized()
|
||||
result = args[-1](opt)
|
||||
for x in scalars:
|
||||
x.unset_assume_initialized()
|
||||
return result
|
||||
return grid
|
||||
|
||||
class op:
|
||||
|
||||
def __init__(self, src, outputs):
|
||||
self.fw_id = dict()
|
||||
self.fw_ops = dict()
|
||||
self.fw_grids = dict()
|
||||
self.src = src
|
||||
self.outputs = outputs
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# create a new op when defines are different
|
||||
key = zip(kwargs.keys(), kwargs.values())
|
||||
if key not in self.fw_ops:
|
||||
# code generation options
|
||||
defines = []
|
||||
for k, v in kwargs.items():
|
||||
try:
|
||||
values = list(map(_cvt_to_def_str, v))
|
||||
except TypeError:
|
||||
values = [_cvt_to_def_str(v)]
|
||||
defines.append((k, values))
|
||||
opt = libtriton.options_space()
|
||||
opt.defines = defines
|
||||
opt.num_warps = [1, 2, 4, 8]
|
||||
# create unique id for this op
|
||||
op_id = libtriton.make_op_id()
|
||||
self.fw_id[key] = op_id
|
||||
# register function
|
||||
libtriton.register_fn(op_id, self.src, opt)
|
||||
self.fw_ops[key] = _make_tensorflow_op(self.src, self.outputs, opt)
|
||||
|
||||
# retrieve framework op
|
||||
op_id = self.fw_id[key]
|
||||
op = self.fw_ops[key]
|
||||
# register grid
|
||||
grid = _make_grid(args)
|
||||
self.fw_grids[key] = grid
|
||||
libtriton.register_grid(op_id, self.fw_grids[key])
|
||||
# create operands
|
||||
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
|
||||
# call framework op
|
||||
return op(*op_args, id=op_id)
|
||||
|
||||
|
||||
class register_gradient:
|
||||
|
||||
def __init__(self, op):
|
||||
self.op = op
|
||||
|
||||
def __call__(self, f):
|
||||
name = 'Dot'
|
||||
ops.RegisterGradient(name)(f)
|
||||
|
||||
|
||||
def empty(shapes):
|
||||
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
|
||||
args = tf.stack(args)
|
||||
return extra_ops.alloc_empty(args)
|
||||
|
Reference in New Issue
Block a user