[PYTHON] Added support for FP16 scalar kernel arguments
This commit is contained in:
committed by
Philippe Tillet
parent
e7461a862b
commit
78cd54b0c8
@@ -53,13 +53,14 @@ inline bool is_int_type(arg_type ty){
|
||||
}
|
||||
|
||||
class arg {
|
||||
private:
|
||||
public:
|
||||
union value_t {
|
||||
bool int1;
|
||||
int8_t int8;
|
||||
int16_t int16;
|
||||
int32_t int32;
|
||||
int64_t int64;
|
||||
uint16_t fp16;
|
||||
float fp32;
|
||||
double fp64;
|
||||
driver::buffer* buf;
|
||||
@@ -67,6 +68,7 @@ private:
|
||||
|
||||
public:
|
||||
// construct from primitive types
|
||||
arg(arg_type ty, value_t val): ty_(ty) { val_ = val; }
|
||||
arg(int32_t x): ty_(INT32_T) { val_.int32 = x; }
|
||||
arg(int64_t x): ty_(INT64_T) { val_.int64 = x; }
|
||||
arg(float x): ty_(FLOAT_T) { val_.fp32 = x; }
|
||||
|
@@ -530,8 +530,36 @@ void gen_torch_make_handles(std::ostream &os,
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_val_struct_name(rt::arg_type ty){
|
||||
switch(ty){
|
||||
case rt::INT1_T: return "int1";
|
||||
case rt::INT8_T: return "int8";
|
||||
case rt::INT16_T: return "int16";
|
||||
case rt::INT32_T: return "int32";
|
||||
case rt::INT64_T: return "int64";
|
||||
case rt::HALF_T: return "fp16";
|
||||
case rt::FLOAT_T: return "fp32";
|
||||
case rt::DOUBLE_T: return "fp64";
|
||||
case rt::BUFFER_T: return "buf";
|
||||
default: return "";
|
||||
}
|
||||
}
|
||||
|
||||
void gen_torch_make_launch_function(std::ostream &os,
|
||||
const std::vector<rt::arg_type>& args) {
|
||||
os << " namespace rt = triton::runtime;\n ";
|
||||
os << " std::vector<rt::arg> args;\n ";
|
||||
for(unsigned i = 0; i < args.size(); i++){
|
||||
std::string name = "arg_" + std::to_string(i);
|
||||
if(args[i] == rt::BUFFER_T)
|
||||
name = "&" + name;
|
||||
if(args[i] == rt::HALF_T)
|
||||
name = "*((uint16_t*)&" + name + ")";
|
||||
os << "rt::arg_type ty" << i << " = (rt::arg_type)(" << args[i] << ");\n ";
|
||||
os << "rt::arg::value_t val" << i << ";\n ";
|
||||
os << "val" << i << "." << get_val_struct_name(args[i]) << " = " << name << ";\n ";
|
||||
os << "args.push_back(rt::arg(ty" << i << ", val" << i << "));\n ";
|
||||
}
|
||||
os << " std::function<void()> run = [&](){\n ";
|
||||
os << " (*id_fn_map.at({id, dev_id}))({";
|
||||
for(unsigned i = 0; i < args.size() ; i++){
|
||||
|
Reference in New Issue
Block a user