[PYTHON] Added support for FP16 scalar kernel arguments

This commit is contained in:
Philippe Tillet
2020-05-15 16:10:20 -04:00
committed by Philippe Tillet
parent e7461a862b
commit 78cd54b0c8
2 changed files with 31 additions and 1 deletions

View File

@@ -530,8 +530,36 @@ void gen_torch_make_handles(std::ostream &os,
}
}
std::string get_val_struct_name(rt::arg_type ty){
switch(ty){
case rt::INT1_T: return "int1";
case rt::INT8_T: return "int8";
case rt::INT16_T: return "int16";
case rt::INT32_T: return "int32";
case rt::INT64_T: return "int64";
case rt::HALF_T: return "fp16";
case rt::FLOAT_T: return "fp32";
case rt::DOUBLE_T: return "fp64";
case rt::BUFFER_T: return "buf";
default: return "";
}
}
void gen_torch_make_launch_function(std::ostream &os,
const std::vector<rt::arg_type>& args) {
os << " namespace rt = triton::runtime;\n ";
os << " std::vector<rt::arg> args;\n ";
for(unsigned i = 0; i < args.size(); i++){
std::string name = "arg_" + std::to_string(i);
if(args[i] == rt::BUFFER_T)
name = "&" + name;
if(args[i] == rt::HALF_T)
name = "*((uint16_t*)&" + name + ")";
os << "rt::arg_type ty" << i << " = (rt::arg_type)(" << args[i] << ");\n ";
os << "rt::arg::value_t val" << i << ";\n ";
os << "val" << i << "." << get_val_struct_name(args[i]) << " = " << name << ";\n ";
os << "args.push_back(rt::arg(ty" << i << ", val" << i << "));\n ";
}
os << " std::function<void()> run = [&](){\n ";
os << " (*id_fn_map.at({id, dev_id}))({";
for(unsigned i = 0; i < args.size() ; i++){