changed auto-tuner parameter ranges
This commit is contained in:
@@ -69,14 +69,14 @@ class BlockSparseGemmOp : public OpKernel {
|
|||||||
[&](){ stream->synchronize(); }, ctx->device());
|
[&](){ stream->synchronize(); }, ctx->device());
|
||||||
return 2.*M*N*K / ts * 1e-3;
|
return 2.*M*N*K / ts * 1e-3;
|
||||||
};
|
};
|
||||||
std::string src = triton::dnn::gemm::src(false, false, "fp16", "fp16", 1, 1);
|
std::string src = triton::dnn::gemm::src(false, true, "fp16", "fp16", 1, 1);
|
||||||
// just-in-time compile source-code
|
// just-in-time compile source-code
|
||||||
// jit.autotune("matmul", src.c_str(), benchmark);
|
jit.autotune("matmul", src.c_str(), benchmark);
|
||||||
// jit.add_module("matmul", src.c_str(), {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
|
// jit.add_module("matmul", src.c_str(), {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
|
||||||
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});
|
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});
|
||||||
// jit.add_module("matmul", src.c_str(), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 });
|
// jit.add_module("matmul", src.c_str(), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 });
|
||||||
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1});
|
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1});
|
||||||
// jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
|
jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
|
||||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||||
std::cout << benchmark(kernel, info) << std::endl;;
|
std::cout << benchmark(kernel, info) << std::endl;;
|
||||||
|
@@ -166,7 +166,7 @@ std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
|
|||||||
for(ir::instruction *i : block->get_inst_list())
|
for(ir::instruction *i : block->get_inst_list())
|
||||||
for(auto &x: params_[i])
|
for(auto &x: params_[i])
|
||||||
if(seen.insert(x.second).second && !x.second->has_value()){
|
if(seen.insert(x.second).second && !x.second->has_value()){
|
||||||
std::cout << i->get_name() << " " << x.first << std::endl;
|
// std::cout << i->get_name() << " " << x.first << std::endl;
|
||||||
result.push_back(x.second);
|
result.push_back(x.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -233,13 +233,13 @@ void tune::run(ir::module &mod) {
|
|||||||
continue;
|
continue;
|
||||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 4));
|
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||||
*params_.at(i).at("nts.d0") = *tmp;
|
*params_.at(i).at("nts.d0") = *tmp;
|
||||||
}
|
}
|
||||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 4));
|
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 4));
|
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||||
*params_.at(i).at("nts.d0") = *tmp1;
|
*params_.at(i).at("nts.d0") = *tmp1;
|
||||||
*params_.at(i).at("nts.d1") = *tmp2;
|
*params_.at(i).at("nts.d1") = *tmp2;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user