[PYTHON] Added support for FP16 scalar kernel arguments
This commit is contained in:
committed by
Philippe Tillet
parent
e7461a862b
commit
78cd54b0c8
@@ -530,8 +530,36 @@ void gen_torch_make_handles(std::ostream &os,
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_val_struct_name(rt::arg_type ty){
|
||||
switch(ty){
|
||||
case rt::INT1_T: return "int1";
|
||||
case rt::INT8_T: return "int8";
|
||||
case rt::INT16_T: return "int16";
|
||||
case rt::INT32_T: return "int32";
|
||||
case rt::INT64_T: return "int64";
|
||||
case rt::HALF_T: return "fp16";
|
||||
case rt::FLOAT_T: return "fp32";
|
||||
case rt::DOUBLE_T: return "fp64";
|
||||
case rt::BUFFER_T: return "buf";
|
||||
default: return "";
|
||||
}
|
||||
}
|
||||
|
||||
void gen_torch_make_launch_function(std::ostream &os,
|
||||
const std::vector<rt::arg_type>& args) {
|
||||
os << " namespace rt = triton::runtime;\n ";
|
||||
os << " std::vector<rt::arg> args;\n ";
|
||||
for(unsigned i = 0; i < args.size(); i++){
|
||||
std::string name = "arg_" + std::to_string(i);
|
||||
if(args[i] == rt::BUFFER_T)
|
||||
name = "&" + name;
|
||||
if(args[i] == rt::HALF_T)
|
||||
name = "*((uint16_t*)&" + name + ")";
|
||||
os << "rt::arg_type ty" << i << " = (rt::arg_type)(" << args[i] << ");\n ";
|
||||
os << "rt::arg::value_t val" << i << ";\n ";
|
||||
os << "val" << i << "." << get_val_struct_name(args[i]) << " = " << name << ";\n ";
|
||||
os << "args.push_back(rt::arg(ty" << i << ", val" << i << "));\n ";
|
||||
}
|
||||
os << " std::function<void()> run = [&](){\n ";
|
||||
os << " (*id_fn_map.at({id, dev_id}))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
|
Reference in New Issue
Block a user