diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index 36a9bacfb..8ff9dc854 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -69,14 +69,14 @@ class BlockSparseGemmOp : public OpKernel { [&](){ stream->synchronize(); }, ctx->device()); return 2.*M*N*K / ts * 1e-3; }; - std::string src = triton::dnn::gemm::src(false, false, "fp16", "fp16", 1, 1); + std::string src = triton::dnn::gemm::src(false, true, "fp16", "fp16", 1, 1); // just-in-time compile source-code -// jit.autotune("matmul", src.c_str(), benchmark); + jit.autotune("matmul", src.c_str(), benchmark); // jit.add_module("matmul", src.c_str(), {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1}); // jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1}); // jit.add_module("matmul", src.c_str(), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 }); // jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1}); -// jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN + jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN triton::driver::kernel* kernel = jit.get_function("matmul"); triton::jit::launch_information info = jit.get_launch_info("matmul"); std::cout << benchmark(kernel, info) << std::endl;; diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index db6d67702..e1d62f4cd 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -166,7 +166,7 @@ std::vector tune::get_params(ir::module &mod) { for(ir::instruction *i : block->get_inst_list()) for(auto &x: params_[i]) if(seen.insert(x.second).second && !x.second->has_value()){ - std::cout << i->get_name() << " " << x.first << std::endl; +// std::cout << i->get_name() << " " << x.first << std::endl; result.push_back(x.second); } @@ -233,13 +233,13 @@ void tune::run(ir::module &mod) { continue; if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 4, 4)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 4)); *params_.at(i).at("nts.d0") = *tmp; } if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 4, 4)); - std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 4, 4)); + std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 2, 4)); + std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 2, 4)); *params_.at(i).at("nts.d0") = *tmp1; *params_.at(i).at("nts.d1") = *tmp2; }