[IR] Added special-purpose dequantize
instruction (#759)
It is currently necessary for optimal performance in quantized workloads to add a special-purpose instruction in the IR. Backward compatibility with this instruction is *NOT* guaranteed.
This commit is contained in:
@@ -115,6 +115,18 @@ std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_ins
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_dequantize(ir::dequantize_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
result.push_back(op_cst[d]);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
@@ -146,7 +158,7 @@ std::vector<align::cst_info> align::populate_is_constant_cmp(ir::cmp_inst* x) {
|
||||
// 16 17 18 ... 32 < 24 24 24 ... 24 => equal in groups of 8
|
||||
// 16 17 18 ... 32 < 20 20 20 ... 20 => equal in groups of 4
|
||||
// 16 17 18 ... 32 < 16 16 16 ... 16 => equal in groups of 16
|
||||
//
|
||||
//
|
||||
// if LHS is a range of N continuous (or equal) elements that starts at M,
|
||||
// and RHS is a set of N constants that start at K
|
||||
// then the result in constant in groups of gcd(M, K)
|
||||
@@ -212,6 +224,8 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
return populate_is_constant_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_is_constant_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
|
||||
return populate_is_constant_dequantize(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_is_constant_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
@@ -279,6 +293,23 @@ std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_dequantize(ir::dequantize_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
|
||||
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
for(size_t d = 0; d < shapes.size(); d++) {
|
||||
unsigned factor = 1;
|
||||
if (d == shapes.size() - 1) {
|
||||
factor = ret_last_dim / op_last_dim;
|
||||
}
|
||||
result.push_back(factor * op_mc[d]);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
@@ -376,6 +407,8 @@ std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||
return populate_max_contiguous_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_max_contiguous_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
|
||||
return populate_max_contiguous_dequantize(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_max_contiguous_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
@@ -420,6 +453,23 @@ std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_dequantize(ir::dequantize_inst* x){
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
|
||||
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
|
||||
auto op_multiple = populate_starting_multiple(op);
|
||||
for(size_t d = 0; d < shapes.size(); d++) {
|
||||
unsigned factor = 1;
|
||||
if (d == shapes.size() - 1) {
|
||||
factor = ret_last_dim / op_last_dim;
|
||||
}
|
||||
result.push_back(factor * op_multiple[d]);
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
|
||||
auto result = populate_starting_multiple(x->get_operand(0));
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
@@ -539,6 +589,8 @@ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
return populate_starting_multiple_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_starting_multiple_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
|
||||
return populate_starting_multiple_dequantize(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_starting_multiple_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
|
@@ -56,6 +56,17 @@ void axes::update_graph_trans(ir::instruction *i) {
|
||||
graph_.add_edge({i, perm[d]}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_dequantize(ir::instruction *i) {
|
||||
auto *dequantize = static_cast<ir::dequantize_inst*>(i);
|
||||
auto shapes = dequantize->get_type()->get_block_shapes();
|
||||
ir::value *op = dequantize->get_operand(0);
|
||||
|
||||
// add edge except the last axis
|
||||
for(unsigned d = 0; d < shapes.size() - 1; d ++){
|
||||
graph_.add_edge({i, d}, {op, d});
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_broadcast(ir::instruction *i) {
|
||||
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
|
||||
auto shapes = broadcast->get_type()->get_block_shapes();
|
||||
@@ -79,7 +90,7 @@ void axes::update_graph_dot(ir::instruction *i) {
|
||||
graph_.add_edge({dot, d}, {D, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_elementwise(ir::instruction *i,
|
||||
void axes::update_graph_elementwise(ir::instruction *i,
|
||||
bool is_masked_load_async) {
|
||||
if(i->get_num_operands() == 0)
|
||||
return;
|
||||
@@ -119,6 +130,7 @@ void axes::update_graph(ir::instruction *i) {
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);
|
||||
case ir::INST_CAT: return update_graph_elementwise(i, true);
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_DEQUANTIZE: return update_graph_dequantize(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
||||
|
Reference in New Issue
Block a user