[codegen] fixed issue in double buffering pointer update

This commit is contained in:
Philippe Tillet
2019-08-28 17:50:45 -07:00
parent 59281f5794
commit d457482539
9 changed files with 135 additions and 114 deletions

View File

@@ -1577,7 +1577,8 @@ void selection::run(ir::module &src, Module &dst) {
offset->addIncoming(next_offset, llvm_inc_block);
}
else {
offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*4)), llvm_inc_block);
unsigned num_bytes = phi->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*num_bytes)), llvm_inc_block);
}
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
}

View File

@@ -72,7 +72,7 @@ handle<T>::~handle(){
try{
if(has_ownership_ && h_ && h_.unique())
_delete(*h_);
}catch(const exception::cuda::deinitialized&){
}catch(const exception::cuda::base&){
// order of destruction for global variables
// is not guaranteed
}

View File

@@ -26,6 +26,7 @@
#include "triton/driver/error.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
@@ -240,7 +241,6 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
@@ -250,8 +250,10 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
try{
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
}catch(exception::cuda::base const &){
#ifdef TRITON_LOG_PTX_ERROR
std::cerr << "Compilation Failed! Log: " << std::endl;
std::cerr << errbuf << std::endl;
#endif
throw;
}
}

View File

@@ -12,6 +12,7 @@
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
#include "triton/driver/error.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/print.h"
@@ -166,6 +167,8 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
bin = make_bin(*ir, stream->context(), opt);
}catch(const std::runtime_error& e) {
return;
}catch(const driver::exception::cuda::invalid_ptx& e) {
return;
}
// benchmark
ir::function *tmp = ir->get_function_list()[0];
@@ -178,6 +181,8 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
}
};
_parallel_loop_nest<std::string>(space, benchmark, 1);
if(!ret)
throw std::runtime_error("could not find valid option in provided space");
return *ret;
}

View File

@@ -47,7 +47,7 @@ class CMakeBuild(build_ext):
tf_libs = 'tensorflow_framework'
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
'-DBUILD_EXAMPLES=OFF',
'-DBUILD_TESTS=OFF',
'-DBUILD_PYTHON_MODULE=ON',
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
'-DTF_INCLUDE_DIRS=' + tf_include_dirs,

View File

