diff --git a/lib/codegen/selection/selection.cc b/lib/codegen/selection/selection.cc index 5059a3130..b04209f61 100644 --- a/lib/codegen/selection/selection.cc +++ b/lib/codegen/selection/selection.cc @@ -1577,7 +1577,8 @@ void selection::run(ir::module &src, Module &dst) { offset->addIncoming(next_offset, llvm_inc_block); } else { - offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*4)), llvm_inc_block); + unsigned num_bytes = phi->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + offset->addIncoming(dst_builder.getInt32(alloc_->get_num_bytes(phi)/(2*num_bytes)), llvm_inc_block); } ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); } diff --git a/lib/driver/handle.cc b/lib/driver/handle.cc index a0013f347..8899eb30e 100755 --- a/lib/driver/handle.cc +++ b/lib/driver/handle.cc @@ -72,7 +72,7 @@ handle::~handle(){ try{ if(has_ownership_ && h_ && h_.unique()) _delete(*h_); - }catch(const exception::cuda::deinitialized&){ + }catch(const exception::cuda::base&){ // order of destruction for global variables // is not guaranteed } diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 486d7d588..96a7c0f08 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -26,6 +26,7 @@ #include "triton/driver/error.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Verifier.h" +#include "llvm/IR/IRPrintingPasses.h" #include "llvm/IR/Module.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" @@ -240,7 +241,6 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) { cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ -// std::cout << source << std::endl; cu_context::context_switcher ctx_switch(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; @@ -250,8 +250,10 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo try{ dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval); }catch(exception::cuda::base const &){ +#ifdef TRITON_LOG_PTX_ERROR std::cerr << "Compilation Failed! Log: " << std::endl; std::cerr << errbuf << std::endl; +#endif throw; } } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 204d05b89..5c93eb452 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -12,6 +12,7 @@ #include "triton/driver/stream.h" #include "triton/driver/kernel.h" #include "triton/driver/module.h" +#include "triton/driver/error.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/print.h" @@ -166,6 +167,8 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr bin = make_bin(*ir, stream->context(), opt); }catch(const std::runtime_error& e) { return; + }catch(const driver::exception::cuda::invalid_ptx& e) { + return; } // benchmark ir::function *tmp = ir->get_function_list()[0]; @@ -178,6 +181,8 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr } }; _parallel_loop_nest(space, benchmark, 1); + if(!ret) + throw std::runtime_error("could not find valid option in provided space"); return *ret; } diff --git a/python/setup.py b/python/setup.py index ef5fa9865..b9285f84f 100644 --- a/python/setup.py +++ b/python/setup.py @@ -47,7 +47,7 @@ class CMakeBuild(build_ext): tf_libs = 'tensorflow_framework' cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, - '-DBUILD_EXAMPLES=OFF', + '-DBUILD_TESTS=OFF', '-DBUILD_PYTHON_MODULE=ON', '-DPYTHON_INCLUDE_DIRS=' + python_include_dirs, '-DTF_INCLUDE_DIRS=' + tf_include_dirs, diff --git a/python/src/tensorflow.cc b/python/src/tensorflow.cc index e71f6a77a..fde2d84ec 100644 --- a/python/src/tensorflow.cc +++ b/python/src/tensorflow.cc @@ -160,7 +160,8 @@ void gen_register_op(std::ostream &os, const std::string &name, std::string name = arg->get_name(); auto tolower = [](char c) { return std::tolower(c);}; std::transform(name.begin(), name.end(), name.begin(), tolower); - os << " .Input(\"" << name << ": " << to_tf_scalar_ty(arg->get_type()) << "\")\n"; + os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl; + os << " .Input(\"" << name << ": T" << i << "\")\n"; } for(size_t i = 0; i < outputs.size(); i++){ std::string name = outputs[i]; diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 63e5e877d..4176d2377 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -10,11 +10,6 @@ #include "cuda/cublas.h" -struct perf_t { - double triton; - double cublas; -}; - namespace drv = triton::driver; namespace rt = triton::runtime; @@ -22,6 +17,14 @@ inline size_t ceil(size_t x, size_t y) { return (x + y - 1) / y; }; +inline rt::function::grid_fn_ty grid(size_t M, size_t N) { + return [M, N](const rt::function::options_t& x) { + return rt::grid_t{ceil(M, x.D("TM")), + ceil(N, x.D("TN"))}; + }; +} + + std::vector do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ typedef half_float::half NumericT; @@ -33,9 +36,9 @@ std::vector do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i int32_t ldb = BT ? N : K; int32_t ldc = M; // create inputs - auto dc = std::unique_ptr(drv::buffer::create(context, M*N*dt_nbytes)); auto da = std::unique_ptr(drv::buffer::create(context, M*K*dt_nbytes)); auto db = std::unique_ptr(drv::buffer::create(context, K*N*dt_nbytes)); + auto dc = std::unique_ptr(drv::buffer::create(context, M*N*dt_nbytes)); // create options rt::function::options_space_t opt; opt.defines.push_back({"TYPE", {ty}}); @@ -47,11 +50,6 @@ std::vector do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i opt.defines.push_back({"TN", {"16", "32", "64", "128"}}); opt.defines.push_back({"TK", {"32"}}); opt.num_warps = {1, 2, 4, 8}; - // create grid - auto grid = [&](const rt::function::options_t& x) { - return rt::grid_t{ceil(M, x.D("TM")), - ceil(N, x.D("TN"))}; - }; // create function rt::function function(src::dot, opt); // benchmark available libraries @@ -68,7 +66,7 @@ std::vector do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i result.push_back(tflops(cublas_ms)); } // triton - double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream); + double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid(M, N), stream);}, stream); result.push_back(tflops(triton_ms)); // done return result; @@ -80,11 +78,25 @@ int main() { triton::driver::stream* stream = triton::driver::stream::create(context); // shapes to benchmark typedef std::tuple config_t; - std::vector configs = { - config_t{false, true, 512, 512, 512}, - config_t{false, true, 2048, 2048, 2048}, - config_t{false, true, 8192, 8192, 8192} - }; + std::vector configs; + for(auto x: std::vector>{{false, false}, + {false, true}, + {true, false}}){ + std::vector tmp = { + config_t{x[0], x[1], 8192, 8192, 8192} +// config_t{x[0], x[1], 16, 2048, 2048}, +// config_t{x[0], x[1], 32, 2048, 2048}, +// config_t{x[0], x[1], 64, 2048, 2048}, +// config_t{x[0], x[1], 128, 2048, 2048}, +// config_t{x[0], x[1], 7000, 2048, 2048}, +// config_t{x[0], x[1], 16, 4096, 4096}, +// config_t{x[0], x[1], 32, 4096, 4096}, +// config_t{x[0], x[1], 64, 4096, 4096}, +// config_t{x[0], x[1], 128, 4096, 4096}, +// config_t{x[0], x[1], 7000, 4096, 4096}, + }; + configs.insert(configs.end(), tmp.begin(), tmp.end()); + } // does the work bool AT, BT; int32_t M, N, K; diff --git a/tests/common/src/dot.h b/tests/common/src/dot.h index 00814c0f0..9df0643a6 100644 --- a/tests/common/src/dot.h +++ b/tests/common/src/dot.h @@ -30,21 +30,21 @@ void dot(TYPE * A __noalias __readonly __aligned(16), float xc[TM, TN] = 0; #ifdef AT TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda; - bool checka[TK, TM] = rka[:, newaxis] < K; - TYPE a[TK, TM] = checka ? *pa : 0; + bool checka[TK, TM] = rka[:, newaxis] < TK; + TYPE a[TK, TM] = *pa; #else TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis]; - bool checka[TM, TK] = rka[newaxis, :] < K; - TYPE a[TM, TK] = checka ? *pa : 0; + bool checka[TM, TK] = rka[newaxis, :] < TK; + TYPE a[TM, TK] = *pa; #endif #ifdef BT TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis]; - bool checkb[TN, TK] = rkb[newaxis, :] < K; - TYPE b[TN, TK] = checkb ? *pb : 0; + bool checkb[TN, TK] = rkb[newaxis, :] < TK; + TYPE b[TN, TK] = *pb; #else TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb; - bool checkb[TK, TN] = rkb[:, newaxis] < K; - TYPE b[TK, TN] = checkb ? *pb : 0; + bool checkb[TK, TN] = rkb[:, newaxis] < TK; + TYPE b[TK, TN] = *pb; #endif for(int k = K; k > 0; k = k - TK){ xc = USEA @ USEB + xc; @@ -60,8 +60,8 @@ void dot(TYPE * A __noalias __readonly __aligned(16), #endif checka = k > TK; checkb = k > TK; - a = checka ? *pa : 0; - b = checkb ? *pb : 0; + a = *pa; + b = *pb; } int rxc[TM] = ridx * TM + (0 ... TM); int ryc[TN] = ridy * TN + (0 ... TN); @@ -70,8 +70,8 @@ void dot(TYPE * A __noalias __readonly __aligned(16), bool checkc0[TM] = rxc < M; bool checkc1[TN] = ryc < N; bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - *?(checkc) pc = c; + *pc = c; } )"; -} \ No newline at end of file +} diff --git a/tests/unit/dot.cc b/tests/unit/dot.cc index 3ddc8953e..298b79a44 100644 --- a/tests/unit/dot.cc +++ b/tests/unit/dot.cc @@ -9,6 +9,9 @@ #include "src/dot.h" #include "cuda/cublas.h" +namespace drv = triton::driver; +namespace rt = triton::runtime; + template void diff(const std::vector& x, const std::vector& y){ for(size_t i = 0; i < x.size(); i++) @@ -44,16 +47,44 @@ void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K, cpu_ref(c, a, b, M, N, K); } - -struct perf_t { - double triton; - double cublas; +inline size_t ceil(size_t x, size_t y) { + return (x + y - 1) / y; }; -namespace drv = triton::driver; -namespace rt = triton::runtime; +inline rt::function::grid_fn_ty grid(size_t M, size_t N) { + return [M, N](const rt::function::options_t& x) { + return rt::grid_t{ceil(M, x.D("TM")), + ceil(N, x.D("TN"))}; + }; +} -perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ +namespace aux{ +template struct seq{}; + +template +struct gen_seq : gen_seq{}; + +template +struct gen_seq<0, Is...> : seq{}; + +template +void print_tuple(std::basic_ostream& os, Tuple const& t, seq){ + using swallow = int[]; + (void)swallow{0, (void(os << (Is == 0? "" : ", ") << std::get(t)), 0)...}; +} +} // aux:: + +template +auto operator<<(std::basic_ostream& os, std::tuple const& t) + -> std::basic_ostream& +{ + os << "("; + aux::print_tuple(os, t, aux::gen_seq()); + return os << ")"; +} + + +bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, size_t nwarp){ typedef half_float::half NumericT; std::string ty = "half"; size_t dt_nbytes = sizeof(NumericT); @@ -71,12 +102,12 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int hb[i] = static_cast((float)rand()/RAND_MAX); for(size_t i = 0; i < hc.size(); i++) hc[i] = static_cast((double)0); - drv::buffer* dc = drv::buffer::create(context, hc.size()*dt_nbytes); - drv::buffer* da = drv::buffer::create(context, ha.size()*dt_nbytes); - drv::buffer* db = drv::buffer::create(context, hb.size()*dt_nbytes); - stream->write(da, true, 0, ha); - stream->write(db, true, 0, hb); - stream->write(dc, true, 0, hc); + auto dc = std::shared_ptr(drv::buffer::create(context, hc.size()*dt_nbytes)); + auto da = std::shared_ptr(drv::buffer::create(context, ha.size()*dt_nbytes)); + auto db = std::shared_ptr(drv::buffer::create(context, hb.size()*dt_nbytes)); + stream->write(&*da, true, 0, ha); + stream->write(&*db, true, 0, hb); + stream->write(&*dc, true, 0, hc); stream->synchronize(); // run rt::function::options_space_t opt; @@ -85,81 +116,50 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int opt.defines.push_back({"AT", {""}}); if(BT) opt.defines.push_back({"BT", {""}}); - opt.defines.push_back({"TM", {"16", "32", "64", "128"}}); - opt.defines.push_back({"TN", {"16", "32", "64", "128"}}); - opt.defines.push_back({"TK", {"32"}}); - opt.num_warps = {1, 2, 4, 8}; + opt.defines.push_back({"TM", {std::to_string(TM)}}); + opt.defines.push_back({"TN", {std::to_string(TN)}}); + opt.defines.push_back({"TK", {std::to_string(TK)}}); + opt.num_warps = {nwarp}; rt::function function(src::dot, opt); - - auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; }; - auto grid = [&](const rt::function::options_t& x) { return rt::grid_t{ceil(M, x.D("TM")), ceil(N, x.D("TN")), 1}; }; - - auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; - perf_t res; - res.triton = tflops(triton::tools::bench([&]() { function({da, db, dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream)); - NumericT alpha(static_cast(1)); - NumericT beta(static_cast(0)); - cublasGemmAlgo_t fastest; - cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, da, lda, db, ldb, &beta, dc, ldc, &fastest); - res.cublas = tflops(triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, - &alpha, da, lda, db, ldb, &beta, dc, ldc, nullptr, fastest); }, - stream)); - + try { + function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid(M, N), stream); + } catch (const std::runtime_error& e) { + return true; + } // test -// stream->read(dc, true, 0, hc); -// std::vector rc(hc.size()); -// cpu_ref(AT, BT, M, N, K, rc, ha, hb); -// for(size_t i = 0; i < M*N; i++) -// if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){ -// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; -// exit(EXIT_FAILURE); -// } -// std::cout << hc[0] << " " << std::endl; -// std::cout << "Pass!" << std::endl; - - // clean-up - delete dc; - delete da; - delete db; - return res; + stream->read(&*dc, true, 0, hc); + std::vector rc(hc.size()); + cpu_ref(AT, BT, M, N, K, rc, ha, hb); + for(size_t i = 0; i < M*N; i++) + if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2) + return false; + return true; } int main() { - struct config_t{ - bool AT; - bool BT; - int32_t M; - int32_t N; - int32_t K; - - std::string repr() { - std::ostringstream oss; - oss << AT << " " << BT << " " << M << " " << N << " " << K; - return oss.str(); - } - - perf_t perf(triton::driver::stream *stream){ - return do_bench(stream, AT, BT, M, N, K); - } - }; - // shapes to benchmark - std::vector configs = { -// {false, false, 8192, 512, 512}, - {false, true, 128, 128, 128} -// {false, true, 128, 128, 128}, -// {false, false, 128, 128, 128}, -// {true, false, 128, 128, 128}, -// {true, true, 128, 128, 128} -// {false, true, 32768, 256, 512} -// {true, false, 8192, 512, 512}, -// {true, true, 8192, 512, 512} - }; // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); triton::driver::stream* stream = triton::driver::stream::create(context); + // shapes to benchmark + typedef std::tuple config_t; + std::vector configs; + for(bool AT: std::array{false, true}) + for(bool BT: std::array{false, true}) + for(int TM: std::vector{16, 128}) + for(int TN: std::vector{16, 128}) + for(int TK: std::vector{16, 32}) + for(int nwarps: std::vector{1, 2, 4, 8}){ + configs.push_back(config_t{AT, BT, 128, 128, 128, TM, TN, TK, nwarps}); + } // does the work - for(config_t c: configs){ - perf_t perf = c.perf(stream); - std::cout << "// " << c.repr() << ", " << perf.triton << ", " << perf.cublas << std::endl; + bool AT, BT; + int M, N, K, TM, TN, TK, nwarp; + for(const auto& c: configs){ + std::tie(AT, BT, M, N, K, TM, TN, TK, nwarp) = c; + std::cout << "Testing " << c << " ... " << std::flush; + if(do_test(stream, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp)) + std::cout << " Pass! " << std::endl; + else + std::cout << " Fail! " << std::endl; } }