[PYTHON][OPS] Added batch normalization op
This commit is contained in:
@@ -99,6 +99,8 @@ private:
|
||||
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options_t &opt);
|
||||
caller autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
|
||||
|
||||
public:
|
||||
static std::string preheader();
|
||||
|
||||
public:
|
||||
function(const std::string& src, const options_space_t& opt = options_space_t());
|
||||
|
@@ -289,6 +289,11 @@ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) {
|
||||
|
||||
|
||||
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
if(!x->get_type()->is_tile_ty()){
|
||||
Value *ptr = get_value(x->get_pointer_operand(), {});
|
||||
set_value(x, {}, builder_->CreateLoad(ptr));
|
||||
return;
|
||||
}
|
||||
// find vector size
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr)->order[0];
|
||||
|
@@ -229,6 +229,13 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
||||
ir::value* ret = ret_;
|
||||
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
|
||||
return set_ret(bld_->create_get_program_id(axis->get_value()));
|
||||
else
|
||||
return should_not_happen();
|
||||
}
|
||||
if(name == "sqrtf"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
ir::value* ret = ret_;
|
||||
return set_ret(bld_->create_sqrt(ret));
|
||||
}
|
||||
return error_not_implemented();
|
||||
}
|
||||
@@ -274,10 +281,11 @@ void Generator::VisitDeclaration(Declaration* decl) {
|
||||
// initialize declaration
|
||||
ir::type::id_t id = ty->get_type_id();
|
||||
if(id == ir::type::StructTyID)
|
||||
assert(false);
|
||||
should_not_happen();
|
||||
if(inits.size() > 1)
|
||||
assert(false);
|
||||
val = inits[0];
|
||||
should_not_happen();
|
||||
if(inits.size() > 0)
|
||||
val = inits[0];
|
||||
assert(val->get_type() == ty);
|
||||
// update scope symbols table
|
||||
const std::string &name = obj->Name();
|
||||
|
@@ -113,7 +113,7 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid,
|
||||
arg arg_i = args.at(i);
|
||||
arg_type ty = arg_i.type();
|
||||
if(ty != param_tys_.at(i))
|
||||
throw std::runtime_error("invalid type");
|
||||
throw std::runtime_error("invalid type for argument " + std::to_string(i));
|
||||
if(ty == BUFFER_T)
|
||||
bin_->setArg(i, *((driver::buffer**)arg_i.data()));
|
||||
else
|
||||
@@ -253,16 +253,14 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
return std::unique_ptr<driver::module>();
|
||||
barriers.run(module);
|
||||
// std::cout << "isel" << std::endl;
|
||||
// ir::print(module, std::cout);
|
||||
isel.visit(module, *llvm);
|
||||
// std::cout << "done" << std::endl;
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
// done
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string preheader() {
|
||||
std::string function::preheader() {
|
||||
return
|
||||
R"(
|
||||
#define bool _Bool
|
||||
@@ -277,6 +275,7 @@ R"(
|
||||
#define __multipleof(A) __attribute__((multipleof(A)))
|
||||
|
||||
extern int get_program_id(int);
|
||||
extern float sqrtf(float);
|
||||
)";
|
||||
}
|
||||
|
||||
|
@@ -77,7 +77,7 @@ class CMakeBuild(build_ext):
|
||||
pass
|
||||
|
||||
cfg = 'Debug' if self.debug else 'Release'
|
||||
cfg = 'Release'
|
||||
#cfg = 'Release'
|
||||
build_args = ['--config', cfg]
|
||||
|
||||
if platform.system() == "Windows":
|
||||
|
@@ -211,27 +211,9 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
|
||||
os << ";\n";
|
||||
}
|
||||
|
||||
inline std::string preheader() {
|
||||
return
|
||||
R"(
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
#define __bool_true_false_are_defined 1
|
||||
|
||||
#define __readonly __attribute__((readonly))
|
||||
#define __writeonly __attribute__((writeonly))
|
||||
#define __noalias __attribute__((noalias))
|
||||
#define __aligned(A) __attribute__((aligned(A)))
|
||||
#define __multipleof(A) __attribute__((multipleof(A)))
|
||||
|
||||
extern int get_program_id(int);
|
||||
)";
|
||||
}
|
||||
|
||||
void make_module(const std::string& src, ir::module* ir,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
std::string copy = preheader() + src;
|
||||
std::string copy = triton::runtime::function::preheader() + src;
|
||||
// pre-process
|
||||
TokenSequence tokens;
|
||||
Preprocessor cpp(©, true);
|
||||
@@ -341,11 +323,11 @@ inline std::string to_torch_ty(ir::type *ty) {
|
||||
if(ty->is_integer_ty())
|
||||
return "int64_t";
|
||||
if(ty->is_half_ty())
|
||||
return "float16";
|
||||
return "double";
|
||||
if(ty->is_float_ty())
|
||||
return "float32";
|
||||
return "double";
|
||||
if(ty->is_double_ty())
|
||||
return "float64";
|
||||
return "double";
|
||||
if(ty->is_pointer_ty())
|
||||
return "torch::Tensor";
|
||||
throw std::runtime_error("unknown type");
|
||||
@@ -363,11 +345,11 @@ inline std::string to_c_ty(ir::type *ty) {
|
||||
if(ty->is_integer_ty(64))
|
||||
return "int64_t";
|
||||
if(ty->is_half_ty())
|
||||
return "float16";
|
||||
return "half";
|
||||
if(ty->is_float_ty())
|
||||
return "float32";
|
||||
return "float";
|
||||
if(ty->is_double_ty())
|
||||
return "float64";
|
||||
return "double";
|
||||
if(ty->is_pointer_ty())
|
||||
return "drv::cu_buffer";
|
||||
throw std::runtime_error("unknown type");
|
||||
|
@@ -1,2 +1,3 @@
|
||||
from .dot import _dot, dot
|
||||
from .einsum import _einsum, einsum
|
||||
from .batchnorm import _batchnorm, batchnorm
|
75
python/triton/ops/batchnorm.py
Normal file
75
python/triton/ops/batchnorm.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import triton
|
||||
import math
|
||||
|
||||
class _batchnorm(triton.function):
|
||||
|
||||
fwd_src = """
|
||||
void batchnormForward(float *Y, float *M, float *V,
|
||||
float *X, float *G, float *B,
|
||||
int N, float rcpN, float eps) {
|
||||
int rx[TM] = 0 ... TM;
|
||||
float *px[TM];
|
||||
float x[TM] = 0;
|
||||
int c = get_program_id(1);
|
||||
float g = *(G + c);
|
||||
float b = *(B + c);
|
||||
|
||||
float mean[TM] = 0;
|
||||
px = X + rx + c*N;
|
||||
for(int i = 0; i < N; i = i + TM){
|
||||
x = *px;
|
||||
mean = mean + x;
|
||||
px = px + TM;
|
||||
}
|
||||
float *pm = M + c;
|
||||
float m = mean[+] * rcpN;
|
||||
*pm = m;
|
||||
|
||||
float var[TM] = 0;
|
||||
px = X + rx + c*N;
|
||||
for(int i = 0; i < N; i = i + TM){
|
||||
x = *px;
|
||||
x = x - m;
|
||||
var = var + x*x;
|
||||
px = px + TM;
|
||||
}
|
||||
float v = var[+] * rcpN;
|
||||
float *pv = V + c;
|
||||
*pv = v;
|
||||
float rstdg = 1 / sqrtf(v + eps) * g;
|
||||
|
||||
px = X + rx + c*N;
|
||||
float* py[TM] = Y + rx + c*N;
|
||||
for(int i = 0; i < N; i = i + TM){
|
||||
x = *px;
|
||||
float y[TM] = (x - m)*rstdg + b;
|
||||
*py = y;
|
||||
px = px + TM;
|
||||
py = py + TM;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V'])
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, gamma, beta, eps):
|
||||
shape = triton.shape(x)
|
||||
dtype = x.dtype
|
||||
# allocate outputs
|
||||
C, H, W, B = shape[0], shape[1], shape[2], shape[3]
|
||||
y = triton.empty(shape, dtype=dtype)
|
||||
mean = triton.empty([C], dtype=dtype)
|
||||
var = triton.empty([C], dtype=dtype)
|
||||
# execute kernels
|
||||
N = H*W*B
|
||||
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, N, 1./N, eps,
|
||||
lambda opt: [1, C],
|
||||
TM = 128)
|
||||
# save
|
||||
ctx.eps = eps
|
||||
ctx.save_for_backward(x, gamma, beta, mean, var)
|
||||
return y, mean, var
|
||||
|
||||
|
||||
batchnorm = _batchnorm.apply
|
@@ -181,15 +181,16 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C,
|
||||
@staticmethod
|
||||
def forward(ctx, subscripts, a, b, bench = 0):
|
||||
ctx.save_for_backward(a, b)
|
||||
# parse
|
||||
if type(subscripts) is str:
|
||||
einsum_a, einsum_bc = subscripts.split(",")
|
||||
einsum_b, einsum_c = einsum_bc.split("->")
|
||||
else:
|
||||
einsum_a, einsum_b, einsum_c = subscripts
|
||||
|
||||
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
|
||||
einsum_a, einsum_b, einsum_c,
|
||||
triton.shape(a), triton.shape(b))
|
||||
# save for backward
|
||||
ctx.trans_a = ta
|
||||
ctx.trans_b = tb
|
||||
ctx.einsum_a = einsum_a
|
||||
@@ -197,6 +198,7 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C,
|
||||
ctx.einsum_c = einsum_c
|
||||
ctx.bench = bench
|
||||
ctx.bmnk = bmnk
|
||||
# run
|
||||
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user