@@ -160,7 +160,8 @@ void gen_register_op(std::ostream &os, const std::string &name,
std::string name = arg->get_name();
auto tolower = [](char c) { return std::tolower(c);};
std::transform(name.begin(), name.end(), name.begin(), tolower);
os << " .Input(\"" << name << ": " << to_tf_scalar_ty(arg->get_type()) << "\")\n";
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
os << " .Input(\"" << name << ": T" << i << "\")\n";
}
for(size_t i = 0; i < outputs.size(); i++){
std::string name = outputs[i];

View File

@@ -10,11 +10,6 @@
#include "cuda/cublas.h"
struct perf_t {
double triton;
double cublas;
};
namespace drv = triton::driver;
namespace rt = triton::runtime;
@@ -22,6 +17,14 @@ inline size_t ceil(size_t x, size_t y) {
return (x + y - 1) / y;
};
inline rt::function::grid_fn_ty grid(size_t M, size_t N) {
return [M, N](const rt::function::options_t& x) {
return rt::grid_t{ceil(M, x.D<int>("TM")),
ceil(N, x.D<int>("TN"))};
};
}
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef half_float::half NumericT;
@@ -33,9 +36,9 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
int32_t ldb = BT ? N : K;
int32_t ldc = M;
// create inputs
auto dc = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
auto da = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
auto db = std::unique_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
auto dc = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
// create options
rt::function::options_space_t opt;
opt.defines.push_back({"TYPE", {ty}});
@@ -47,11 +50,6 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
opt.defines.push_back({"TN", {"16", "32", "64", "128"}});
opt.defines.push_back({"TK", {"32"}});
opt.num_warps = {1, 2, 4, 8};
// create grid
auto grid = [&](const rt::function::options_t& x) {
return rt::grid_t{ceil(M, x.D<int>("TM")),
ceil(N, x.D<int>("TN"))};
};
// create function
rt::function function(src::dot, opt);
// benchmark available libraries
@@ -68,7 +66,7 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
result.push_back(tflops(cublas_ms));
}
// triton
double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream);
double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid(M, N), stream);}, stream);
result.push_back(tflops(triton_ms));
// done
return result;
@@ -80,11 +78,25 @@ int main() {
triton::driver::stream* stream = triton::driver::stream::create(context);
// shapes to benchmark
typedef std::tuple<bool, bool, int, int, int> config_t;
std::vector<config_t> configs = {
config_t{false, true, 512, 512, 512},
config_t{false, true, 2048, 2048, 2048},
config_t{false, true, 8192, 8192, 8192}
};
std::vector<config_t> configs;
for(auto x: std::vector<std::array<bool, 2>>{{false, false},
{false, true},
{true, false}}){
std::vector<config_t> tmp = {
config_t{x[0], x[1], 8192, 8192, 8192}
// config_t{x[0], x[1], 16, 2048, 2048},
// config_t{x[0], x[1], 32, 2048, 2048},
// config_t{x[0], x[1], 64, 2048, 2048},
// config_t{x[0], x[1], 128, 2048, 2048},
// config_t{x[0], x[1], 7000, 2048, 2048},
// config_t{x[0], x[1], 16, 4096, 4096},
// config_t{x[0], x[1], 32, 4096, 4096},
// config_t{x[0], x[1], 64, 4096, 4096},
// config_t{x[0], x[1], 128, 4096, 4096},
// config_t{x[0], x[1], 7000, 4096, 4096},
};
configs.insert(configs.end(), tmp.begin(), tmp.end());
}
// does the work
bool AT, BT;
int32_t M, N, K;

View File

@@ -30,21 +30,21 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
float xc[TM, TN] = 0;
#ifdef AT
TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda;
bool checka[TK, TM] = rka[:, newaxis] < K;
TYPE a[TK, TM] = checka ? *pa : 0;
bool checka[TK, TM] = rka[:, newaxis] < TK;
TYPE a[TK, TM] = *pa;
#else
TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis];
bool checka[TM, TK] = rka[newaxis, :] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
bool checka[TM, TK] = rka[newaxis, :] < TK;
TYPE a[TM, TK] = *pa;
#endif
#ifdef BT
TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis];
bool checkb[TN, TK] = rkb[newaxis, :] < K;
TYPE b[TN, TK] = checkb ? *pb : 0;
bool checkb[TN, TK] = rkb[newaxis, :] < TK;
TYPE b[TN, TK] = *pb;
#else
TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb;
bool checkb[TK, TN] = rkb[:, newaxis] < K;
TYPE b[TK, TN] = checkb ? *pb : 0;
bool checkb[TK, TN] = rkb[:, newaxis] < TK;
TYPE b[TK, TN] = *pb;
#endif
for(int k = K; k > 0; k = k - TK){
xc = USEA @ USEB + xc;
@@ -60,8 +60,8 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
#endif
checka = k > TK;
checkb = k > TK;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
a = *pa;
b = *pb;
}
int rxc[TM] = ridx * TM + (0 ... TM);
int ryc[TN] = ridy * TN + (0 ... TN);
@@ -70,8 +70,8 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
bool checkc0[TM] = rxc < M;
bool checkc1[TN] = ryc < N;
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
*?(checkc) pc = c;
*pc = c;
}
)";
}
}

View File

