[PYTHON][OPS] Added batch normalization op
This commit is contained in:
@@ -99,6 +99,8 @@ private:
|
|||||||
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options_t &opt);
|
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options_t &opt);
|
||||||
caller autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
|
caller autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
|
||||||
|
|
||||||
|
public:
|
||||||
|
static std::string preheader();
|
||||||
|
|
||||||
public:
|
public:
|
||||||
function(const std::string& src, const options_space_t& opt = options_space_t());
|
function(const std::string& src, const options_space_t& opt = options_space_t());
|
||||||
|
@@ -289,6 +289,11 @@ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) {
|
|||||||
|
|
||||||
|
|
||||||
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||||
|
if(!x->get_type()->is_tile_ty()){
|
||||||
|
Value *ptr = get_value(x->get_pointer_operand(), {});
|
||||||
|
set_value(x, {}, builder_->CreateLoad(ptr));
|
||||||
|
return;
|
||||||
|
}
|
||||||
// find vector size
|
// find vector size
|
||||||
ir::value *ptr = x->get_pointer_operand();
|
ir::value *ptr = x->get_pointer_operand();
|
||||||
size_t ld = layouts_->get(ptr)->order[0];
|
size_t ld = layouts_->get(ptr)->order[0];
|
||||||
|
@@ -229,6 +229,13 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
|||||||
ir::value* ret = ret_;
|
ir::value* ret = ret_;
|
||||||
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
|
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
|
||||||
return set_ret(bld_->create_get_program_id(axis->get_value()));
|
return set_ret(bld_->create_get_program_id(axis->get_value()));
|
||||||
|
else
|
||||||
|
return should_not_happen();
|
||||||
|
}
|
||||||
|
if(name == "sqrtf"){
|
||||||
|
VisitExpr(funcCall->Args()->at(0));
|
||||||
|
ir::value* ret = ret_;
|
||||||
|
return set_ret(bld_->create_sqrt(ret));
|
||||||
}
|
}
|
||||||
return error_not_implemented();
|
return error_not_implemented();
|
||||||
}
|
}
|
||||||
@@ -274,10 +281,11 @@ void Generator::VisitDeclaration(Declaration* decl) {
|
|||||||
// initialize declaration
|
// initialize declaration
|
||||||
ir::type::id_t id = ty->get_type_id();
|
ir::type::id_t id = ty->get_type_id();
|
||||||
if(id == ir::type::StructTyID)
|
if(id == ir::type::StructTyID)
|
||||||
assert(false);
|
should_not_happen();
|
||||||
if(inits.size() > 1)
|
if(inits.size() > 1)
|
||||||
assert(false);
|
should_not_happen();
|
||||||
val = inits[0];
|
if(inits.size() > 0)
|
||||||
|
val = inits[0];
|
||||||
assert(val->get_type() == ty);
|
assert(val->get_type() == ty);
|
||||||
// update scope symbols table
|
// update scope symbols table
|
||||||
const std::string &name = obj->Name();
|
const std::string &name = obj->Name();
|
||||||
|
@@ -113,7 +113,7 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid,
|
|||||||
arg arg_i = args.at(i);
|
arg arg_i = args.at(i);
|
||||||
arg_type ty = arg_i.type();
|
arg_type ty = arg_i.type();
|
||||||
if(ty != param_tys_.at(i))
|
if(ty != param_tys_.at(i))
|
||||||
throw std::runtime_error("invalid type");
|
throw std::runtime_error("invalid type for argument " + std::to_string(i));
|
||||||
if(ty == BUFFER_T)
|
if(ty == BUFFER_T)
|
||||||
bin_->setArg(i, *((driver::buffer**)arg_i.data()));
|
bin_->setArg(i, *((driver::buffer**)arg_i.data()));
|
||||||
else
|
else
|
||||||
@@ -253,16 +253,14 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
return std::unique_ptr<driver::module>();
|
return std::unique_ptr<driver::module>();
|
||||||
barriers.run(module);
|
barriers.run(module);
|
||||||
// std::cout << "isel" << std::endl;
|
// std::cout << "isel" << std::endl;
|
||||||
// ir::print(module, std::cout);
|
|
||||||
isel.visit(module, *llvm);
|
isel.visit(module, *llvm);
|
||||||
// std::cout << "done" << std::endl;
|
|
||||||
// return binary
|
// return binary
|
||||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||||
// done
|
// done
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string preheader() {
|
std::string function::preheader() {
|
||||||
return
|
return
|
||||||
R"(
|
R"(
|
||||||
#define bool _Bool
|
#define bool _Bool
|
||||||
@@ -277,6 +275,7 @@ R"(
|
|||||||
#define __multipleof(A) __attribute__((multipleof(A)))
|
#define __multipleof(A) __attribute__((multipleof(A)))
|
||||||
|
|
||||||
extern int get_program_id(int);
|
extern int get_program_id(int);
|
||||||
|
extern float sqrtf(float);
|
||||||
)";
|
)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -77,7 +77,7 @@ class CMakeBuild(build_ext):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
cfg = 'Debug' if self.debug else 'Release'
|
cfg = 'Debug' if self.debug else 'Release'
|
||||||
cfg = 'Release'
|
#cfg = 'Release'
|
||||||
build_args = ['--config', cfg]
|
build_args = ['--config', cfg]
|
||||||
|
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
|
@@ -211,27 +211,9 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
|
|||||||
os << ";\n";
|
os << ";\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::string preheader() {
|
|
||||||
return
|
|
||||||
R"(
|
|
||||||
#define bool _Bool
|
|
||||||
#define true 1
|
|
||||||
#define false 0
|
|
||||||
#define __bool_true_false_are_defined 1
|
|
||||||
|
|
||||||
#define __readonly __attribute__((readonly))
|
|
||||||
#define __writeonly __attribute__((writeonly))
|
|
||||||
#define __noalias __attribute__((noalias))
|
|
||||||
#define __aligned(A) __attribute__((aligned(A)))
|
|
||||||
#define __multipleof(A) __attribute__((multipleof(A)))
|
|
||||||
|
|
||||||
extern int get_program_id(int);
|
|
||||||
)";
|
|
||||||
}
|
|
||||||
|
|
||||||
void make_module(const std::string& src, ir::module* ir,
|
void make_module(const std::string& src, ir::module* ir,
|
||||||
const runtime::function::options_space_t& opt) {
|
const runtime::function::options_space_t& opt) {
|
||||||
std::string copy = preheader() + src;
|
std::string copy = triton::runtime::function::preheader() + src;
|
||||||
// pre-process
|
// pre-process
|
||||||
TokenSequence tokens;
|
TokenSequence tokens;
|
||||||
Preprocessor cpp(©, true);
|
Preprocessor cpp(©, true);
|
||||||
@@ -341,11 +323,11 @@ inline std::string to_torch_ty(ir::type *ty) {
|
|||||||
if(ty->is_integer_ty())
|
if(ty->is_integer_ty())
|
||||||
return "int64_t";
|
return "int64_t";
|
||||||
if(ty->is_half_ty())
|
if(ty->is_half_ty())
|
||||||
return "float16";
|
return "double";
|
||||||
if(ty->is_float_ty())
|
if(ty->is_float_ty())
|
||||||
return "float32";
|
return "double";
|
||||||
if(ty->is_double_ty())
|
if(ty->is_double_ty())
|
||||||
return "float64";
|
return "double";
|
||||||
if(ty->is_pointer_ty())
|
if(ty->is_pointer_ty())
|
||||||
return "torch::Tensor";
|
return "torch::Tensor";
|
||||||
throw std::runtime_error("unknown type");
|
throw std::runtime_error("unknown type");
|
||||||
@@ -363,11 +345,11 @@ inline std::string to_c_ty(ir::type *ty) {
|
|||||||
if(ty->is_integer_ty(64))
|
if(ty->is_integer_ty(64))
|
||||||
return "int64_t";
|
return "int64_t";
|
||||||
if(ty->is_half_ty())
|
if(ty->is_half_ty())
|
||||||
return "float16";
|
return "half";
|
||||||
if(ty->is_float_ty())
|
if(ty->is_float_ty())
|
||||||
return "float32";
|
return "float";
|
||||||
if(ty->is_double_ty())
|
if(ty->is_double_ty())
|
||||||
return "float64";
|
return "double";
|
||||||
if(ty->is_pointer_ty())
|
if(ty->is_pointer_ty())
|
||||||
return "drv::cu_buffer";
|
return "drv::cu_buffer";
|
||||||
throw std::runtime_error("unknown type");
|
throw std::runtime_error("unknown type");
|
||||||
|
@@ -1,2 +1,3 @@
|
|||||||
from .dot import _dot, dot
|
from .dot import _dot, dot
|
||||||
from .einsum import _einsum, einsum
|
from .einsum import _einsum, einsum
|
||||||
|
from .batchnorm import _batchnorm, batchnorm
|
75
python/triton/ops/batchnorm.py
Normal file
75
python/triton/ops/batchnorm.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import triton
|
||||||
|
import math
|
||||||
|
|
||||||
|
class _batchnorm(triton.function):
|
||||||
|
|
||||||
|
fwd_src = """
|
||||||
|
void batchnormForward(float *Y, float *M, float *V,
|
||||||
|
float *X, float *G, float *B,
|
||||||
|
int N, float rcpN, float eps) {
|
||||||
|
int rx[TM] = 0 ... TM;
|
||||||
|
float *px[TM];
|
||||||
|
float x[TM] = 0;
|
||||||
|
int c = get_program_id(1);
|
||||||
|
float g = *(G + c);
|
||||||
|
float b = *(B + c);
|
||||||
|
|
||||||
|
float mean[TM] = 0;
|
||||||
|
px = X + rx + c*N;
|
||||||
|
for(int i = 0; i < N; i = i + TM){
|
||||||
|
x = *px;
|
||||||
|
mean = mean + x;
|
||||||
|
px = px + TM;
|
||||||
|
}
|
||||||
|
float *pm = M + c;
|
||||||
|
float m = mean[+] * rcpN;
|
||||||
|
*pm = m;
|
||||||
|
|
||||||
|
float var[TM] = 0;
|
||||||
|
px = X + rx + c*N;
|
||||||
|
for(int i = 0; i < N; i = i + TM){
|
||||||
|
x = *px;
|
||||||
|
x = x - m;
|
||||||
|
var = var + x*x;
|
||||||
|
px = px + TM;
|
||||||
|
}
|
||||||
|
float v = var[+] * rcpN;
|
||||||
|
float *pv = V + c;
|
||||||
|
*pv = v;
|
||||||
|
float rstdg = 1 / sqrtf(v + eps) * g;
|
||||||
|
|
||||||
|
px = X + rx + c*N;
|
||||||
|
float* py[TM] = Y + rx + c*N;
|
||||||
|
for(int i = 0; i < N; i = i + TM){
|
||||||
|
x = *px;
|
||||||
|
float y[TM] = (x - m)*rstdg + b;
|
||||||
|
*py = y;
|
||||||
|
px = px + TM;
|
||||||
|
py = py + TM;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V'])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, gamma, beta, eps):
|
||||||
|
shape = triton.shape(x)
|
||||||
|
dtype = x.dtype
|
||||||
|
# allocate outputs
|
||||||
|
C, H, W, B = shape[0], shape[1], shape[2], shape[3]
|
||||||
|
y = triton.empty(shape, dtype=dtype)
|
||||||
|
mean = triton.empty([C], dtype=dtype)
|
||||||
|
var = triton.empty([C], dtype=dtype)
|
||||||
|
# execute kernels
|
||||||
|
N = H*W*B
|
||||||
|
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, N, 1./N, eps,
|
||||||
|
lambda opt: [1, C],
|
||||||
|
TM = 128)
|
||||||
|
# save
|
||||||
|
ctx.eps = eps
|
||||||
|
ctx.save_for_backward(x, gamma, beta, mean, var)
|
||||||
|
return y, mean, var
|
||||||
|
|
||||||
|
|
||||||
|
batchnorm = _batchnorm.apply
|
@@ -181,15 +181,16 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, subscripts, a, b, bench = 0):
|
def forward(ctx, subscripts, a, b, bench = 0):
|
||||||
ctx.save_for_backward(a, b)
|
ctx.save_for_backward(a, b)
|
||||||
|
# parse
|
||||||
if type(subscripts) is str:
|
if type(subscripts) is str:
|
||||||
einsum_a, einsum_bc = subscripts.split(",")
|
einsum_a, einsum_bc = subscripts.split(",")
|
||||||
einsum_b, einsum_c = einsum_bc.split("->")
|
einsum_b, einsum_c = einsum_bc.split("->")
|
||||||
else:
|
else:
|
||||||
einsum_a, einsum_b, einsum_c = subscripts
|
einsum_a, einsum_b, einsum_c = subscripts
|
||||||
|
|
||||||
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
|
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
|
||||||
einsum_a, einsum_b, einsum_c,
|
einsum_a, einsum_b, einsum_c,
|
||||||
triton.shape(a), triton.shape(b))
|
triton.shape(a), triton.shape(b))
|
||||||
|
# save for backward
|
||||||
ctx.trans_a = ta
|
ctx.trans_a = ta
|
||||||
ctx.trans_b = tb
|
ctx.trans_b = tb
|
||||||
ctx.einsum_a = einsum_a
|
ctx.einsum_a = einsum_a
|
||||||
@@ -197,6 +198,7 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
ctx.einsum_c = einsum_c
|
ctx.einsum_c = einsum_c
|
||||||
ctx.bench = bench
|
ctx.bench = bench
|
||||||
ctx.bmnk = bmnk
|
ctx.bmnk = bmnk
|
||||||
|
# run
|
||||||
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench)
|
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user