[PYTHON][OPS] Added batch normalization op

This commit is contained in:
Philippe Tillet
2019-10-29 17:29:11 -04:00
parent d9eacf937c
commit d65a94c768
9 changed files with 108 additions and 34 deletions

View File

@@ -211,27 +211,9 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
os << ";\n";
}
inline std::string preheader() {
return
R"(
#define bool _Bool
#define true 1
#define false 0
#define __bool_true_false_are_defined 1
#define __readonly __attribute__((readonly))
#define __writeonly __attribute__((writeonly))
#define __noalias __attribute__((noalias))
#define __aligned(A) __attribute__((aligned(A)))
#define __multipleof(A) __attribute__((multipleof(A)))
extern int get_program_id(int);
)";
}
void make_module(const std::string& src, ir::module* ir,
const runtime::function::options_space_t& opt) {
std::string copy = preheader() + src;
std::string copy = triton::runtime::function::preheader() + src;
// pre-process
TokenSequence tokens;
Preprocessor cpp(&copy, true);
@@ -341,11 +323,11 @@ inline std::string to_torch_ty(ir::type *ty) {
if(ty->is_integer_ty())
return "int64_t";
if(ty->is_half_ty())
return "float16";
return "double";
if(ty->is_float_ty())
return "float32";
return "double";
if(ty->is_double_ty())
return "float64";
return "double";
if(ty->is_pointer_ty())
return "torch::Tensor";
throw std::runtime_error("unknown type");
@@ -363,11 +345,11 @@ inline std::string to_c_ty(ir::type *ty) {
if(ty->is_integer_ty(64))
return "int64_t";
if(ty->is_half_ty())
return "float16";
return "half";
if(ty->is_float_ty())
return "float32";
return "float";
if(ty->is_double_ty())
return "float64";
return "double";
if(ty->is_pointer_ty())
return "drv::cu_buffer";
throw std::runtime_error("unknown type");