@@ -9,6 +9,9 @@
#include "src/dot.h"
#include "cuda/cublas.h"
namespace drv = triton::driver;
namespace rt = triton::runtime;
template<class T>
void diff(const std::vector<T>& x, const std::vector<T>& y){
for(size_t i = 0; i < x.size(); i++)
@@ -44,16 +47,44 @@ void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K,
cpu_ref<T, false, false>(c, a, b, M, N, K);
}
struct perf_t {
double triton;
double cublas;
inline size_t ceil(size_t x, size_t y) {
return (x + y - 1) / y;
};
namespace drv = triton::driver;
namespace rt = triton::runtime;
inline rt::function::grid_fn_ty grid(size_t M, size_t N) {
return [M, N](const rt::function::options_t& x) {
return rt::grid_t{ceil(M, x.D<int>("TM")),
ceil(N, x.D<int>("TN"))};
};
}
perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
namespace aux{
template<std::size_t...> struct seq{};
template<std::size_t N, std::size_t... Is>
struct gen_seq : gen_seq<N-1, N-1, Is...>{};
template<std::size_t... Is>
struct gen_seq<0, Is...> : seq<Is...>{};
template<class Ch, class Tr, class Tuple, std::size_t... Is>
void print_tuple(std::basic_ostream<Ch,Tr>& os, Tuple const& t, seq<Is...>){
using swallow = int[];
(void)swallow{0, (void(os << (Is == 0? "" : ", ") << std::get<Is>(t)), 0)...};
}
} // aux::
template<class Ch, class Tr, class... Args>
auto operator<<(std::basic_ostream<Ch, Tr>& os, std::tuple<Args...> const& t)
-> std::basic_ostream<Ch, Tr>&
{
os << "(";
aux::print_tuple(os, t, aux::gen_seq<sizeof...(Args)>());
return os << ")";
}
bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, size_t nwarp){
typedef half_float::half NumericT;
std::string ty = "half";
size_t dt_nbytes = sizeof(NumericT);
@@ -71,12 +102,12 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
hb[i] = static_cast<NumericT>((float)rand()/RAND_MAX);
for(size_t i = 0; i < hc.size(); i++)
hc[i] = static_cast<NumericT>((double)0);
drv::buffer* dc = drv::buffer::create(context, hc.size()*dt_nbytes);
drv::buffer* da = drv::buffer::create(context, ha.size()*dt_nbytes);
drv::buffer* db = drv::buffer::create(context, hb.size()*dt_nbytes);
stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hc.size()*dt_nbytes));
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, ha.size()*dt_nbytes));
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hb.size()*dt_nbytes));
stream->write(&*da, true, 0, ha);
stream->write(&*db, true, 0, hb);
stream->write(&*dc, true, 0, hc);
stream->synchronize();
// run
rt::function::options_space_t opt;
@@ -85,81 +116,50 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
opt.defines.push_back({"AT", {""}});
if(BT)
opt.defines.push_back({"BT", {""}});
opt.defines.push_back({"TM", {"16", "32", "64", "128"}});
opt.defines.push_back({"TN", {"16", "32", "64", "128"}});
opt.defines.push_back({"TK", {"32"}});
opt.num_warps = {1, 2, 4, 8};
opt.defines.push_back({"TM", {std::to_string(TM)}});
opt.defines.push_back({"TN", {std::to_string(TN)}});
opt.defines.push_back({"TK", {std::to_string(TK)}});
opt.num_warps = {nwarp};
rt::function function(src::dot, opt);
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
auto grid = [&](const rt::function::options_t& x) { return rt::grid_t{ceil(M, x.D<int>("TM")), ceil(N, x.D<int>("TN")), 1}; };
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
perf_t res;
res.triton = tflops(triton::tools::bench([&]() { function({da, db, dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream));
NumericT alpha(static_cast<double>(1));
NumericT beta(static_cast<double>(0));
cublasGemmAlgo_t fastest;
cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, da, lda, db, ldb, &beta, dc, ldc, &fastest);
res.cublas = tflops(triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
&alpha, da, lda, db, ldb, &beta, dc, ldc, nullptr, fastest); },
stream));
try {
function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid(M, N), stream);
} catch (const std::runtime_error& e) {
return true;
}
// test
// stream->read(dc, true, 0, hc);
// std::vector<NumericT> rc(hc.size());
// cpu_ref(AT, BT, M, N, K, rc, ha, hb);
// for(size_t i = 0; i < M*N; i++)
// if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
// exit(EXIT_FAILURE);
// }
// std::cout << hc[0] << " " << std::endl;
// std::cout << "Pass!" << std::endl;
// clean-up
delete dc;
delete da;
delete db;
return res;
stream->read(&*dc, true, 0, hc);
std::vector<NumericT> rc(hc.size());
cpu_ref(AT, BT, M, N, K, rc, ha, hb);
for(size_t i = 0; i < M*N; i++)
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2)
return false;
return true;
}
int main() {
struct config_t{
bool AT;
bool BT;
int32_t M;
int32_t N;
int32_t K;
std::string repr() {
std::ostringstream oss;
oss << AT << " " << BT << " " << M << " " << N << " " << K;
return oss.str();
}
perf_t perf(triton::driver::stream *stream){
return do_bench(stream, AT, BT, M, N, K);
}
};
// shapes to benchmark
std::vector<config_t> configs = {
// {false, false, 8192, 512, 512},
{false, true, 128, 128, 128}
// {false, true, 128, 128, 128},
// {false, false, 128, 128, 128},
// {true, false, 128, 128, 128},
// {true, true, 128, 128, 128}
// {false, true, 32768, 256, 512}
// {true, false, 8192, 512, 512},
// {true, true, 8192, 512, 512}
};
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context);
// shapes to benchmark
typedef std::tuple<bool, bool, int, int, int, int, int, int, int> config_t;
std::vector<config_t> configs;
for(bool AT: std::array<bool, 2>{false, true})
for(bool BT: std::array<bool, 2>{false, true})
for(int TM: std::vector<int>{16, 128})
for(int TN: std::vector<int>{16, 128})
for(int TK: std::vector<int>{16, 32})
for(int nwarps: std::vector<int>{1, 2, 4, 8}){
configs.push_back(config_t{AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
}
// does the work
for(config_t c: configs){
perf_t perf = c.perf(stream);
std::cout << "// " << c.repr() << ", " << perf.triton << ", " << perf.cublas << std::endl;
bool AT, BT;
int M, N, K, TM, TN, TK, nwarp;
for(const auto& c: configs){
std::tie(AT, BT, M, N, K, TM, TN, TK, nwarp) = c;
std::cout << "Testing " << c << " ... " << std::flush;
if(do_test(stream, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp))
std::cout << " Pass! " << std::endl;
else
std::cout << " Fail! " << std::endl;
}
}