[examples/python/tensorflow] bugfix in tensorflow wrapper example
This commit is contained in:
@@ -10,6 +10,10 @@ FLEX_TARGET(Lexer ${CMAKE_CURRENT_SOURCE_DIR}/include/triton/ast/scanner.l ${CMA
|
|||||||
get_filename_component(BISON_Parser_INCLUDE_DIRECTORIES ${BISON_Parser_OUTPUT_HEADER} DIRECTORY)
|
get_filename_component(BISON_Parser_INCLUDE_DIRECTORIES ${BISON_Parser_OUTPUT_HEADER} DIRECTORY)
|
||||||
include_directories(${BISON_Parser_INCLUDE_DIRECTORIES})
|
include_directories(${BISON_Parser_INCLUDE_DIRECTORIES})
|
||||||
|
|
||||||
|
#execute_process(COMMAND python -c "import tensorflow as tf; print(tf.__cxx11_abi_flag__ if \"__cxx11_abi_flag__\" in tf.__dict__ else 0)"
|
||||||
|
# OUTPUT_VARIABLE TF_ABI OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||||
|
#add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
||||||
|
|
||||||
# LLVM
|
# LLVM
|
||||||
find_package(LLVM REQUIRED CONFIG)
|
find_package(LLVM REQUIRED CONFIG)
|
||||||
message(STATUS ${LLVM_INCLUDE_DIRS})
|
message(STATUS ${LLVM_INCLUDE_DIRS})
|
||||||
@@ -24,7 +28,7 @@ if(NOT CMAKE_BUILD_TYPE)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Gather headers for cmake-based IDEs
|
# Gather headers for cmake-based IDEs
|
||||||
file( GLOB_RECURSE ALL_SRC *.cpp *.hpp *.h *.py *.y *.l)
|
file( GLOB_RECURSE ALL_SRC *.cpp *.hpp *.h *.py *.y *.l CMakeLists*)
|
||||||
add_custom_target( ALL SOURCES ${ALL_SRC} )
|
add_custom_target( ALL SOURCES ${ALL_SRC} )
|
||||||
|
|
||||||
# Compiler flags
|
# Compiler flags
|
||||||
|
@@ -5,7 +5,7 @@
|
|||||||
#include "triton/driver/backend.h"
|
#include "triton/driver/backend.h"
|
||||||
#include "triton/driver/stream.h"
|
#include "triton/driver/stream.h"
|
||||||
|
|
||||||
std::string src =
|
const char* src =
|
||||||
R"(
|
R"(
|
||||||
const tunable int32 TM = {16, 32, 64};
|
const tunable int32 TM = {16, 32, 64};
|
||||||
const tunable int32 TN = {16, 32, 64};
|
const tunable int32 TN = {16, 32, 64};
|
||||||
|
@@ -53,26 +53,8 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
|||||||
fp32 b[TN, 1] = checkb ? *pb : 0;
|
fp32 b[TN, 1] = checkb ? *pb : 0;
|
||||||
c = dot(a, trans(b), c);
|
c = dot(a, trans(b), c);
|
||||||
}
|
}
|
||||||
int32 ridx = get_range_id(0);
|
|
||||||
int32 ridy = get_range_id(1);
|
|
||||||
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||||
int32 *plock = locks + ridx + ridy*grid0;
|
*pc = c;
|
||||||
while(__atomic_cas(plock, 0, 1));
|
|
||||||
int32 *pcount = plock + grid0*grid1;
|
|
||||||
int32 count = *pcount;
|
|
||||||
int32 countp1 = select(count == GZ - 1, 0, count + 1);
|
|
||||||
int1 checkc0[TM] = rxc < M;
|
|
||||||
int1 checkc1[TN] = ryc < N;
|
|
||||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
|
||||||
if(count == 0) {
|
|
||||||
@checkc *pc = c;
|
|
||||||
*pcount = countp1;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
@checkc *pc = c + *pc;
|
|
||||||
*pcount = countp1;
|
|
||||||
}
|
|
||||||
__atomic_cas(plock, 1, 0);
|
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
@@ -1,12 +1,14 @@
|
|||||||
execute_process(COMMAND python -c "from os.path import dirname; import tensorflow as tf; print(dirname(dirname(tf.sysconfig.get_include())))"
|
execute_process(COMMAND python -c "from os.path import dirname; import tensorflow as tf; print(dirname(dirname(tf.sysconfig.get_include())))"
|
||||||
OUTPUT_VARIABLE TF_INC OUTPUT_STRIP_TRAILING_WHITESPACE)
|
OUTPUT_VARIABLE TF_INC OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||||
#execute_process(COMMAND python -c "import tensorflow as tf; print(tf.sysconfig.get_lib())"
|
execute_process(COMMAND python -c "import tensorflow as tf; print(tf.sysconfig.get_lib())"
|
||||||
# OUTPUT_VARIABLE TF_LIB)
|
OUTPUT_VARIABLE TF_LIB OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||||
#execute_process(COMMAND python -c "import tensorflow as tf; print(tf.__cxx11_abi_flag__ if \"__cxx11_abi_flag__\" in tf.__dict__ else 0)"
|
execute_process(COMMAND python -c "import tensorflow as tf; print(tf.__cxx11_abi_flag__ if \"__cxx11_abi_flag__\" in tf.__dict__ else 0)"
|
||||||
# OUTPUT_VARIABLE TF_ABI)
|
OUTPUT_VARIABLE TF_ABI OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||||
|
|
||||||
set(CUDA_HOME "/usr/local/cuda")
|
set(CUDA_HOME "/usr/local/cuda")
|
||||||
include_directories("${TF_INC}/tensorflow/include")
|
include_directories("${TF_INC}/tensorflow/include")
|
||||||
include_directories("${CUDA_HOME}/include")
|
include_directories("${CUDA_HOME}/include")
|
||||||
|
link_directories(${TF_LIB})
|
||||||
|
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
||||||
add_library(tf_blocksparse SHARED blocksparse.cpp)
|
add_library(tf_blocksparse SHARED blocksparse.cpp)
|
||||||
#link_libraries(tf_blocksparse ${TF_LIB})
|
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
||||||
|
@@ -66,35 +66,18 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
|||||||
fp32 b[TN, 1] = checkb ? *pb : 0;
|
fp32 b[TN, 1] = checkb ? *pb : 0;
|
||||||
c = dot(a, trans(b), c);
|
c = dot(a, trans(b), c);
|
||||||
}
|
}
|
||||||
int32 ridx = get_range_id(0);
|
|
||||||
int32 ridy = get_range_id(1);
|
|
||||||
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||||
int32 *plock = locks + ridx + ridy*grid0;
|
*pc = c;
|
||||||
while(__atomic_cas(plock, 0, 1));
|
|
||||||
int32 *pcount = plock + grid0*grid1;
|
|
||||||
int32 count = *pcount;
|
|
||||||
int32 countp1 = select(count == GZ - 1, 0, count + 1);
|
|
||||||
int1 checkc0[TM] = rxc < M;
|
|
||||||
int1 checkc1[TN] = ryc < N;
|
|
||||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
|
||||||
if(count == 0) {
|
|
||||||
@checkc *pc = c;
|
|
||||||
*pcount = countp1;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
@checkc *pc = c + *pc;
|
|
||||||
*pcount = countp1;
|
|
||||||
}
|
|
||||||
__atomic_cas(plock, 1, 0);
|
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
REGISTER_OP("BlockSparseGemm")
|
REGISTER_OP("BlockSparseMatMul")
|
||||||
|
.Input("a: T")
|
||||||
|
.Input("b: T")
|
||||||
|
.Input("locks: int32")
|
||||||
|
.Output("c: T")
|
||||||
.Attr("T: {float}")
|
.Attr("T: {float}")
|
||||||
.Input("A: float")
|
;
|
||||||
.Input("B: float")
|
|
||||||
.Input("locks: int")
|
|
||||||
.Output("C: float");
|
|
||||||
|
|
||||||
class BlockSparseGemmOp : public OpKernel {
|
class BlockSparseGemmOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
@@ -104,59 +87,60 @@ class BlockSparseGemmOp : public OpKernel {
|
|||||||
void Compute(OpKernelContext* context){
|
void Compute(OpKernelContext* context){
|
||||||
// get device/stream
|
// get device/stream
|
||||||
GPUDevice device = context->eigen_device<GPUDevice>();
|
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||||
triton::driver::cu_stream stream(device.stream(), false);
|
triton::driver::cu_stream sstream(device.stream(), false);
|
||||||
|
triton::driver::context* ctx = sstream.context();
|
||||||
|
triton::driver::stream* stream = &sstream;
|
||||||
// get inputs
|
// get inputs
|
||||||
const Tensor& a = context->input(0);
|
const Tensor& a = context->input(0);
|
||||||
const Tensor& b = context->input(1);
|
const Tensor& b = context->input(1);
|
||||||
const Tensor& locks = context->input(2);
|
const Tensor& locks = context->input(2);
|
||||||
// get shapes
|
// get shapes
|
||||||
const int64 M = a.dim_size(0);
|
const int32_t M = a.dim_size(0);
|
||||||
const int64 N = b.dim_size(0);
|
const int32_t N = b.dim_size(0);
|
||||||
const int64 K = a.dim_size(1);
|
const int32_t K = a.dim_size(1);
|
||||||
// allocate output
|
// allocate output
|
||||||
Tensor* c = nullptr;
|
Tensor* c = nullptr;
|
||||||
TensorShape out_shape({M, N});
|
TensorShape out_shape({(int64)M, (int64)N});
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c));
|
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c));
|
||||||
// return early if possible
|
// return early if possible
|
||||||
if (out_shape.num_elements() == 0)
|
if (out_shape.num_elements() == 0)
|
||||||
return;
|
return;
|
||||||
// wraps into buffers
|
// initialize default compute device
|
||||||
triton::driver::cu_buffer ta(stream.context(), (CUdeviceptr)a.flat<float>().data(), false);
|
triton::jit jit(ctx);
|
||||||
triton::driver::cu_buffer tb(stream.context(), (CUdeviceptr)b.flat<float>().data(), false);
|
// matrix multiplication parameters
|
||||||
triton::driver::cu_buffer tlocks(stream.context(), (CUdeviceptr)locks.flat<int32_t>().data(), false);
|
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
|
||||||
triton::driver::cu_buffer tc(stream.context(), (CUdeviceptr)c->flat<float>().data(), false);
|
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
|
||||||
// launch info
|
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
||||||
triton::jit jit(stream.context());
|
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
|
||||||
|
stream->synchronize();
|
||||||
|
// just-in-time compile source-code
|
||||||
jit.add_module("matmul", src, {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1});
|
jit.add_module("matmul", src, {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1});
|
||||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||||
int64 TM = info.global_range_size[0];
|
// launch info
|
||||||
int64 TN = info.global_range_size[1];
|
unsigned TM = info.global_range_size[0];
|
||||||
|
unsigned TN = info.global_range_size[1];
|
||||||
unsigned nthreads = info.num_threads;
|
unsigned nthreads = info.num_threads;
|
||||||
int64 GZ = jit.get_int("GZ");
|
unsigned GZ = jit.get_int("GZ");
|
||||||
std::array<size_t, 3> grid;
|
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
|
||||||
grid[0] = (M + TM - 1)/TM;
|
|
||||||
grid[1] = (N + TN - 1)/TN;
|
|
||||||
grid[2] = GZ;
|
|
||||||
// set argument
|
// set argument
|
||||||
kernel->setArg(0, &ta);
|
kernel->setArg(0, *da.cu());
|
||||||
kernel->setArg(1, &tb);
|
kernel->setArg(1, *db.cu());
|
||||||
kernel->setArg(2, &tc);
|
kernel->setArg(2, *dc.cu());
|
||||||
kernel->setArg(3, M);
|
kernel->setArg(3, M);
|
||||||
kernel->setArg(4, N);
|
kernel->setArg(4, N);
|
||||||
kernel->setArg(5, K);
|
kernel->setArg(5, K);
|
||||||
kernel->setArg(6, M);
|
kernel->setArg(6, M);
|
||||||
kernel->setArg(7, N);
|
kernel->setArg(7, N);
|
||||||
kernel->setArg(8, M);
|
kernel->setArg(8, M);
|
||||||
kernel->setArg(9, tlocks);
|
kernel->setArg(9, *dlocks.cu());
|
||||||
kernel->setArg(10, grid[0]);
|
kernel->setArg(10, grid[0]);
|
||||||
kernel->setArg(11, grid[1]);
|
kernel->setArg(11, grid[1]);
|
||||||
// dry run
|
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||||
stream.enqueue(kernel, grid, {nthreads, 1, 1}, nullptr, nullptr);
|
stream->synchronize();
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("BlockSparse").Device(DEVICE_GPU), BlockSparseGemmOp);
|
REGISTER_KERNEL_BUILDER(Name("BlockSparseMatMul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlockSparseGemmOp);
|
||||||
|
20
examples/python/tensorflow/blocksparse.py
Normal file
20
examples/python/tensorflow/blocksparse.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import os
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
data_files_path = tf.resource_loader.get_data_files_path()
|
||||||
|
library_dir = '/home/philippe/development/triton/build/examples/python/tensorflow'
|
||||||
|
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
|
||||||
|
|
||||||
|
M, N, K = 512, 512, 512
|
||||||
|
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||||
|
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||||
|
locks = tf.placeholder(tf.int32, shape=[4096])
|
||||||
|
c = module.block_sparse_mat_mul(a, b, locks)
|
||||||
|
# Run
|
||||||
|
sess = tf.InteractiveSession()
|
||||||
|
sess.run(tf.global_variables_initializer())
|
||||||
|
result = sess.run([c], feed_dict = {locks: np.zeros(4096),
|
||||||
|
a: np.random.rand(M, K),
|
||||||
|
b: np.random.rand(N, K)})
|
||||||
|
print(result)
|
@@ -1,74 +0,0 @@
|
|||||||
import os, sys
|
|
||||||
from os.path import dirname
|
|
||||||
from distutils.core import setup, Extension
|
|
||||||
from glob import glob
|
|
||||||
from build import build_clib_subclass, build_ext_subclass
|
|
||||||
|
|
||||||
|
|
||||||
def recursive_glob(rootdir='.', suffix=''):
|
|
||||||
return [os.path.join(looproot, filename)
|
|
||||||
for looproot, _, filenames in os.walk(rootdir)
|
|
||||||
for filename in filenames if filename.endswith(suffix)]
|
|
||||||
|
|
||||||
def main():
|
|
||||||
|
|
||||||
path = os.path.join(os.pardir, 'include')
|
|
||||||
include = [path, os.path.join(path, 'isaac', 'external', 'CUDA')]
|
|
||||||
src = recursive_glob(os.path.join(os.pardir,'lib'), 'cpp')
|
|
||||||
flags = ['-std=c++11', '-fPIC', '-D_GLIBCXX_USE_CXX11_ABI=0']
|
|
||||||
core = ('core', {'sources': src, 'include_dirs': include, 'cflags': flags})
|
|
||||||
|
|
||||||
# Extensions
|
|
||||||
extensions = []
|
|
||||||
|
|
||||||
# Isaac
|
|
||||||
extensions += [Extension('_isaac',
|
|
||||||
sources=recursive_glob(os.path.join('src','bind'), 'cpp'),
|
|
||||||
libraries=[],
|
|
||||||
library_dirs=[],
|
|
||||||
extra_compile_args=flags,
|
|
||||||
extra_link_args=[],
|
|
||||||
include_dirs=include + [os.path.join('src', 'bind')])]
|
|
||||||
|
|
||||||
# Tensorflow
|
|
||||||
try:
|
|
||||||
import tensorflow as tf
|
|
||||||
tf_include = tf.sysconfig.get_include()
|
|
||||||
extensions += [Extension('_tensorflow',
|
|
||||||
sources=[os.path.join('src', 'extensions', 'tensorflow.cpp')],
|
|
||||||
libraries = ['tensorflow_framework'],
|
|
||||||
extra_compile_args= flags,
|
|
||||||
include_dirs = include + [tf_include, os.path.join(tf_include, 'external', 'nsync', 'public')],
|
|
||||||
library_dirs = [tf.sysconfig.get_lib()])]
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# Setup
|
|
||||||
setup(
|
|
||||||
name='blocksparse',
|
|
||||||
version='1.0',
|
|
||||||
author='Philippe Tillet',
|
|
||||||
author_email='ptillet@g.harvard.edu',
|
|
||||||
packages=['isaac', 'isaac.pytorch', 'isaac.pytorch.models', 'isaac.pytorch.c_lib'],
|
|
||||||
libraries=[core],
|
|
||||||
ext_package='isaac',
|
|
||||||
ext_modules=extensions,
|
|
||||||
cmdclass={'build_clib': build_clib_subclass, 'build_ext': build_ext_subclass},
|
|
||||||
classifiers=['Environment :: Console',
|
|
||||||
'Development Status :: 4 - Beta',
|
|
||||||
'Intended Audience :: Developers',
|
|
||||||
'Intended Audience :: Other Audience',
|
|
||||||
'Intended Audience :: Science/Research',
|
|
||||||
'Natural Language :: English',
|
|
||||||
'Programming Language :: C++',
|
|
||||||
'Programming Language :: Python',
|
|
||||||
'Programming Language :: Python :: 3',
|
|
||||||
'Topic :: Scientific/Engineering',
|
|
||||||
'Topic :: Scientific/Engineering :: Mathematics',
|
|
||||||
'Topic :: Scientific/Engineering :: Physics',
|
|
||||||
'Topic :: Scientific/Engineering :: Machine Learning']
|
|
||||||
)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@@ -20,6 +20,7 @@ namespace codegen{
|
|||||||
class target {
|
class target {
|
||||||
public:
|
public:
|
||||||
target(bool is_gpu): is_gpu_(is_gpu){}
|
target(bool is_gpu): is_gpu_(is_gpu){}
|
||||||
|
virtual ~target() {}
|
||||||
virtual void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn) = 0;
|
virtual void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn) = 0;
|
||||||
virtual llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder) = 0;
|
virtual llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder) = 0;
|
||||||
virtual llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax) = 0;
|
virtual llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax) = 0;
|
||||||
|
@@ -89,17 +89,17 @@ public:
|
|||||||
private:
|
private:
|
||||||
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
|
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
|
||||||
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes);
|
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes);
|
||||||
std::unique_ptr<ir::module> make_triton_module(const std::string &name, const std::string &src);
|
std::unique_ptr<ir::module> make_triton_module(const char* name, const char* src);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
jit(driver::context* context);
|
jit(driver::context* context);
|
||||||
void autotune(const std::string &name, const std::string &src, benchmark_t benchmark);
|
void autotune(const char* name, const char* src, benchmark_t benchmark);
|
||||||
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
|
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
|
||||||
void add_module(const std::string &name, const std::string &src, const std::vector<unsigned>& params = {});
|
void add_module(const char* name, const char* src, const std::vector<unsigned>& params = {});
|
||||||
driver::kernel* get_function(const std::string &name);
|
driver::kernel* get_function(const char* name);
|
||||||
launch_information get_launch_info(const std::string &name);
|
launch_information get_launch_info(const char* name);
|
||||||
unsigned get_int(const std::string &name);
|
unsigned get_int(const char* name);
|
||||||
driver::buffer *get_buffer(const std::string &name);
|
driver::buffer* get_buffer(const char* name);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<driver::module*> modules_;
|
std::vector<driver::module*> modules_;
|
||||||
|
@@ -404,6 +404,7 @@ ir::value* while_statement::codegen(ir::module* mod) const{
|
|||||||
mod->seal_block(builder.get_insert_block());
|
mod->seal_block(builder.get_insert_block());
|
||||||
mod->seal_block(next_bb);
|
mod->seal_block(next_bb);
|
||||||
builder.set_insert_point(next_bb);
|
builder.set_insert_point(next_bb);
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Selection statement */
|
/* Selection statement */
|
||||||
|
@@ -19,7 +19,6 @@
|
|||||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
24
lib/jit.cpp
24
lib/jit.cpp
@@ -79,9 +79,9 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
|
|||||||
return std::unique_ptr<llvm::Module>(result);
|
return std::unique_ptr<llvm::Module>(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<ir::module> jit::make_triton_module(const std::string &name, const std::string &src) {
|
std::unique_ptr<ir::module> jit::make_triton_module(const char *name, const char *src) {
|
||||||
// create AST from Triton-C source
|
// create AST from Triton-C source
|
||||||
YY_BUFFER_STATE buffer = yy_scan_string(src.c_str());
|
YY_BUFFER_STATE buffer = yy_scan_string(src);
|
||||||
yyparse();
|
yyparse();
|
||||||
yy_delete_buffer(buffer);
|
yy_delete_buffer(buffer);
|
||||||
translation_unit *program = ast_root;
|
translation_unit *program = ast_root;
|
||||||
@@ -97,7 +97,7 @@ jit::jit(driver::context *context): driver_context_(context),
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void jit::autotune(const std::string &name, const std::string &src, benchmark_t benchmark) {
|
void jit::autotune(const char *name, const char *src, benchmark_t benchmark) {
|
||||||
// find metaparameters
|
// find metaparameters
|
||||||
auto ptt_module = make_triton_module(name, src);
|
auto ptt_module = make_triton_module(name, src);
|
||||||
ir::module &tt_module = *ptt_module;
|
ir::module &tt_module = *ptt_module;
|
||||||
@@ -143,8 +143,8 @@ void jit::autotune(const std::string &name, const std::string &src, benchmark_t
|
|||||||
// Compile
|
// Compile
|
||||||
auto ll_module = make_llvm_module(tt_module, passes);
|
auto ll_module = make_llvm_module(tt_module, passes);
|
||||||
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name.c_str()));
|
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
|
||||||
launch_information info = launch_info_map_.at(name.c_str());
|
launch_information info = launch_info_map_.at(name);
|
||||||
for(unsigned p: params)
|
for(unsigned p: params)
|
||||||
std::cout << p << " " << std::flush;
|
std::cout << p << " " << std::flush;
|
||||||
// add globals
|
// add globals
|
||||||
@@ -191,26 +191,26 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
|||||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||||
}
|
}
|
||||||
|
|
||||||
void jit::add_module(const std::string &name, const std::string &src, const std::vector<unsigned> ¶ms) {
|
void jit::add_module(const char *name, const char *src, const std::vector<unsigned> ¶ms) {
|
||||||
auto ptt_module = make_triton_module(name, src);
|
auto ptt_module = make_triton_module(name, src);
|
||||||
add_module(*ptt_module, params);
|
add_module(*ptt_module, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
driver::kernel *jit::get_function(const std::string &name) {
|
driver::kernel *jit::get_function(const char *name) {
|
||||||
return driver::kernel::create(modules_.front(), name.c_str());
|
return driver::kernel::create(modules_.front(), name);
|
||||||
}
|
}
|
||||||
|
|
||||||
jit::launch_information jit::get_launch_info(const std::string &name) {
|
jit::launch_information jit::get_launch_info(const char *name) {
|
||||||
return launch_info_map_.at(name);
|
return launch_info_map_.at(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned jit::get_int(const std::string &name){
|
unsigned jit::get_int(const char *name){
|
||||||
return global_ints_.at(name);
|
return global_ints_.at(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
driver::buffer *jit::get_buffer(const std::string &name){
|
driver::buffer *jit::get_buffer(const char *name){
|
||||||
driver::cu_module *mod = (driver::cu_module*)modules_.front();
|
driver::cu_module *mod = (driver::cu_module*)modules_.front();
|
||||||
return mod->symbol(name.c_str());
|
return mod->symbol(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user