[PYTHON][OPS] Added batch normalization op
This commit is contained in:
@@ -211,27 +211,9 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
|
||||
os << ";\n";
|
||||
}
|
||||
|
||||
inline std::string preheader() {
|
||||
return
|
||||
R"(
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
#define __bool_true_false_are_defined 1
|
||||
|
||||
#define __readonly __attribute__((readonly))
|
||||
#define __writeonly __attribute__((writeonly))
|
||||
#define __noalias __attribute__((noalias))
|
||||
#define __aligned(A) __attribute__((aligned(A)))
|
||||
#define __multipleof(A) __attribute__((multipleof(A)))
|
||||
|
||||
extern int get_program_id(int);
|
||||
)";
|
||||
}
|
||||
|
||||
void make_module(const std::string& src, ir::module* ir,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
std::string copy = preheader() + src;
|
||||
std::string copy = triton::runtime::function::preheader() + src;
|
||||
// pre-process
|
||||
TokenSequence tokens;
|
||||
Preprocessor cpp(©, true);
|
||||
@@ -341,11 +323,11 @@ inline std::string to_torch_ty(ir::type *ty) {
|
||||
if(ty->is_integer_ty())
|
||||
return "int64_t";
|
||||
if(ty->is_half_ty())
|
||||
return "float16";
|
||||
return "double";
|
||||
if(ty->is_float_ty())
|
||||
return "float32";
|
||||
return "double";
|
||||
if(ty->is_double_ty())
|
||||
return "float64";
|
||||
return "double";
|
||||
if(ty->is_pointer_ty())
|
||||
return "torch::Tensor";
|
||||
throw std::runtime_error("unknown type");
|
||||
@@ -363,11 +345,11 @@ inline std::string to_c_ty(ir::type *ty) {
|
||||
if(ty->is_integer_ty(64))
|
||||
return "int64_t";
|
||||
if(ty->is_half_ty())
|
||||
return "float16";
|
||||
return "half";
|
||||
if(ty->is_float_ty())
|
||||
return "float32";
|
||||
return "float";
|
||||
if(ty->is_double_ty())
|
||||
return "float64";
|
||||
return "double";
|
||||
if(ty->is_pointer_ty())
|
||||
return "drv::cu_buffer";
|
||||
throw std::runtime_error("unknown type");
|
||||
|
Reference in New Issue
Block a user