[jit] changed default metaparameter ranges
This commit is contained in:
@@ -123,6 +123,7 @@ int main() {
|
||||
};
|
||||
triton::jit jit(context);
|
||||
jit.add_module(src, params);
|
||||
jit.autotune(src, benchmark);
|
||||
triton::driver::kernel kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
benchmark(kernel, info);
|
||||
|
@@ -412,7 +412,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
|
||||
assert(expr_ == nullptr);
|
||||
//TODO: implement ranges
|
||||
value = ir::metaparameter::create(mod->get_context(), ty, 4, 8);
|
||||
value = ir::metaparameter::create(mod->get_context(), ty, 8, 64);
|
||||
}
|
||||
if(expr_){
|
||||
value = expr_->codegen(mod);
|
||||
|
@@ -145,9 +145,9 @@ void tune::run(ir::module &mod) {
|
||||
// Layout parameters
|
||||
while(!nodes_.empty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
ir::metaparameter *mp0 = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
ir::metaparameter *mp1 = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
ir::metaparameter *mp2 = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
ir::metaparameter *mp0 = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||
ir::metaparameter *mp1 = ir::metaparameter::create(ctx, ty, 4, 8);
|
||||
ir::metaparameter *mp2 = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
connected_components(*nodes_.begin(), {mp0, mp1, mp2}, nodes_, dependencies_);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user