[python] refactoring in anticipation of pytorch support

This commit is contained in:
Philippe Tillet
2019-08-29 17:06:59 -07:00
parent e3c953e79f
commit 141a823799
4 changed files with 385 additions and 156 deletions

View File

@@ -34,19 +34,19 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
/* pointers for A */
#if AT == 1
TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda;
TYPE* pa[TK, TM] = A + rka[:, newaxis]*lda + rxa[newaxis, :];
TYPE a[TK, TM] = *pa;
#else
TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis];
TYPE* pa[TM, TK] = A + rka[newaxis, :] + rxa[:, newaxis]*lda;
TYPE a[TM, TK] = *pa;
#endif
/* pointers for B */
#if BT == 1
TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis];
TYPE* pb[TN, TK] = B + rkb[newaxis, :] + ryb[:, newaxis]*ldb;
TYPE b[TN, TK] = *pb;
#else
TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb;
TYPE* pb[TK, TN] = B + rkb[:, newaxis]*ldb + ryb[newaxis, :];
TYPE b[TK, TN] = *pb;
#endif
@@ -54,14 +54,14 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
for(int k = K; k > 0; k = k - TK){
xc = USEA @ USEB + xc;
#if AT == 1
pa = pa + TK;
#else
pa = pa + TK*lda;
#else
pa = pa + TK;
#endif
#if BT == 1
pb = pb + TK*ldb;
#else
pb = pb + TK;
#else
pb = pb + TK*ldb;
#endif
a = *pa;
b = *pb;
@@ -70,19 +70,19 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
/* epilogue */
int rxc[TM] = ridx * TM + (0 ... TM);
int ryc[TN] = ridy * TN + (0 ... TN);
TYPE* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc;
TYPE c[TM, TN] = xc;
bool checkc0[TM] = rxc < M;
bool checkc1[TN] = ryc < N;
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
*pc = c;
*?(checkc) pc = c;
}
"""
def cdiv(a, b):
return -(-a // b)
class dot:
class dot_op:
def __init__(self, trans_a = False, trans_b = False):
self.dot = triton.op(src, ['C'])
@@ -93,10 +93,18 @@ class dot:
shape_a = triton.shape(a)
shape_b = triton.shape(b)
M = shape_a[0]
K = shape_a[1]
N = shape_b[0]
lda = M
ldb = K
Ka = shape_a[1]
Kb = shape_b[0]
N = shape_b[1]
# transpose shapes
if self.trans_a:
M, Ka = Ka, M
if self.trans_b:
Kb, N = N, Kb
K = Ka
# contiguous dimensions
lda = Ka
ldb = N
ldc = N
c = triton.empty([M, N])
return self.dot(a, b, c, M, N, K, lda, ldb, ldc,
@@ -104,34 +112,34 @@ class dot:
AT = self.trans_a, BT = self.trans_b, TYPE = tf.float16,
TM = [128], TN = [ 128], TK = [32])
dot_nt = dot(False, True)
dot_nn = dot(False, False)
dot_tn = dot(True, False)
dot_tt = dot(True, True)
dot_nt = dot_op(False, True)
dot_nn = dot_op(False, False)
dot_tn = dot_op(True, False)
dot_tt = dot_op(True, True)
@triton.register_gradient(dot)
def _dot_grad(op, dy):
a = op.inputs[0]
b = op.inputs[1]
return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
# @triton.register_gradient(dot_op)
# def _dot_grad(op, dy):
# a = op.inputs[0]
# b = op.inputs[1]
# return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
def run_dot():
M, N, K = 128, 128, 128
a = tf.placeholder(tf.float16, shape=[M, K])
b = tf.placeholder(tf.float16, shape=[N, K])
# c = tf.matmul(a, b, transpose_a=True)
c = dot_nn(a, b)
grads = tf.gradients(c, [a])
c = dot_nt(a, b)
# grads = tf.gradients(c, [a])
# Reference
ha = np.random.rand(M, K).astype(np.float16)
hb = np.random.rand(N, K).astype(np.float16)
hb = np.random.rand(K, N).astype(np.float16)
# Run
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([grads], feed_dict = {a: ha,
result = sess.run([c], feed_dict = {a: ha,
b: hb})[0]
# Test
hresult = np.dot(ha.T, hb.T).T
hresult = np.dot(ha, hb.T)
dif = np.abs(result - hresult)
np.savetxt('dif.dat', dif, '%2.4f')
print(hresult)

View File

@@ -1,7 +0,0 @@
#include <stdio.h>
int main(){
const char* TEST = "test\n";
const char* LOL = "lol\n";
printf("%s\n",DTYPE);
}

View File

@@ -136,7 +136,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
os << "}, *id_grid_map.at(id_), stream); \n";
}
void gen_register_kernel_builder(std::ostream &os, const std::string &name,
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
const std::string &opname,
const std::vector<ir::argument*>& args){
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
@@ -151,7 +151,7 @@ void gen_register_kernel_builder(std::ostream &os, const std::string &name,
os << ", " + opname << ");\n";
}
void gen_register_op(std::ostream &os, const std::string &name,
void gen_tf_register_op(std::ostream &os, const std::string &name,
const std::vector<ir::argument*>& args,
const std::vector<std::string>& outputs){
os << "REGISTER_OP(\"" << name << "\")\n";
@@ -195,15 +195,12 @@ extern int get_program_id(int);
)";
}
std::tuple<std::string,
std::string> make_tensorflow_src(std::string src,
const std::vector<std::string>& outputs,
const runtime::function::options_space_t& opt)
{
src = preheader() + src;
void make_module(const std::string& src, ir::module* ir,
const runtime::function::options_space_t& opt) {
std::string copy = preheader() + src;
// pre-process
TokenSequence tokens;
Preprocessor cpp(&src, true);
Preprocessor cpp(&copy, true);
for(auto it: opt.defines){
cpp.AddMacro(it.first, &it.second[0]);
}
@@ -211,11 +208,19 @@ std::tuple<std::string,
// parse
Parser parser(tokens);
parser.Parse();
Generator gen(&parser);
gen.Gen(ir);
}
std::tuple<std::string,
std::string> make_tensorflow_src(const std::string& src,
const std::vector<std::string>& outputs,
const runtime::function::options_space_t& opt)
{
// triton-ir code-gen
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
Generator gen(&parser);
gen.Gen(&*ir);
make_module(src, &*ir, opt);
// function
ir::function* fn = ir->get_function_list().front();
std::string name = fn->get_name();
@@ -287,16 +292,145 @@ private:
// register kernel builder
)";
gen_register_kernel_builder(oss, cc_name, opname, fn->args());
gen_tf_register_kernel_builder(oss, cc_name, opname, fn->args());
oss << R"(
// register op
)";
gen_register_op(oss, cc_name, fn->args(), outputs);
gen_tf_register_op(oss, cc_name, fn->args(), outputs);
return {oss.str(), name};
}
inline std::string to_torch_ty(ir::type *ty) {
if(ty->is_integer_ty(1))
return "bool";
if(ty->is_integer_ty(8))
return "int8";
if(ty->is_integer_ty(16))
return "int16";
if(ty->is_integer_ty(32))
return "int32";
if(ty->is_integer_ty(64))
return "int64";
if(ty->is_half_ty())
return "float16";
if(ty->is_float_ty())
return "float32";
if(ty->is_double_ty())
return "float64";
if(ty->is_pointer_ty())
return "Tensor";
throw std::runtime_error("unknown type");
}
void gen_torch_signature(std::ostringstream& oss,
ir::function* fn,
const std::vector<std::string>& outputs,
const std::string& name) {
const auto& args = fn->args();
std::vector<ir::type*> out_types;
for(const std::string& out: outputs) {
auto it = std::find_if(args.begin(), args.end(),
[&](ir::argument* arg) { return arg->get_name() == out; });
if(it == args.end())
throw std::runtime_error("unknown argument");
out_types.push_back((*it)->get_type());
}
oss << "std::tuple<";
for(size_t i = 0; i < out_types.size(); i++){
if(i > 0)
oss << ", ";
oss << to_torch_ty(out_types[i]);
}
oss << "> ";
oss << name << "(";
oss << "int64 id" << std::endl;
for(size_t i = 0; i < args.size(); i++) {
ir::argument* arg = args[i];
if(i > 0)
oss << ", ";
oss << to_torch_ty(arg->get_type()) << " " << arg->get_name();
}
oss << ")";
}
void gen_torch_init_driver(std::ostringstream &oss) {
oss << " // Wrap CUDA handles" << std::endl;
oss << " c10::DeviceIndex device = torcha.storage().device().index();" << std::endl;
oss << " // Get stream" << std::endl;
oss << " CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();" << std::endl;
oss << " triton::driver::cu_stream stream(custream, false);" << std::endl;
oss << " triton::driver::context* ctx = stream.context();" << std::endl;
}
void gen_torch_make_handles(std::ostream &os,
const std::vector<ir::argument*>& args) {
for(unsigned i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
if(!arg->get_type()->is_pointer_ty())
continue;
const std::string& name = arg->get_name();
os << " drv::cu_buffer cu_" + name + "(ctx, " + name + ".storage().size(), (CUdeviceptr)" + name + ".storage.data(), false);\n ";
}
}
void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argument*>& args) {
os << " (*id_fn_map.at(id))({";
for(unsigned i = 0; i < args.size() ; i++){
ir::argument *arg = args[i];
std::string name = arg->get_name();
if(arg->get_type()->is_pointer_ty())
name = "&cu_" + name;
if(i > 0)
os << ", ";
os << name;
}
os << "}, *id_grid_map.at(id), stream); \n";
}
std::tuple<std::string,
std::string> make_pytorch_src(const std::string& src,
const std::vector<std::string>& outputs,
const runtime::function::options_space_t& opt) {
// triton-ir code-gen
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
make_module(src, &*ir, opt);
// function
ir::function* fn = ir->get_function_list().front();
std::string name = fn->get_name();
// generate framework code
std::ostringstream oss;
oss << R"(
#include "triton/driver/buffer.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
namespace rt = triton::runtime;
namespace drv = triton::driver;
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
)";
gen_torch_signature(oss, fn, outputs, name);
oss << " {" << std::endl;
gen_torch_init_driver(oss);
gen_torch_make_handles(oss, fn->args());
gen_torch_make_launch_function(oss, fn->args());
oss << std::endl << "}";
oss << "static auto registry = torch::jit::RegisterOperators(\"triton::" << name << "\", &" << name << ");" << std::endl;
}
typedef triton::runtime::function::options_t options_t;
typedef triton::runtime::function::options_space_t options_space_t;
@@ -307,6 +441,8 @@ PYBIND11_MODULE(libtriton, m) {
m.def("make_tensorflow_src", &make_tensorflow_src,
"Creates C++ source code for a custom Tensorflow op "
"corresponding to the specified Triton kernel");
m.def("make_pytorch_src", &make_pytorch_src,
"Creates C++ source code for a custom PyTorch op ");
// bindings for triton classes
pybind11::class_<options_t>(m, "options")

View File

@@ -11,18 +11,60 @@ import setuptools.command.build_ext
import setuptools
# triton
import libtriton
# frameworks
import tensorflow as tf
from tensorflow.python.framework import ops
extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so')
torch_id = 'torch'
tensorflow_id = 'tensorflow'
torch = None
tensorflow = None
tf_extra_ops = None
def _import_torch():
global torch
if torch is None:
import torch
def _import_tensorflow():
global tensorflow
if tensorflow is None:
import tensorflow
def _import_tf_extra_ops():
global tf_extra_ops
if tf_extra_ops is None:
path = os.path.dirname(libtriton.__file__)
path = os.path.join(path, 'libextra_tf_ops.so')
_import_tensorflow()
tf_extra_ops = tensorflow.load_op_library(path)
def make_bindings(src, out, grid):
return libtriton.make_tensorflow_src(src, out, grid)
def _find_framework(default = None):
is_tf_imported = 'tensorflow' in sys.modules
is_torch_imported = 'torch' in sys.modules
if default:
if default not in [tensorflow_id, torch_id]:
raise ValueError('unsupported framework')
else:
return default
elif is_tf_imported and not is_torch_imported:
return tensorflow_id
elif is_torch_imported and not is_tf_imported:
return torch_id
else:
raise ValueError('cannot determine imported framework, '
'please provide framework argument')
def make_cache_path(src):
def _make_framework_src(src, out, grid, framework):
if framework == tensorflow_id:
return libtriton.make_tensorflow_src(src, out, grid)
elif framework == torch_id:
return libtriton.make_torch_src(src, out, grid)
else:
assert False
def _make_cache_path(src):
md5 = hashlib.sha1(src.encode())
hexhash = md5.hexdigest()
home = os.path.expanduser('~')
@@ -32,10 +74,10 @@ def make_cache_path(src):
os.makedirs(cachepath)
return cachepath
def write_bindings(src, root):
cpp = os.path.join(root, 'tensorflow.cpp')
def _write_bindings(src, root, framework):
cpp = os.path.join(root, '{framework}.cpp'.format(framework=framework))
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(root, 'tensorflow{suffix}'.format(suffix=suffix))
so = os.path.join(root, '{framework}{suffix}'.format(framework=framework, suffix=suffix))
recompile = False
# recompile if .so does not exist
if not os.path.exists(cpp) or not os.path.exists(so):
@@ -50,18 +92,32 @@ def write_bindings(src, root):
# return path of cpp file
return (cpp, so)
def build(src, path):
def _build(src, path, framework):
# include directories
triton_include_dirs = ['/home/philippe/development/triton/include']
tensorflow_include_dirs = [tf.sysconfig.get_include()]
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
include_dirs = triton_include_dirs
# library directories
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
library_dirs = triton_library_dirs + tensorflow_library_dirs
library_dirs = triton_library_dirs
# libraries
libraries = ['tensorflow_framework', 'triton']
libraries = ['triton']
# add framework
if framework == tensorflow_id:
_import_tensorflow()
library_dirs += [tensorflow.sysconfig.get_lib()]
include_dirs += [tensorflow.sysconfig.get_lib()]
libraries += ['tensorflow_framework']
elif framework == torch_id:
_import_torch()
prefix = os.path.dirname(torch.__file__)
library_dirs += [os.path.join(prefix, 'lib')]
include_dirs += [os.path.join(prefix, 'lib', 'include'),
os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'),
os.path.join(prefix, 'include'),
os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')]
libraries += ['torch']
else:
assert False
# extra arguments
extra_compile_args = []
extra_link_args = []
@@ -93,25 +149,138 @@ def build(src, path):
setuptools.setup(**args)
shutil.rmtree(tmp)
def _cvt_to_def_str(obj):
def _cvt_to_def_str(obj, framework):
# bool
if isinstance(obj, bool):
return str(int(obj))
if isinstance(obj, tf.DType):
return {tf.int8: 'char',
tf.int16: 'short',
tf.int32: 'int',
tf.int64: 'long',
tf.float16: 'half',
tf.float32: 'float',
tf.float64: 'double'}[obj]
# tensorflow type
if framework == tensorflow_id:
_import_tensorflow()
if isinstance(obj, tensorflow.DType):
return {tensorflow.int8: 'char',
tensorflow.int16: 'short',
tensorflow.int32: 'int',
tensorflow.int64: 'long',
tensorflow.float16: 'half',
tensorflow.float32: 'float',
tensorflow.float64: 'double'}[obj]
# torch type
elif framework == torch_id:
_import_torch()
if isinstance(obj, torch.dtype):
return {torch.int8: 'char',
torch.int16: 'short',
torch.int32: 'int',
torch.int64: 'long',
torch.float16: 'half',
torch.float32: 'float',
torch.float64: 'double'}[obj]
else:
assert False
# default
return str(obj)
def _make_framework_op(src, outputs, options, framework):
src, name = _make_framework_src(src, outputs, options, framework)
cache_path = _make_cache_path(src)
cpp, so = _write_bindings(src, cache_path, framework)
_build(cpp, cache_path, framework)
if framework == tensorflow_id:
_import_tensorflow()
return tensorflow.load_op_library(so).__dict__[name]
elif framework == torch_id:
_import_torch()
torch.ops.load_library(so)
return torch.ops.triton.__dict__[name]
else:
assert False
def _make_grid(args) :
scalars = [x for x in args[:-1] if isinstance(x, scalar)]
def grid(opt):
for x in scalars:
x.set_assume_initialized()
result = args[-1](opt)
for x in scalars:
x.unset_assume_initialized()
return result
return grid
class op:
def __init__(self, src, outputs, framework = None):
self.fw_id = dict()
self.fw_ops = dict()
self.fw_grids = dict()
self.src = src
self.outputs = outputs
self.framework = _find_framework(None)
def __call__(self, *args, **kwargs):
# create a new op when defines are different
key = zip(kwargs.keys(), kwargs.values())
if key not in self.fw_ops:
# code generation options
defines = []
for k, v in kwargs.items():
cvt = lambda x: _cvt_to_def_str(x, self.framework)
try:
values = list(map(cvt, v))
except TypeError:
values = [cvt(v)]
defines.append((k, values))
opt = libtriton.options_space()
opt.defines = defines
opt.num_warps = [1, 2, 4, 8]
# create unique id for this op
op_id = libtriton.make_op_id()
self.fw_id[key] = op_id
# register function
libtriton.register_fn(op_id, self.src, opt)
self.fw_ops[key] = _make_framework_op(self.src, self.outputs, opt, self.framework)
# retrieve framework op
op_id = self.fw_id[key]
op = self.fw_ops[key]
# register grid
grid = _make_grid(args)
self.fw_grids[key] = grid
libtriton.register_grid(op_id, self.fw_grids[key])
# create operands
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
# call framework op
return op(*op_args, id=op_id)
# class register_gradient:
# def __init__(self, op):
# self.op = op
# def __call__(self, f):
# name = 'Dot'
# ops.RegisterGradient(name)(f)
def empty(shapes, framework = None):
framework = _find_framework(framework)
if framework == tensorflow_id:
_import_tensorflow()
_import_tf_extra_ops
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
args = tensorflow.stack(args)
return tf_extra_ops.alloc_empty(args)
elif framework == torch_id:
_import_torch()
return torch.empty(*shapes)
class scalar:
def __init__(self, x):
_import_tf_extra_ops()
self.id = libtriton.make_scalar_id()
self.handle = extra_ops.register_scalar(x, id=self.id)
self.handle = tf_extra_ops.register_scalar(x, id=self.id)
self.assume_initialized = False
def set_assume_initialized(self):
@@ -174,83 +343,6 @@ class lazy_shape:
return scalar(self.shape[key])
def shape(A) :
return lazy_shape(tf.shape(A))
_import_tensorflow()
return lazy_shape(tensorflow.shape(A))
def _make_tensorflow_op(src, outputs, options):
src, name = make_bindings(src, outputs, options)
cache_path = make_cache_path(src)
cpp, so = write_bindings(src, cache_path)
build(cpp, cache_path)
result = tf.load_op_library(so)
return result.__dict__[name]
def _make_grid(args) :
scalars = [x for x in args[:-1] if isinstance(x, scalar)]
def grid(opt):
for x in scalars:
x.set_assume_initialized()
result = args[-1](opt)
for x in scalars:
x.unset_assume_initialized()
return result
return grid
class op:
def __init__(self, src, outputs):
self.fw_id = dict()
self.fw_ops = dict()
self.fw_grids = dict()
self.src = src
self.outputs = outputs
pass
def __call__(self, *args, **kwargs):
# create a new op when defines are different
key = zip(kwargs.keys(), kwargs.values())
if key not in self.fw_ops:
# code generation options
defines = []
for k, v in kwargs.items():
try:
values = list(map(_cvt_to_def_str, v))
except TypeError:
values = [_cvt_to_def_str(v)]
defines.append((k, values))
opt = libtriton.options_space()
opt.defines = defines
opt.num_warps = [1, 2, 4, 8]
# create unique id for this op
op_id = libtriton.make_op_id()
self.fw_id[key] = op_id
# register function
libtriton.register_fn(op_id, self.src, opt)
self.fw_ops[key] = _make_tensorflow_op(self.src, self.outputs, opt)
# retrieve framework op
op_id = self.fw_id[key]
op = self.fw_ops[key]
# register grid
grid = _make_grid(args)
self.fw_grids[key] = grid
libtriton.register_grid(op_id, self.fw_grids[key])
# create operands
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
# call framework op
return op(*op_args, id=op_id)
class register_gradient:
def __init__(self, op):
self.op = op
def __call__(self, f):
name = 'Dot'
ops.RegisterGradient(name)(f)
def empty(shapes):
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
args = tf.stack(args)
return extra_ops.alloc_empty(args)