[python] refactoring in anticipation of pytorch support
This commit is contained in:
@@ -34,19 +34,19 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
|
|
||||||
/* pointers for A */
|
/* pointers for A */
|
||||||
#if AT == 1
|
#if AT == 1
|
||||||
TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda;
|
TYPE* pa[TK, TM] = A + rka[:, newaxis]*lda + rxa[newaxis, :];
|
||||||
TYPE a[TK, TM] = *pa;
|
TYPE a[TK, TM] = *pa;
|
||||||
#else
|
#else
|
||||||
TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis];
|
TYPE* pa[TM, TK] = A + rka[newaxis, :] + rxa[:, newaxis]*lda;
|
||||||
TYPE a[TM, TK] = *pa;
|
TYPE a[TM, TK] = *pa;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/* pointers for B */
|
/* pointers for B */
|
||||||
#if BT == 1
|
#if BT == 1
|
||||||
TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis];
|
TYPE* pb[TN, TK] = B + rkb[newaxis, :] + ryb[:, newaxis]*ldb;
|
||||||
TYPE b[TN, TK] = *pb;
|
TYPE b[TN, TK] = *pb;
|
||||||
#else
|
#else
|
||||||
TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb;
|
TYPE* pb[TK, TN] = B + rkb[:, newaxis]*ldb + ryb[newaxis, :];
|
||||||
TYPE b[TK, TN] = *pb;
|
TYPE b[TK, TN] = *pb;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -54,14 +54,14 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
for(int k = K; k > 0; k = k - TK){
|
for(int k = K; k > 0; k = k - TK){
|
||||||
xc = USEA @ USEB + xc;
|
xc = USEA @ USEB + xc;
|
||||||
#if AT == 1
|
#if AT == 1
|
||||||
pa = pa + TK;
|
|
||||||
#else
|
|
||||||
pa = pa + TK*lda;
|
pa = pa + TK*lda;
|
||||||
|
#else
|
||||||
|
pa = pa + TK;
|
||||||
#endif
|
#endif
|
||||||
#if BT == 1
|
#if BT == 1
|
||||||
pb = pb + TK*ldb;
|
|
||||||
#else
|
|
||||||
pb = pb + TK;
|
pb = pb + TK;
|
||||||
|
#else
|
||||||
|
pb = pb + TK*ldb;
|
||||||
#endif
|
#endif
|
||||||
a = *pa;
|
a = *pa;
|
||||||
b = *pb;
|
b = *pb;
|
||||||
@@ -70,19 +70,19 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
/* epilogue */
|
/* epilogue */
|
||||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||||
TYPE* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc;
|
||||||
TYPE c[TM, TN] = xc;
|
TYPE c[TM, TN] = xc;
|
||||||
bool checkc0[TM] = rxc < M;
|
bool checkc0[TM] = rxc < M;
|
||||||
bool checkc1[TN] = ryc < N;
|
bool checkc1[TN] = ryc < N;
|
||||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||||
*pc = c;
|
*?(checkc) pc = c;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def cdiv(a, b):
|
def cdiv(a, b):
|
||||||
return -(-a // b)
|
return -(-a // b)
|
||||||
|
|
||||||
class dot:
|
class dot_op:
|
||||||
|
|
||||||
def __init__(self, trans_a = False, trans_b = False):
|
def __init__(self, trans_a = False, trans_b = False):
|
||||||
self.dot = triton.op(src, ['C'])
|
self.dot = triton.op(src, ['C'])
|
||||||
@@ -93,10 +93,18 @@ class dot:
|
|||||||
shape_a = triton.shape(a)
|
shape_a = triton.shape(a)
|
||||||
shape_b = triton.shape(b)
|
shape_b = triton.shape(b)
|
||||||
M = shape_a[0]
|
M = shape_a[0]
|
||||||
K = shape_a[1]
|
Ka = shape_a[1]
|
||||||
N = shape_b[0]
|
Kb = shape_b[0]
|
||||||
lda = M
|
N = shape_b[1]
|
||||||
ldb = K
|
# transpose shapes
|
||||||
|
if self.trans_a:
|
||||||
|
M, Ka = Ka, M
|
||||||
|
if self.trans_b:
|
||||||
|
Kb, N = N, Kb
|
||||||
|
K = Ka
|
||||||
|
# contiguous dimensions
|
||||||
|
lda = Ka
|
||||||
|
ldb = N
|
||||||
ldc = N
|
ldc = N
|
||||||
c = triton.empty([M, N])
|
c = triton.empty([M, N])
|
||||||
return self.dot(a, b, c, M, N, K, lda, ldb, ldc,
|
return self.dot(a, b, c, M, N, K, lda, ldb, ldc,
|
||||||
@@ -104,34 +112,34 @@ class dot:
|
|||||||
AT = self.trans_a, BT = self.trans_b, TYPE = tf.float16,
|
AT = self.trans_a, BT = self.trans_b, TYPE = tf.float16,
|
||||||
TM = [128], TN = [ 128], TK = [32])
|
TM = [128], TN = [ 128], TK = [32])
|
||||||
|
|
||||||
dot_nt = dot(False, True)
|
dot_nt = dot_op(False, True)
|
||||||
dot_nn = dot(False, False)
|
dot_nn = dot_op(False, False)
|
||||||
dot_tn = dot(True, False)
|
dot_tn = dot_op(True, False)
|
||||||
dot_tt = dot(True, True)
|
dot_tt = dot_op(True, True)
|
||||||
|
|
||||||
@triton.register_gradient(dot)
|
# @triton.register_gradient(dot_op)
|
||||||
def _dot_grad(op, dy):
|
# def _dot_grad(op, dy):
|
||||||
a = op.inputs[0]
|
# a = op.inputs[0]
|
||||||
b = op.inputs[1]
|
# b = op.inputs[1]
|
||||||
return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
|
# return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
|
||||||
|
|
||||||
def run_dot():
|
def run_dot():
|
||||||
M, N, K = 128, 128, 128
|
M, N, K = 128, 128, 128
|
||||||
a = tf.placeholder(tf.float16, shape=[M, K])
|
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||||
b = tf.placeholder(tf.float16, shape=[N, K])
|
b = tf.placeholder(tf.float16, shape=[N, K])
|
||||||
# c = tf.matmul(a, b, transpose_a=True)
|
# c = tf.matmul(a, b, transpose_a=True)
|
||||||
c = dot_nn(a, b)
|
c = dot_nt(a, b)
|
||||||
grads = tf.gradients(c, [a])
|
# grads = tf.gradients(c, [a])
|
||||||
# Reference
|
# Reference
|
||||||
ha = np.random.rand(M, K).astype(np.float16)
|
ha = np.random.rand(M, K).astype(np.float16)
|
||||||
hb = np.random.rand(N, K).astype(np.float16)
|
hb = np.random.rand(K, N).astype(np.float16)
|
||||||
# Run
|
# Run
|
||||||
sess = tf.InteractiveSession()
|
sess = tf.InteractiveSession()
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
result = sess.run([grads], feed_dict = {a: ha,
|
result = sess.run([c], feed_dict = {a: ha,
|
||||||
b: hb})[0]
|
b: hb})[0]
|
||||||
# Test
|
# Test
|
||||||
hresult = np.dot(ha.T, hb.T).T
|
hresult = np.dot(ha, hb.T)
|
||||||
dif = np.abs(result - hresult)
|
dif = np.abs(result - hresult)
|
||||||
np.savetxt('dif.dat', dif, '%2.4f')
|
np.savetxt('dif.dat', dif, '%2.4f')
|
||||||
print(hresult)
|
print(hresult)
|
||||||
|
@@ -1,7 +0,0 @@
|
|||||||
#include <stdio.h>
|
|
||||||
|
|
||||||
int main(){
|
|
||||||
const char* TEST = "test\n";
|
|
||||||
const char* LOL = "lol\n";
|
|
||||||
printf("%s\n",DTYPE);
|
|
||||||
}
|
|
@@ -136,7 +136,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
|
|||||||
os << "}, *id_grid_map.at(id_), stream); \n";
|
os << "}, *id_grid_map.at(id_), stream); \n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void gen_register_kernel_builder(std::ostream &os, const std::string &name,
|
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||||
const std::string &opname,
|
const std::string &opname,
|
||||||
const std::vector<ir::argument*>& args){
|
const std::vector<ir::argument*>& args){
|
||||||
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
|
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
|
||||||
@@ -151,7 +151,7 @@ void gen_register_kernel_builder(std::ostream &os, const std::string &name,
|
|||||||
os << ", " + opname << ");\n";
|
os << ", " + opname << ");\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void gen_register_op(std::ostream &os, const std::string &name,
|
void gen_tf_register_op(std::ostream &os, const std::string &name,
|
||||||
const std::vector<ir::argument*>& args,
|
const std::vector<ir::argument*>& args,
|
||||||
const std::vector<std::string>& outputs){
|
const std::vector<std::string>& outputs){
|
||||||
os << "REGISTER_OP(\"" << name << "\")\n";
|
os << "REGISTER_OP(\"" << name << "\")\n";
|
||||||
@@ -195,15 +195,12 @@ extern int get_program_id(int);
|
|||||||
)";
|
)";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::string,
|
void make_module(const std::string& src, ir::module* ir,
|
||||||
std::string> make_tensorflow_src(std::string src,
|
const runtime::function::options_space_t& opt) {
|
||||||
const std::vector<std::string>& outputs,
|
std::string copy = preheader() + src;
|
||||||
const runtime::function::options_space_t& opt)
|
|
||||||
{
|
|
||||||
src = preheader() + src;
|
|
||||||
// pre-process
|
// pre-process
|
||||||
TokenSequence tokens;
|
TokenSequence tokens;
|
||||||
Preprocessor cpp(&src, true);
|
Preprocessor cpp(©, true);
|
||||||
for(auto it: opt.defines){
|
for(auto it: opt.defines){
|
||||||
cpp.AddMacro(it.first, &it.second[0]);
|
cpp.AddMacro(it.first, &it.second[0]);
|
||||||
}
|
}
|
||||||
@@ -211,11 +208,19 @@ std::tuple<std::string,
|
|||||||
// parse
|
// parse
|
||||||
Parser parser(tokens);
|
Parser parser(tokens);
|
||||||
parser.Parse();
|
parser.Parse();
|
||||||
|
Generator gen(&parser);
|
||||||
|
gen.Gen(ir);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<std::string,
|
||||||
|
std::string> make_tensorflow_src(const std::string& src,
|
||||||
|
const std::vector<std::string>& outputs,
|
||||||
|
const runtime::function::options_space_t& opt)
|
||||||
|
{
|
||||||
// triton-ir code-gen
|
// triton-ir code-gen
|
||||||
ir::context ctx;
|
ir::context ctx;
|
||||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||||
Generator gen(&parser);
|
make_module(src, &*ir, opt);
|
||||||
gen.Gen(&*ir);
|
|
||||||
// function
|
// function
|
||||||
ir::function* fn = ir->get_function_list().front();
|
ir::function* fn = ir->get_function_list().front();
|
||||||
std::string name = fn->get_name();
|
std::string name = fn->get_name();
|
||||||
@@ -287,16 +292,145 @@ private:
|
|||||||
|
|
||||||
// register kernel builder
|
// register kernel builder
|
||||||
)";
|
)";
|
||||||
gen_register_kernel_builder(oss, cc_name, opname, fn->args());
|
gen_tf_register_kernel_builder(oss, cc_name, opname, fn->args());
|
||||||
oss << R"(
|
oss << R"(
|
||||||
// register op
|
// register op
|
||||||
)";
|
)";
|
||||||
gen_register_op(oss, cc_name, fn->args(), outputs);
|
gen_tf_register_op(oss, cc_name, fn->args(), outputs);
|
||||||
|
|
||||||
|
|
||||||
return {oss.str(), name};
|
return {oss.str(), name};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline std::string to_torch_ty(ir::type *ty) {
|
||||||
|
if(ty->is_integer_ty(1))
|
||||||
|
return "bool";
|
||||||
|
if(ty->is_integer_ty(8))
|
||||||
|
return "int8";
|
||||||
|
if(ty->is_integer_ty(16))
|
||||||
|
return "int16";
|
||||||
|
if(ty->is_integer_ty(32))
|
||||||
|
return "int32";
|
||||||
|
if(ty->is_integer_ty(64))
|
||||||
|
return "int64";
|
||||||
|
if(ty->is_half_ty())
|
||||||
|
return "float16";
|
||||||
|
if(ty->is_float_ty())
|
||||||
|
return "float32";
|
||||||
|
if(ty->is_double_ty())
|
||||||
|
return "float64";
|
||||||
|
if(ty->is_pointer_ty())
|
||||||
|
return "Tensor";
|
||||||
|
throw std::runtime_error("unknown type");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void gen_torch_signature(std::ostringstream& oss,
|
||||||
|
ir::function* fn,
|
||||||
|
const std::vector<std::string>& outputs,
|
||||||
|
const std::string& name) {
|
||||||
|
const auto& args = fn->args();
|
||||||
|
std::vector<ir::type*> out_types;
|
||||||
|
for(const std::string& out: outputs) {
|
||||||
|
auto it = std::find_if(args.begin(), args.end(),
|
||||||
|
[&](ir::argument* arg) { return arg->get_name() == out; });
|
||||||
|
if(it == args.end())
|
||||||
|
throw std::runtime_error("unknown argument");
|
||||||
|
out_types.push_back((*it)->get_type());
|
||||||
|
}
|
||||||
|
|
||||||
|
oss << "std::tuple<";
|
||||||
|
for(size_t i = 0; i < out_types.size(); i++){
|
||||||
|
if(i > 0)
|
||||||
|
oss << ", ";
|
||||||
|
oss << to_torch_ty(out_types[i]);
|
||||||
|
}
|
||||||
|
oss << "> ";
|
||||||
|
oss << name << "(";
|
||||||
|
oss << "int64 id" << std::endl;
|
||||||
|
for(size_t i = 0; i < args.size(); i++) {
|
||||||
|
ir::argument* arg = args[i];
|
||||||
|
if(i > 0)
|
||||||
|
oss << ", ";
|
||||||
|
oss << to_torch_ty(arg->get_type()) << " " << arg->get_name();
|
||||||
|
}
|
||||||
|
oss << ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
void gen_torch_init_driver(std::ostringstream &oss) {
|
||||||
|
oss << " // Wrap CUDA handles" << std::endl;
|
||||||
|
oss << " c10::DeviceIndex device = torcha.storage().device().index();" << std::endl;
|
||||||
|
oss << " // Get stream" << std::endl;
|
||||||
|
oss << " CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();" << std::endl;
|
||||||
|
oss << " triton::driver::cu_stream stream(custream, false);" << std::endl;
|
||||||
|
oss << " triton::driver::context* ctx = stream.context();" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void gen_torch_make_handles(std::ostream &os,
|
||||||
|
const std::vector<ir::argument*>& args) {
|
||||||
|
for(unsigned i = 0; i < args.size(); i++){
|
||||||
|
ir::argument *arg = args[i];
|
||||||
|
if(!arg->get_type()->is_pointer_ty())
|
||||||
|
continue;
|
||||||
|
const std::string& name = arg->get_name();
|
||||||
|
os << " drv::cu_buffer cu_" + name + "(ctx, " + name + ".storage().size(), (CUdeviceptr)" + name + ".storage.data(), false);\n ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||||
|
os << " (*id_fn_map.at(id))({";
|
||||||
|
for(unsigned i = 0; i < args.size() ; i++){
|
||||||
|
ir::argument *arg = args[i];
|
||||||
|
std::string name = arg->get_name();
|
||||||
|
if(arg->get_type()->is_pointer_ty())
|
||||||
|
name = "&cu_" + name;
|
||||||
|
if(i > 0)
|
||||||
|
os << ", ";
|
||||||
|
os << name;
|
||||||
|
}
|
||||||
|
os << "}, *id_grid_map.at(id), stream); \n";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::tuple<std::string,
|
||||||
|
std::string> make_pytorch_src(const std::string& src,
|
||||||
|
const std::vector<std::string>& outputs,
|
||||||
|
const runtime::function::options_space_t& opt) {
|
||||||
|
// triton-ir code-gen
|
||||||
|
ir::context ctx;
|
||||||
|
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||||
|
make_module(src, &*ir, opt);
|
||||||
|
// function
|
||||||
|
ir::function* fn = ir->get_function_list().front();
|
||||||
|
std::string name = fn->get_name();
|
||||||
|
// generate framework code
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << R"(
|
||||||
|
#include "triton/driver/buffer.h"
|
||||||
|
#include "triton/driver/backend.h"
|
||||||
|
#include "triton/driver/stream.h"
|
||||||
|
#include "triton/runtime/function.h"
|
||||||
|
|
||||||
|
namespace rt = triton::runtime;
|
||||||
|
namespace drv = triton::driver;
|
||||||
|
|
||||||
|
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||||
|
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||||
|
|
||||||
|
)";
|
||||||
|
|
||||||
|
gen_torch_signature(oss, fn, outputs, name);
|
||||||
|
oss << " {" << std::endl;
|
||||||
|
gen_torch_init_driver(oss);
|
||||||
|
gen_torch_make_handles(oss, fn->args());
|
||||||
|
gen_torch_make_launch_function(oss, fn->args());
|
||||||
|
oss << std::endl << "}";
|
||||||
|
|
||||||
|
oss << "static auto registry = torch::jit::RegisterOperators(\"triton::" << name << "\", &" << name << ");" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
typedef triton::runtime::function::options_t options_t;
|
typedef triton::runtime::function::options_t options_t;
|
||||||
typedef triton::runtime::function::options_space_t options_space_t;
|
typedef triton::runtime::function::options_space_t options_space_t;
|
||||||
|
|
||||||
@@ -307,6 +441,8 @@ PYBIND11_MODULE(libtriton, m) {
|
|||||||
m.def("make_tensorflow_src", &make_tensorflow_src,
|
m.def("make_tensorflow_src", &make_tensorflow_src,
|
||||||
"Creates C++ source code for a custom Tensorflow op "
|
"Creates C++ source code for a custom Tensorflow op "
|
||||||
"corresponding to the specified Triton kernel");
|
"corresponding to the specified Triton kernel");
|
||||||
|
m.def("make_pytorch_src", &make_pytorch_src,
|
||||||
|
"Creates C++ source code for a custom PyTorch op ");
|
||||||
|
|
||||||
// bindings for triton classes
|
// bindings for triton classes
|
||||||
pybind11::class_<options_t>(m, "options")
|
pybind11::class_<options_t>(m, "options")
|
||||||
|
@@ -11,18 +11,60 @@ import setuptools.command.build_ext
|
|||||||
import setuptools
|
import setuptools
|
||||||
# triton
|
# triton
|
||||||
import libtriton
|
import libtriton
|
||||||
# frameworks
|
|
||||||
import tensorflow as tf
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
|
|
||||||
|
|
||||||
extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so')
|
torch_id = 'torch'
|
||||||
|
tensorflow_id = 'tensorflow'
|
||||||
|
|
||||||
|
torch = None
|
||||||
|
tensorflow = None
|
||||||
|
tf_extra_ops = None
|
||||||
|
|
||||||
|
def _import_torch():
|
||||||
|
global torch
|
||||||
|
if torch is None:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def _import_tensorflow():
|
||||||
|
global tensorflow
|
||||||
|
if tensorflow is None:
|
||||||
|
import tensorflow
|
||||||
|
|
||||||
|
def _import_tf_extra_ops():
|
||||||
|
global tf_extra_ops
|
||||||
|
if tf_extra_ops is None:
|
||||||
|
path = os.path.dirname(libtriton.__file__)
|
||||||
|
path = os.path.join(path, 'libextra_tf_ops.so')
|
||||||
|
_import_tensorflow()
|
||||||
|
tf_extra_ops = tensorflow.load_op_library(path)
|
||||||
|
|
||||||
|
|
||||||
def make_bindings(src, out, grid):
|
def _find_framework(default = None):
|
||||||
|
is_tf_imported = 'tensorflow' in sys.modules
|
||||||
|
is_torch_imported = 'torch' in sys.modules
|
||||||
|
if default:
|
||||||
|
if default not in [tensorflow_id, torch_id]:
|
||||||
|
raise ValueError('unsupported framework')
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
elif is_tf_imported and not is_torch_imported:
|
||||||
|
return tensorflow_id
|
||||||
|
elif is_torch_imported and not is_tf_imported:
|
||||||
|
return torch_id
|
||||||
|
else:
|
||||||
|
raise ValueError('cannot determine imported framework, '
|
||||||
|
'please provide framework argument')
|
||||||
|
|
||||||
|
|
||||||
|
def _make_framework_src(src, out, grid, framework):
|
||||||
|
if framework == tensorflow_id:
|
||||||
return libtriton.make_tensorflow_src(src, out, grid)
|
return libtriton.make_tensorflow_src(src, out, grid)
|
||||||
|
elif framework == torch_id:
|
||||||
|
return libtriton.make_torch_src(src, out, grid)
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
def make_cache_path(src):
|
def _make_cache_path(src):
|
||||||
md5 = hashlib.sha1(src.encode())
|
md5 = hashlib.sha1(src.encode())
|
||||||
hexhash = md5.hexdigest()
|
hexhash = md5.hexdigest()
|
||||||
home = os.path.expanduser('~')
|
home = os.path.expanduser('~')
|
||||||
@@ -32,10 +74,10 @@ def make_cache_path(src):
|
|||||||
os.makedirs(cachepath)
|
os.makedirs(cachepath)
|
||||||
return cachepath
|
return cachepath
|
||||||
|
|
||||||
def write_bindings(src, root):
|
def _write_bindings(src, root, framework):
|
||||||
cpp = os.path.join(root, 'tensorflow.cpp')
|
cpp = os.path.join(root, '{framework}.cpp'.format(framework=framework))
|
||||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||||
so = os.path.join(root, 'tensorflow{suffix}'.format(suffix=suffix))
|
so = os.path.join(root, '{framework}{suffix}'.format(framework=framework, suffix=suffix))
|
||||||
recompile = False
|
recompile = False
|
||||||
# recompile if .so does not exist
|
# recompile if .so does not exist
|
||||||
if not os.path.exists(cpp) or not os.path.exists(so):
|
if not os.path.exists(cpp) or not os.path.exists(so):
|
||||||
@@ -50,18 +92,32 @@ def write_bindings(src, root):
|
|||||||
# return path of cpp file
|
# return path of cpp file
|
||||||
return (cpp, so)
|
return (cpp, so)
|
||||||
|
|
||||||
def build(src, path):
|
def _build(src, path, framework):
|
||||||
# include directories
|
# include directories
|
||||||
triton_include_dirs = ['/home/philippe/development/triton/include']
|
triton_include_dirs = ['/home/philippe/development/triton/include']
|
||||||
tensorflow_include_dirs = [tf.sysconfig.get_include()]
|
include_dirs = triton_include_dirs
|
||||||
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
|
|
||||||
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
|
|
||||||
# library directories
|
# library directories
|
||||||
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
|
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
|
||||||
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
|
library_dirs = triton_library_dirs
|
||||||
library_dirs = triton_library_dirs + tensorflow_library_dirs
|
|
||||||
# libraries
|
# libraries
|
||||||
libraries = ['tensorflow_framework', 'triton']
|
libraries = ['triton']
|
||||||
|
# add framework
|
||||||
|
if framework == tensorflow_id:
|
||||||
|
_import_tensorflow()
|
||||||
|
library_dirs += [tensorflow.sysconfig.get_lib()]
|
||||||
|
include_dirs += [tensorflow.sysconfig.get_lib()]
|
||||||
|
libraries += ['tensorflow_framework']
|
||||||
|
elif framework == torch_id:
|
||||||
|
_import_torch()
|
||||||
|
prefix = os.path.dirname(torch.__file__)
|
||||||
|
library_dirs += [os.path.join(prefix, 'lib')]
|
||||||
|
include_dirs += [os.path.join(prefix, 'lib', 'include'),
|
||||||
|
os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'),
|
||||||
|
os.path.join(prefix, 'include'),
|
||||||
|
os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')]
|
||||||
|
libraries += ['torch']
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
# extra arguments
|
# extra arguments
|
||||||
extra_compile_args = []
|
extra_compile_args = []
|
||||||
extra_link_args = []
|
extra_link_args = []
|
||||||
@@ -93,25 +149,138 @@ def build(src, path):
|
|||||||
setuptools.setup(**args)
|
setuptools.setup(**args)
|
||||||
shutil.rmtree(tmp)
|
shutil.rmtree(tmp)
|
||||||
|
|
||||||
def _cvt_to_def_str(obj):
|
def _cvt_to_def_str(obj, framework):
|
||||||
|
# bool
|
||||||
if isinstance(obj, bool):
|
if isinstance(obj, bool):
|
||||||
return str(int(obj))
|
return str(int(obj))
|
||||||
if isinstance(obj, tf.DType):
|
# tensorflow type
|
||||||
return {tf.int8: 'char',
|
if framework == tensorflow_id:
|
||||||
tf.int16: 'short',
|
_import_tensorflow()
|
||||||
tf.int32: 'int',
|
if isinstance(obj, tensorflow.DType):
|
||||||
tf.int64: 'long',
|
return {tensorflow.int8: 'char',
|
||||||
tf.float16: 'half',
|
tensorflow.int16: 'short',
|
||||||
tf.float32: 'float',
|
tensorflow.int32: 'int',
|
||||||
tf.float64: 'double'}[obj]
|
tensorflow.int64: 'long',
|
||||||
|
tensorflow.float16: 'half',
|
||||||
|
tensorflow.float32: 'float',
|
||||||
|
tensorflow.float64: 'double'}[obj]
|
||||||
|
# torch type
|
||||||
|
elif framework == torch_id:
|
||||||
|
_import_torch()
|
||||||
|
if isinstance(obj, torch.dtype):
|
||||||
|
return {torch.int8: 'char',
|
||||||
|
torch.int16: 'short',
|
||||||
|
torch.int32: 'int',
|
||||||
|
torch.int64: 'long',
|
||||||
|
torch.float16: 'half',
|
||||||
|
torch.float32: 'float',
|
||||||
|
torch.float64: 'double'}[obj]
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
# default
|
||||||
return str(obj)
|
return str(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_framework_op(src, outputs, options, framework):
|
||||||
|
src, name = _make_framework_src(src, outputs, options, framework)
|
||||||
|
cache_path = _make_cache_path(src)
|
||||||
|
cpp, so = _write_bindings(src, cache_path, framework)
|
||||||
|
_build(cpp, cache_path, framework)
|
||||||
|
if framework == tensorflow_id:
|
||||||
|
_import_tensorflow()
|
||||||
|
return tensorflow.load_op_library(so).__dict__[name]
|
||||||
|
elif framework == torch_id:
|
||||||
|
_import_torch()
|
||||||
|
torch.ops.load_library(so)
|
||||||
|
return torch.ops.triton.__dict__[name]
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
def _make_grid(args) :
|
||||||
|
scalars = [x for x in args[:-1] if isinstance(x, scalar)]
|
||||||
|
def grid(opt):
|
||||||
|
for x in scalars:
|
||||||
|
x.set_assume_initialized()
|
||||||
|
result = args[-1](opt)
|
||||||
|
for x in scalars:
|
||||||
|
x.unset_assume_initialized()
|
||||||
|
return result
|
||||||
|
return grid
|
||||||
|
|
||||||
|
class op:
|
||||||
|
|
||||||
|
def __init__(self, src, outputs, framework = None):
|
||||||
|
self.fw_id = dict()
|
||||||
|
self.fw_ops = dict()
|
||||||
|
self.fw_grids = dict()
|
||||||
|
self.src = src
|
||||||
|
self.outputs = outputs
|
||||||
|
self.framework = _find_framework(None)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
# create a new op when defines are different
|
||||||
|
key = zip(kwargs.keys(), kwargs.values())
|
||||||
|
if key not in self.fw_ops:
|
||||||
|
# code generation options
|
||||||
|
defines = []
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
cvt = lambda x: _cvt_to_def_str(x, self.framework)
|
||||||
|
try:
|
||||||
|
values = list(map(cvt, v))
|
||||||
|
except TypeError:
|
||||||
|
values = [cvt(v)]
|
||||||
|
defines.append((k, values))
|
||||||
|
opt = libtriton.options_space()
|
||||||
|
opt.defines = defines
|
||||||
|
opt.num_warps = [1, 2, 4, 8]
|
||||||
|
# create unique id for this op
|
||||||
|
op_id = libtriton.make_op_id()
|
||||||
|
self.fw_id[key] = op_id
|
||||||
|
# register function
|
||||||
|
libtriton.register_fn(op_id, self.src, opt)
|
||||||
|
self.fw_ops[key] = _make_framework_op(self.src, self.outputs, opt, self.framework)
|
||||||
|
|
||||||
|
# retrieve framework op
|
||||||
|
op_id = self.fw_id[key]
|
||||||
|
op = self.fw_ops[key]
|
||||||
|
# register grid
|
||||||
|
grid = _make_grid(args)
|
||||||
|
self.fw_grids[key] = grid
|
||||||
|
libtriton.register_grid(op_id, self.fw_grids[key])
|
||||||
|
# create operands
|
||||||
|
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
|
||||||
|
# call framework op
|
||||||
|
return op(*op_args, id=op_id)
|
||||||
|
|
||||||
|
|
||||||
|
# class register_gradient:
|
||||||
|
|
||||||
|
# def __init__(self, op):
|
||||||
|
# self.op = op
|
||||||
|
|
||||||
|
# def __call__(self, f):
|
||||||
|
# name = 'Dot'
|
||||||
|
# ops.RegisterGradient(name)(f)
|
||||||
|
|
||||||
|
|
||||||
|
def empty(shapes, framework = None):
|
||||||
|
framework = _find_framework(framework)
|
||||||
|
if framework == tensorflow_id:
|
||||||
|
_import_tensorflow()
|
||||||
|
_import_tf_extra_ops
|
||||||
|
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
|
||||||
|
args = tensorflow.stack(args)
|
||||||
|
return tf_extra_ops.alloc_empty(args)
|
||||||
|
elif framework == torch_id:
|
||||||
|
_import_torch()
|
||||||
|
return torch.empty(*shapes)
|
||||||
|
|
||||||
class scalar:
|
class scalar:
|
||||||
|
|
||||||
def __init__(self, x):
|
def __init__(self, x):
|
||||||
|
_import_tf_extra_ops()
|
||||||
self.id = libtriton.make_scalar_id()
|
self.id = libtriton.make_scalar_id()
|
||||||
self.handle = extra_ops.register_scalar(x, id=self.id)
|
self.handle = tf_extra_ops.register_scalar(x, id=self.id)
|
||||||
self.assume_initialized = False
|
self.assume_initialized = False
|
||||||
|
|
||||||
def set_assume_initialized(self):
|
def set_assume_initialized(self):
|
||||||
@@ -174,83 +343,6 @@ class lazy_shape:
|
|||||||
return scalar(self.shape[key])
|
return scalar(self.shape[key])
|
||||||
|
|
||||||
def shape(A) :
|
def shape(A) :
|
||||||
return lazy_shape(tf.shape(A))
|
_import_tensorflow()
|
||||||
|
return lazy_shape(tensorflow.shape(A))
|
||||||
|
|
||||||
def _make_tensorflow_op(src, outputs, options):
|
|
||||||
src, name = make_bindings(src, outputs, options)
|
|
||||||
cache_path = make_cache_path(src)
|
|
||||||
cpp, so = write_bindings(src, cache_path)
|
|
||||||
build(cpp, cache_path)
|
|
||||||
result = tf.load_op_library(so)
|
|
||||||
return result.__dict__[name]
|
|
||||||
|
|
||||||
def _make_grid(args) :
|
|
||||||
scalars = [x for x in args[:-1] if isinstance(x, scalar)]
|
|
||||||
def grid(opt):
|
|
||||||
for x in scalars:
|
|
||||||
x.set_assume_initialized()
|
|
||||||
result = args[-1](opt)
|
|
||||||
for x in scalars:
|
|
||||||
x.unset_assume_initialized()
|
|
||||||
return result
|
|
||||||
return grid
|
|
||||||
|
|
||||||
class op:
|
|
||||||
|
|
||||||
def __init__(self, src, outputs):
|
|
||||||
self.fw_id = dict()
|
|
||||||
self.fw_ops = dict()
|
|
||||||
self.fw_grids = dict()
|
|
||||||
self.src = src
|
|
||||||
self.outputs = outputs
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
# create a new op when defines are different
|
|
||||||
key = zip(kwargs.keys(), kwargs.values())
|
|
||||||
if key not in self.fw_ops:
|
|
||||||
# code generation options
|
|
||||||
defines = []
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
try:
|
|
||||||
values = list(map(_cvt_to_def_str, v))
|
|
||||||
except TypeError:
|
|
||||||
values = [_cvt_to_def_str(v)]
|
|
||||||
defines.append((k, values))
|
|
||||||
opt = libtriton.options_space()
|
|
||||||
opt.defines = defines
|
|
||||||
opt.num_warps = [1, 2, 4, 8]
|
|
||||||
# create unique id for this op
|
|
||||||
op_id = libtriton.make_op_id()
|
|
||||||
self.fw_id[key] = op_id
|
|
||||||
# register function
|
|
||||||
libtriton.register_fn(op_id, self.src, opt)
|
|
||||||
self.fw_ops[key] = _make_tensorflow_op(self.src, self.outputs, opt)
|
|
||||||
|
|
||||||
# retrieve framework op
|
|
||||||
op_id = self.fw_id[key]
|
|
||||||
op = self.fw_ops[key]
|
|
||||||
# register grid
|
|
||||||
grid = _make_grid(args)
|
|
||||||
self.fw_grids[key] = grid
|
|
||||||
libtriton.register_grid(op_id, self.fw_grids[key])
|
|
||||||
# create operands
|
|
||||||
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
|
|
||||||
# call framework op
|
|
||||||
return op(*op_args, id=op_id)
|
|
||||||
|
|
||||||
|
|
||||||
class register_gradient:
|
|
||||||
|
|
||||||
def __init__(self, op):
|
|
||||||
self.op = op
|
|
||||||
|
|
||||||
def __call__(self, f):
|
|
||||||
name = 'Dot'
|
|
||||||
ops.RegisterGradient(name)(f)
|
|
||||||
|
|
||||||
|
|
||||||
def empty(shapes):
|
|
||||||
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
|
|
||||||
args = tf.stack(args)
|
|
||||||
return extra_ops.alloc_empty(args)
|
|
||||||
|
Reference in New Issue
Block a user