diff --git a/examples/cpp/dot.cc b/examples/cpp/dot.cc index 6e40f79d2..409b77217 100644 --- a/examples/cpp/dot.cc +++ b/examples/cpp/dot.cc @@ -153,10 +153,10 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int opt.defines.push_back({"AT", {""}}); if(BT) opt.defines.push_back({"BT", {""}}); - opt.defines.push_back({"TM", {"32"}}); - opt.defines.push_back({"TN", {"32"}}); + opt.defines.push_back({"TM", {"128"}}); + opt.defines.push_back({"TN", {"128"}}); opt.defines.push_back({"TK", {"32"}}); - opt.num_warps = {1, 2, 4, 8}; + opt.num_warps = {4}; rt::function function(src, opt); auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; }; @@ -169,16 +169,16 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int // test stream->synchronize(); - stream->read(dc, true, 0, hc); - std::vector rc(hc.size()); - cpu_ref(AT, BT, M, N, K, rc, ha, hb); - for(size_t i = 0; i < M*N; i++) - if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){ - std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; - exit(EXIT_FAILURE); - } - std::cout << hc[0] << " " << std::endl; - std::cout << "Pass!" << std::endl; +// stream->read(dc, true, 0, hc); +// std::vector rc(hc.size()); +// cpu_ref(AT, BT, M, N, K, rc, ha, hb); +// for(size_t i = 0; i < M*N; i++) +// if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){ +// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; +// exit(EXIT_FAILURE); +// } +// std::cout << hc[0] << " " << std::endl; +// std::cout << "Pass!" << std::endl; // clean-up delete dc; @@ -208,7 +208,7 @@ int main() { // shapes to benchmark std::vector configs = { // {false, false, 8192, 512, 512}, - {false, true, 128, 128, 128} + {false, true, 8192, 8192, 8192} // {false, true, 128, 128, 128}, // {false, false, 128, 128, 128}, // {true, false, 128, 128, 128}, diff --git a/lib/codegen/analysis/alignment.cc b/lib/codegen/analysis/alignment.cc index 69cf3479c..276422d10 100644 --- a/lib/codegen/analysis/alignment.cc +++ b/lib/codegen/analysis/alignment.cc @@ -4,6 +4,7 @@ #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" #include "triton/ir/type.h" +#include namespace triton { namespace codegen{ @@ -304,7 +305,7 @@ void alignment_info::run(ir::module &mod) { for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i: block->get_inst_list()){ populate_max_contiguous(i); -// std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << starting_multiple_.at(i) << " " << max_contiguous_.at(i) << std::endl; + std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << starting_multiple_.at(i) << " " << max_contiguous_.at(i) << std::endl; } } diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 486d7d588..2ed0160f7 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -240,7 +240,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) { cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ -// std::cout << source << std::endl; + std::cout << source << std::endl; cu_context::context_switcher ctx_switch(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index fdc9b6d15..f2f0a42bb 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -14,6 +14,7 @@ #include "triton/driver/module.h" #include "triton/ir/module.h" #include "triton/ir/function.h" +#include "triton/ir/print.h" #include "triton/tools/bench.hpp" #include "llvm/IR/Module.h" @@ -205,6 +206,8 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c shmem_allocation.run(); shmem_barriers.run(module); } + dce.run(module); + ir::print(module, std::cout); alignment_info.run(module); vectorize.run(module); dce.run(module);