more progress
This commit is contained in:
@@ -33,20 +33,20 @@ endif()
|
||||
if(BUILD_PYTHON_MODULE)
|
||||
message(STATUS "Adding Python module")
|
||||
# PyBind11 wrapper source file
|
||||
file(GLOB_RECURSE PYTHON_SRC python/src/tensorflow.cpp)
|
||||
file(GLOB_RECURSE PYTHON_SRC python/src/tensorflow.cc)
|
||||
# update include directory
|
||||
include_directories(python/src/ ${PYTHON_INCLUDE_DIRS} ${TF_INCLUDE_DIRS})
|
||||
# update link directories
|
||||
link_directories(${TF_LIB_DIRS})
|
||||
# extra tensorflow ops (e.g., alloc_empty)
|
||||
file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cpp)
|
||||
file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cc)
|
||||
add_library(extra_tf_ops SHARED ${EXTRA_TF_OPS_SRC})
|
||||
target_link_libraries(extra_tf_ops ${TF_LIBS})
|
||||
endif()
|
||||
|
||||
|
||||
# Triton
|
||||
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cpp lib/*.cc)
|
||||
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
target_link_libraries(triton LLVM)
|
||||
|
||||
|
@@ -59,12 +59,11 @@ class metaparameter;
|
||||
namespace runtime{
|
||||
|
||||
|
||||
typedef std::array<size_t, 3> grid_t;
|
||||
typedef std::vector<size_t> grid_t;
|
||||
typedef std::map<std::string, size_t> params_t;
|
||||
|
||||
template<typename T> T convert(const std::string& name);
|
||||
template<> long convert<long>(const std::string& name) { return std::stol(name); }
|
||||
template<> int convert<int>(const std::string& name) { return std::stoi(name); }
|
||||
template<typename T> inline T convert(const std::string& name);
|
||||
template<> inline long convert<long>(const std::string& name) { return std::stol(name); }
|
||||
template<> inline int convert<int>(const std::string& name) { return std::stoi(name); }
|
||||
|
||||
class function {
|
||||
public:
|
||||
@@ -91,7 +90,7 @@ private:
|
||||
class caller {
|
||||
public:
|
||||
caller(ir::function *ir, std::shared_ptr<driver::module> program, const options_t& opt_);
|
||||
void operator()(driver::stream *stream, const std::array<size_t, 3>& grid, const std::vector<arg>& args) const;
|
||||
void operator()(driver::stream *stream, const grid_t& grid, const std::vector<arg>& args) const;
|
||||
const options_t opt() const { return opt_; }
|
||||
|
||||
private:
|
||||
@@ -113,7 +112,7 @@ private:
|
||||
|
||||
public:
|
||||
function(const std::string& src, const options_space_t& opt = options_space_t());
|
||||
void operator()(const std::vector<arg>& args, const std::array<size_t, 3>& grid, driver::stream* stream);
|
||||
void operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream* stream);
|
||||
void operator()(const std::vector<arg>& args, const grid_fn_ty& grid, driver::stream *stream);
|
||||
std::string make_tensorflow_src(const std::vector<size_t> &outputs, const std::string ¯o);
|
||||
|
||||
|
@@ -15,8 +15,8 @@ namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
grids::grids(size_t num_warps): num_warps_(num_warps){
|
||||
}
|
||||
grids::grids(size_t num_warps): num_warps_(num_warps)
|
||||
{ }
|
||||
|
||||
bool is_hmma(ir::value *v){
|
||||
bool result = false;
|
||||
|
@@ -93,7 +93,7 @@ function::caller::caller(ir::function *ir, std::shared_ptr<driver::module> paren
|
||||
}
|
||||
|
||||
|
||||
void function::caller::operator ()(driver::stream *stream, const std::array<size_t, 3>& grid, const std::vector<arg>& args) const {
|
||||
void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, const std::vector<arg>& args) const {
|
||||
if(args.size() != param_tys_.size())
|
||||
throw std::runtime_error("invalid number of arguments");
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
@@ -106,6 +106,12 @@ void function::caller::operator ()(driver::stream *stream, const std::array<size
|
||||
else
|
||||
bin_->setArg(i, size_of(ty), arg_i.data());
|
||||
}
|
||||
// sanity check
|
||||
if(_grid.size() > 3)
|
||||
throw std::runtime_error("grid size must be no greater than 3");
|
||||
std::array<size_t, 3> grid;
|
||||
for(size_t i = 0; i < 3; i++)
|
||||
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
|
||||
stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1});
|
||||
}
|
||||
|
||||
@@ -207,7 +213,8 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
}
|
||||
|
||||
std::string preheader() {
|
||||
return R"(
|
||||
return
|
||||
R"(
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
@@ -228,9 +235,10 @@ function::function(const std::string &src, const options_space_t& opt): src_(sr
|
||||
}
|
||||
|
||||
void function::operator()(const std::vector<arg>& args, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
||||
/* determine if should re-tune or not */
|
||||
cache_key_t key;
|
||||
// re-tune if device is difference
|
||||
|
||||
/* figure out if the kernel should be re-tuned */
|
||||
// re-tune if device is different
|
||||
key.first = stream->context()->device();
|
||||
// re-tune if any int argument is different
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
|
@@ -3,49 +3,89 @@ import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
src = """
|
||||
const tunable int TM = {128};
|
||||
const tunable int TN = {128};
|
||||
const tunable int TK = {32};
|
||||
#if AT == 1
|
||||
#define USEA ^a
|
||||
#else
|
||||
#define USEA a
|
||||
#endif
|
||||
|
||||
void matmul(restrict read_only align(16) half *A,
|
||||
restrict read_only align(16) half *B,
|
||||
restrict read_only align(16) half *C,
|
||||
#if BT == 1
|
||||
#define USEB ^b
|
||||
#else
|
||||
#define USEB b
|
||||
#endif
|
||||
|
||||
void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __readonly __aligned(16),
|
||||
int M, int N, int K,
|
||||
multiple_of(8) int lda, multiple_of(8) int ldb, int ldc)
|
||||
{
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc) {
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int ryb[TN] = ridy * TN + (0 ... TN);
|
||||
int rxa[TM] = ridx * TM + 0 ... TM;
|
||||
int ryb[TN] = ridy * TN + 0 ... TN;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float xc[TM, TN] = 0;
|
||||
half* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis];
|
||||
half* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis];
|
||||
half a[TM, TK] = *pa;
|
||||
half b[TN, TK] = *pb;
|
||||
|
||||
/* pointers for A */
|
||||
#if AT == 1
|
||||
TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda;
|
||||
TYPE a[TK, TM] = *pa;
|
||||
#else
|
||||
TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis];
|
||||
TYPE a[TM, TK] = *pa;
|
||||
#endif
|
||||
|
||||
/* pointers for B */
|
||||
#if BT == 1
|
||||
TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis];
|
||||
TYPE b[TN, TK] = *pb;
|
||||
#else
|
||||
TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb;
|
||||
TYPE b[TK, TN] = *pb;
|
||||
#endif
|
||||
|
||||
/* reduction loop */
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = dot(a, trans(b), xc);
|
||||
xc = USEA @ USEB + xc;
|
||||
#if AT == 1
|
||||
pa = pa + TK;
|
||||
#else
|
||||
pa = pa + TK*lda;
|
||||
#endif
|
||||
#if BT == 1
|
||||
pb = pb + TK*ldb;
|
||||
#else
|
||||
pb = pb + TK;
|
||||
#endif
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
|
||||
/* epilogue */
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
half* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis]*ldc;
|
||||
half c[TM, TN] = xc;
|
||||
TYPE* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
TYPE c[TM, TN] = xc;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = c;
|
||||
*pc = c;
|
||||
}
|
||||
"""
|
||||
|
||||
def cdiv(a, b):
|
||||
return -(-a // b)
|
||||
|
||||
class dot:
|
||||
|
||||
def __init__(self):
|
||||
self.matmul = triton.make_tensorflow_op(src, ['C'], ['(M + #TM - 1)/#TM', '(N + #TN - 1)/#TN'])
|
||||
def __init__(self, trans_a = False, trans_b = True):
|
||||
self.dot = triton.op(src, ['C'])
|
||||
self.trans_a = trans_a
|
||||
self.trans_b = trans_b
|
||||
|
||||
def __call__(self, a, b):
|
||||
shape_a = tf.shape(a)
|
||||
@@ -57,9 +97,13 @@ class dot:
|
||||
ldb = K
|
||||
ldc = N
|
||||
c = triton.empty([M, N])
|
||||
return self.matmul.matmul(a, b, c, M, N, K, lda, ldb, ldc)
|
||||
return self.dot(a, b, c, M, N, K, lda, ldb, ldc,
|
||||
lambda opt: [cdiv(M, opt.D('TM')), cdiv(N, opt.D('TN')), 1],
|
||||
AT = self.trans_a, BT = self.trans_b, TYPE = tf.float16,
|
||||
TM = [128], TN = [128], TK = [32])
|
||||
|
||||
dot_tn = dot()
|
||||
|
||||
def run_dot():
|
||||
M, N, K = 128, 128, 128
|
||||
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||
|
@@ -35,6 +35,7 @@ class CMakeBuild(build_ext):
|
||||
self.build_extension(ext)
|
||||
|
||||
def build_extension(self, ext):
|
||||
self.debug = True
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
|
||||
# python directors
|
||||
python_include_dirs = distutils.sysconfig.get_python_inc()
|
||||
|
@@ -1,11 +1,14 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <string>
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/selection/selection.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/lang/lang.h"
|
||||
#include "triton/lang/code_gen.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/lang/cpp.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
@@ -14,14 +17,33 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||
extern int yyparse();
|
||||
extern YY_BUFFER_STATE yy_scan_string(const char * str);
|
||||
extern void yy_delete_buffer(YY_BUFFER_STATE buffer);
|
||||
extern triton::lang::translation_unit *ast_root;
|
||||
|
||||
using namespace triton;
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
|
||||
/* TF triton op properties */
|
||||
|
||||
std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
|
||||
std::map<size_t, rt::function*> id_fn_map;
|
||||
|
||||
void register_grid(size_t id,
|
||||
const rt::function::grid_fn_ty& grid_fn) {
|
||||
id_grid_map[id] = grid_fn;
|
||||
}
|
||||
|
||||
size_t register_fn(const std::string& src,
|
||||
const rt::function::options_space_t& opt) {
|
||||
size_t id = id_grid_map.size();
|
||||
bool is_inserted = id_fn_map.insert({id, new rt::function(src, opt)}).second;
|
||||
if(!is_inserted)
|
||||
assert(false);
|
||||
return id;
|
||||
}
|
||||
|
||||
|
||||
/* TF source-code generation */
|
||||
|
||||
inline std::string to_tf_ty(ir::type *ty) {
|
||||
if(ty->is_integer_ty(1))
|
||||
return "bool";
|
||||
@@ -59,21 +81,6 @@ inline std::string ref_to_tf_ty(ir::type *ty) {
|
||||
return res;
|
||||
}
|
||||
|
||||
inline triton::lang::translation_unit *make_ast(const char *src) {
|
||||
YY_BUFFER_STATE buffer = yy_scan_string(src);
|
||||
yyparse();
|
||||
yy_delete_buffer(buffer);
|
||||
triton::lang::translation_unit *program = ast_root;
|
||||
return program;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ir::module> make_ir(ir::context& ctx, triton::lang::translation_unit *program) {
|
||||
// create Triton-IR from AST
|
||||
ir::module* module = new ir::module("", ctx);
|
||||
program->codegen(module);
|
||||
return std::unique_ptr<ir::module>(module);
|
||||
}
|
||||
|
||||
|
||||
void gen_extract_inputs(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
for(unsigned i = 0; i < args.size(); i++){
|
||||
@@ -102,24 +109,8 @@ void gen_make_handles(std::ostream &os, const std::vector<ir::argument*>& args)
|
||||
}
|
||||
}
|
||||
|
||||
void gen_make_spmd_grid(std::ostream &os, const std::vector<std::string>& macros) {
|
||||
std::regex regex("#([a-zA-Z]([a-zA-Z]|[0-9])*)");
|
||||
std::vector<std::string> grids = macros;
|
||||
for(size_t i = grids.size(); i < 3; i++)
|
||||
grids.push_back("1");
|
||||
std::string grid = "rt::grid_t{";
|
||||
for(size_t i = 0; i < grids.size(); i++){
|
||||
if(i > 0)
|
||||
grid += ", ";
|
||||
grid += std::regex_replace(grids[i], regex, "x.at(\"$1\")");
|
||||
}
|
||||
grid += "}";
|
||||
|
||||
os << " auto grid = [&](const rt::params_t& x) { return " << grid << "; };\n ";
|
||||
}
|
||||
|
||||
void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
|
||||
os << " fn_({";
|
||||
os << " (*id_fn_map.at(id_))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = arg->get_name();
|
||||
@@ -129,7 +120,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
|
||||
os << ", ";
|
||||
os << name;
|
||||
}
|
||||
os << "}, grid, stream); \n";
|
||||
os << "}, id_grid_map.at(id_), stream); \n";
|
||||
}
|
||||
|
||||
void gen_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
@@ -168,20 +159,55 @@ void gen_register_op(std::ostream &os, const std::string &name,
|
||||
throw std::runtime_error("unknown output");
|
||||
os << " .Output(\"out" << i << ": " << to_tf_scalar_ty(args[idx]->get_type()) << "\")\n";
|
||||
}
|
||||
os << " .Attr(\"id: int\")" << std::endl;
|
||||
os << ";\n";
|
||||
}
|
||||
|
||||
std::string make_tensorflow_src(const std::string src,
|
||||
inline std::string preheader() {
|
||||
return
|
||||
R"(
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
#define __bool_true_false_are_defined 1
|
||||
|
||||
#define __readonly __attribute__((readonly))
|
||||
#define __writeonly __attribute__((writeonly))
|
||||
#define __noalias __attribute__((noalias))
|
||||
#define __aligned(A) __attribute__((aligned(A)))
|
||||
#define __multipleof(A) __attribute__((multipleof(A)))
|
||||
|
||||
extern int get_program_id(int);
|
||||
)";
|
||||
}
|
||||
|
||||
std::tuple<std::string,
|
||||
std::string> make_tensorflow_src(std::string src,
|
||||
const std::vector<std::string>& outputs,
|
||||
const std::vector<std::string>& macros) {
|
||||
triton::lang::translation_unit *ast = make_ast(src.c_str());
|
||||
triton::ir::context context;
|
||||
std::unique_ptr<ir::module> ir = make_ir(context, ast);
|
||||
const runtime::function::options_space_t& opt)
|
||||
{
|
||||
src = preheader() + src;
|
||||
// pre-process
|
||||
TokenSequence tokens;
|
||||
Preprocessor cpp(&src, true);
|
||||
for(auto it: opt.defines){
|
||||
cpp.AddMacro(it.first, &it.second[0]);
|
||||
}
|
||||
cpp.Process(tokens);
|
||||
// parse
|
||||
Parser parser(tokens);
|
||||
parser.Parse();
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::unique_ptr<ir::module>(new ir::module("", ctx));
|
||||
Generator gen(&parser);
|
||||
gen.Gen(&*ir);
|
||||
// function
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
std::string name = fn->get_name();
|
||||
name[0] = static_cast<char>(std::toupper(name[0]));
|
||||
std::string opname = name + "Op";
|
||||
std::string cc_name = name;
|
||||
cc_name[0] = static_cast<char>(std::toupper(cc_name[0]));
|
||||
std::string opname = cc_name + "Op";
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << R"(
|
||||
@@ -204,12 +230,16 @@ using GPUDevice = Eigen::GpuDevice;
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
|
||||
std::string src = R"TTKERNSRC( )" + src + ")TTKERNSRC\";" + R"(
|
||||
extern std::map<size_t, rt::function::grid_fn_ty> id_grid_map;
|
||||
extern std::map<size_t, rt::function*> id_fn_map;
|
||||
|
||||
|
||||
class )" << opname << R"(: public OpKernel {
|
||||
public:
|
||||
explicit )" << opname << R"((OpKernelConstruction* context)
|
||||
: OpKernel(context), fn_(src) { }
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
// get device/stream
|
||||
@@ -229,9 +259,7 @@ oss << R"(
|
||||
)";
|
||||
gen_make_handles(oss, fn->args());
|
||||
oss << R"(
|
||||
// create spmd grid
|
||||
)";
|
||||
gen_make_spmd_grid(oss, macros);
|
||||
oss << R"(
|
||||
// launch function
|
||||
)";
|
||||
@@ -240,22 +268,42 @@ oss << R"(
|
||||
}
|
||||
|
||||
private:
|
||||
rt::function fn_;
|
||||
int id_;
|
||||
};
|
||||
|
||||
// register kernel builder
|
||||
)";
|
||||
gen_register_kernel_builder(oss, name, opname, fn->args());
|
||||
gen_register_kernel_builder(oss, cc_name, opname, fn->args());
|
||||
oss << R"(
|
||||
// register op
|
||||
)";
|
||||
gen_register_op(oss, name, fn->args(), outputs);
|
||||
gen_register_op(oss, cc_name, fn->args(), outputs);
|
||||
|
||||
return oss.str();
|
||||
return {oss.str(), name};
|
||||
}
|
||||
|
||||
typedef triton::runtime::function::options_t options_t;
|
||||
typedef triton::runtime::function::options_space_t options_space_t;
|
||||
|
||||
PYBIND11_MODULE(libtriton, m) {
|
||||
m.doc() = "Python bindings to the C++ Triton API";
|
||||
m.def("make_tensorflow_src", &make_tensorflow_src, "Creates C++ source code for a custom Tensorflow op corresponding to the specified Triton kernel");
|
||||
|
||||
// framework binding source code generation
|
||||
m.def("make_tensorflow_src", &make_tensorflow_src,
|
||||
"Creates C++ source code for a custom Tensorflow op "
|
||||
"corresponding to the specified Triton kernel");
|
||||
|
||||
// bindings for triton classes
|
||||
pybind11::class_<options_t>(m, "options")
|
||||
.def(pybind11::init<>())
|
||||
.def("D", &options_t::D<int>);
|
||||
|
||||
pybind11::class_<options_space_t>(m, "options_space")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("defines", &options_space_t::defines)
|
||||
.def_readwrite("num_warps", &options_space_t::num_warps);
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
}
|
||||
|
@@ -91,13 +91,71 @@ def build(src, path):
|
||||
setuptools.setup(**args)
|
||||
shutil.rmtree(tmp)
|
||||
|
||||
def make_tensorflow_op(src, outputs, grids):
|
||||
bindings = make_bindings(src, outputs, grids)
|
||||
cache_path = make_cache_path(bindings)
|
||||
cpp, so = write_bindings(bindings, cache_path)
|
||||
def _cvt_to_def_str(obj):
|
||||
if isinstance(obj, bool):
|
||||
return str(int(obj))
|
||||
if isinstance(obj, tf.DType):
|
||||
return {tf.int8: 'char',
|
||||
tf.int16: 'short',
|
||||
tf.int32: 'int',
|
||||
tf.int64: 'long',
|
||||
tf.float16: 'half',
|
||||
tf.float32: 'float',
|
||||
tf.float64: 'double'}[obj]
|
||||
return str(obj)
|
||||
|
||||
class op:
|
||||
|
||||
def _make_tensorflow_op(self, src, outputs, options):
|
||||
src, name = make_bindings(src, outputs, options)
|
||||
cache_path = make_cache_path(src)
|
||||
cpp, so = write_bindings(src, cache_path)
|
||||
build(cpp, cache_path)
|
||||
result = tf.load_op_library(so)
|
||||
return result
|
||||
return result.__dict__[name]
|
||||
|
||||
def __init__(self, src, outputs):
|
||||
self.fw_ops = dict()
|
||||
self.src = src
|
||||
self.outputs = outputs
|
||||
pass
|
||||
|
||||
def D(self, name):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# recompilation key
|
||||
key = zip(kwargs.keys(), kwargs.values())
|
||||
# create a new op when non-iterable defines are different
|
||||
if key not in self.fw_ops:
|
||||
# code generation options
|
||||
defines = []
|
||||
for k, v in kwargs.items():
|
||||
try:
|
||||
values = list(map(_cvt_to_def_str, v))
|
||||
except TypeError:
|
||||
values = [_cvt_to_def_str(v)]
|
||||
defines.append((k, values))
|
||||
opt = libtriton.options_space()
|
||||
opt.defines = defines
|
||||
opt.num_warps = [1, 2, 4, 8]
|
||||
# register framework op
|
||||
id = libtriton.register_fn(self.src, opt)
|
||||
self.fw_ops[key] = (self._make_tensorflow_op(self.src, self.outputs, opt), id)
|
||||
# retrieve framework op
|
||||
op, id = self.fw_ops[key]
|
||||
libtriton.register_grid(id, args[-1])
|
||||
op_args = args[:-1]
|
||||
return op(*op_args, id=id)
|
||||
|
||||
|
||||
def make_tensorflow_op(src, outputs, grids):
|
||||
src, name = make_bindings(src, outputs, grids)
|
||||
cache_path = make_cache_path(src)
|
||||
cpp, so = write_bindings(src, cache_path)
|
||||
build(cpp, cache_path)
|
||||
result = tf.load_op_library(so)
|
||||
return result.__dict__[name]
|
||||
|
||||
def empty(shapes):
|
||||
return extra_ops.alloc_empty(tf.stack(shapes))
|
||||
|
Reference in New Issue
Block a user