From 78cd54b0c820a506968466c3a390ace273e30c37 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 15 May 2020 16:10:20 -0400 Subject: [PATCH] [PYTHON] Added support for FP16 scalar kernel arguments --- include/triton/runtime/arg.h | 4 +++- python/src/bindings.cc | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/include/triton/runtime/arg.h b/include/triton/runtime/arg.h index 4c932b8a4..acee97ae8 100644 --- a/include/triton/runtime/arg.h +++ b/include/triton/runtime/arg.h @@ -53,13 +53,14 @@ inline bool is_int_type(arg_type ty){ } class arg { -private: +public: union value_t { bool int1; int8_t int8; int16_t int16; int32_t int32; int64_t int64; + uint16_t fp16; float fp32; double fp64; driver::buffer* buf; @@ -67,6 +68,7 @@ private: public: // construct from primitive types + arg(arg_type ty, value_t val): ty_(ty) { val_ = val; } arg(int32_t x): ty_(INT32_T) { val_.int32 = x; } arg(int64_t x): ty_(INT64_T) { val_.int64 = x; } arg(float x): ty_(FLOAT_T) { val_.fp32 = x; } diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 696e2e2b6..e36760316 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -530,8 +530,36 @@ void gen_torch_make_handles(std::ostream &os, } } +std::string get_val_struct_name(rt::arg_type ty){ + switch(ty){ + case rt::INT1_T: return "int1"; + case rt::INT8_T: return "int8"; + case rt::INT16_T: return "int16"; + case rt::INT32_T: return "int32"; + case rt::INT64_T: return "int64"; + case rt::HALF_T: return "fp16"; + case rt::FLOAT_T: return "fp32"; + case rt::DOUBLE_T: return "fp64"; + case rt::BUFFER_T: return "buf"; + default: return ""; + } +} + void gen_torch_make_launch_function(std::ostream &os, const std::vector& args) { + os << " namespace rt = triton::runtime;\n "; + os << " std::vector args;\n "; + for(unsigned i = 0; i < args.size(); i++){ + std::string name = "arg_" + std::to_string(i); + if(args[i] == rt::BUFFER_T) + name = "&" + name; + if(args[i] == rt::HALF_T) + name = "*((uint16_t*)&" + name + ")"; + os << "rt::arg_type ty" << i << " = (rt::arg_type)(" << args[i] << ");\n "; + os << "rt::arg::value_t val" << i << ";\n "; + os << "val" << i << "." << get_val_struct_name(args[i]) << " = " << name << ";\n "; + os << "args.push_back(rt::arg(ty" << i << ", val" << i << "));\n "; + } os << " std::function run = [&](){\n "; os << " (*id_fn_map.at({id, dev_id}))({"; for(unsigned i = 0; i < args.size() ; i++){