[IR] Added special-purpose dequantize
instruction (#759)
It is currently necessary for optimal performance in quantized workloads to add a special-purpose instruction in the IR. Backward compatibility with this instruction is *NOT* guaranteed.
This commit is contained in:
@@ -14,6 +14,7 @@ namespace ir {
|
||||
class cast_inst;
|
||||
class cmp_inst;
|
||||
class reshape_inst;
|
||||
class dequantize_inst;
|
||||
class broadcast_inst;
|
||||
class binary_operator;
|
||||
class getelementptr_inst;
|
||||
@@ -34,6 +35,7 @@ private:
|
||||
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
|
||||
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_dequantize(ir::dequantize_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
|
||||
std::vector<cst_info> populate_is_constant_cmp(ir::cmp_inst* x);
|
||||
@@ -44,6 +46,7 @@ private:
|
||||
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_dequantize(ir::dequantize_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
|
||||
@@ -54,6 +57,7 @@ private:
|
||||
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_dequantize(ir::dequantize_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
|
||||
|
@@ -25,6 +25,7 @@ private:
|
||||
void update_graph_reduce(ir::instruction *i);
|
||||
void update_graph_reshape(ir::instruction *i);
|
||||
void update_graph_trans(ir::instruction *i);
|
||||
void update_graph_dequantize(ir::instruction *i);
|
||||
void update_graph_broadcast(ir::instruction *i);
|
||||
void update_graph_dot(ir::instruction *i);
|
||||
void update_graph_elementwise(ir::instruction *i,
|
||||
|
@@ -152,7 +152,15 @@ private:
|
||||
std::tuple<Value*, Value*, Value*, Value*> bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
Value* bf16_to_fp32(Value *in0);
|
||||
Value* fp32_to_bf16(Value *in0);
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int16_to_float16x8(
|
||||
Value *in0, Value *scale_x512, Value *shift
|
||||
);
|
||||
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int32_to_float16x8(
|
||||
Value *in0, Value *scale_x512, Value *shift
|
||||
);
|
||||
std::tuple<Value*, Value*, Value*, Value*> int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift);
|
||||
std::tuple<Value*, Value*> prepare_scale_shift(Value *scale, Value *shift);
|
||||
void visit_dequantize_inst(ir::dequantize_inst*);
|
||||
void visit_cast_inst(ir::cast_inst*);
|
||||
void visit_return_inst(ir::return_inst*);
|
||||
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
||||
|
@@ -73,6 +73,8 @@ public:
|
||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||
value* create_ret_void();
|
||||
value* create_ret(value *ret);
|
||||
// Dequantize instructions
|
||||
value* create_dequantize(value *src, value *scale, value *shift, type *dest_ty);
|
||||
// Cast instructions
|
||||
value* create_bitcast(value *src, type *dest_ty);
|
||||
value *create_cast(cast_op_t op, value *v, type *dst_ty);
|
||||
|
@@ -108,6 +108,8 @@ enum value_id_t: unsigned {
|
||||
// cmp
|
||||
INST_ICMP,
|
||||
INST_FCMP,
|
||||
// dequantize
|
||||
INST_DEQUANTIZE,
|
||||
// cast
|
||||
INST_CAST_TRUNC,
|
||||
INST_CAST_ZEXT,
|
||||
|
@@ -274,6 +274,24 @@ protected:
|
||||
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// dequantize_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class dequantize_inst: public instruction{
|
||||
private:
|
||||
std::string repr_impl() const override { return "dequantize"; }
|
||||
|
||||
protected:
|
||||
dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static dequantize_inst *create(value *arg, value *scale, value *shift, type *ty,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(dequantize_inst)
|
||||
_TRITON_DEFINE_ACCEPT(dequantize_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast_inst classes
|
||||
|
@@ -20,6 +20,7 @@ class getelementptr_inst;
|
||||
|
||||
class icmp_inst;
|
||||
class fcmp_inst;
|
||||
class dequantize_inst;
|
||||
class cast_inst;
|
||||
class trunc_inst;
|
||||
class z_ext_inst;
|
||||
@@ -124,6 +125,7 @@ public:
|
||||
|
||||
virtual void visit_icmp_inst(icmp_inst*) = 0;
|
||||
virtual void visit_fcmp_inst(fcmp_inst*) = 0;
|
||||
virtual void visit_dequantize_inst(dequantize_inst*) = 0;
|
||||
virtual void visit_cast_inst(cast_inst*) = 0;
|
||||
|
||||
virtual void visit_return_inst(return_inst*) = 0;
|
||||
|
@@ -115,6 +115,18 @@ std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_ins
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_dequantize(ir::dequantize_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
result.push_back(op_cst[d]);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
@@ -212,6 +224,8 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
return populate_is_constant_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_is_constant_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
|
||||
return populate_is_constant_dequantize(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_is_constant_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
@@ -279,6 +293,23 @@ std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_dequantize(ir::dequantize_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
|
||||
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
for(size_t d = 0; d < shapes.size(); d++) {
|
||||
unsigned factor = 1;
|
||||
if (d == shapes.size() - 1) {
|
||||
factor = ret_last_dim / op_last_dim;
|
||||
}
|
||||
result.push_back(factor * op_mc[d]);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
@@ -376,6 +407,8 @@ std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||
return populate_max_contiguous_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_max_contiguous_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
|
||||
return populate_max_contiguous_dequantize(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_max_contiguous_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
@@ -420,6 +453,23 @@ std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_dequantize(ir::dequantize_inst* x){
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
|
||||
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
|
||||
auto op_multiple = populate_starting_multiple(op);
|
||||
for(size_t d = 0; d < shapes.size(); d++) {
|
||||
unsigned factor = 1;
|
||||
if (d == shapes.size() - 1) {
|
||||
factor = ret_last_dim / op_last_dim;
|
||||
}
|
||||
result.push_back(factor * op_multiple[d]);
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
|
||||
auto result = populate_starting_multiple(x->get_operand(0));
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
@@ -539,6 +589,8 @@ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
return populate_starting_multiple_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_starting_multiple_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
|
||||
return populate_starting_multiple_dequantize(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_starting_multiple_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
|
@@ -56,6 +56,17 @@ void axes::update_graph_trans(ir::instruction *i) {
|
||||
graph_.add_edge({i, perm[d]}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_dequantize(ir::instruction *i) {
|
||||
auto *dequantize = static_cast<ir::dequantize_inst*>(i);
|
||||
auto shapes = dequantize->get_type()->get_block_shapes();
|
||||
ir::value *op = dequantize->get_operand(0);
|
||||
|
||||
// add edge except the last axis
|
||||
for(unsigned d = 0; d < shapes.size() - 1; d ++){
|
||||
graph_.add_edge({i, d}, {op, d});
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_broadcast(ir::instruction *i) {
|
||||
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
|
||||
auto shapes = broadcast->get_type()->get_block_shapes();
|
||||
@@ -119,6 +130,7 @@ void axes::update_graph(ir::instruction *i) {
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);
|
||||
case ir::INST_CAT: return update_graph_elementwise(i, true);
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_DEQUANTIZE: return update_graph_dequantize(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
||||
|
@@ -99,6 +99,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
||||
// constants
|
||||
#define i16(...) builder_->getInt16(__VA_ARGS__)
|
||||
#define i32(...) builder_->getInt32(__VA_ARGS__)
|
||||
// ops
|
||||
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
|
||||
@@ -854,6 +855,234 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> generator::int16_to_float16x8(
|
||||
Value *in0, Value *scale_x512, Value *shift
|
||||
){
|
||||
/* unpacking 8 int2s packed into an int16 to 8 float16s
|
||||
* the algorithm is similar to
|
||||
* https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/
|
||||
fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1492-L1563
|
||||
*/
|
||||
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false),
|
||||
"{"
|
||||
".reg .b32 a<2>, b<4>; \n\t" // input is 0xab,cd,ef,gh,ab,cd,ef,gh, each a, b etc occupies two bits.
|
||||
"and.b32 a0, 0x30300303, $4; \n\t" // set a0 to 0x0b,00,0f,00,00,0d,00,0h
|
||||
"and.b32 a1, 0xc0c00c0c, $4; \n\t" // set a1 to 0xa0,00,e0,00,00,c0,00,g0
|
||||
"prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x00,00,00,0d,00,00,00,0h
|
||||
"prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00,00,00,c0,00,00,00,g0
|
||||
"prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x00,00,0b,00,00,00,0f,00
|
||||
"prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00,00,a0,00,00,00,e0,00
|
||||
"mov.b32 a0, 0x78007800; \n\t" // a0 = 32768
|
||||
"mov.b32 a1, 0x70007000; \n\t" // a1 = 8192
|
||||
"mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768.
|
||||
"mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 8192.
|
||||
"mov.b32 a0, 0x68006800; \n\t" // a0 = 2048
|
||||
"mov.b32 a1, 0x60006000; \n\t" // a1 = 512
|
||||
"mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 2048.
|
||||
"mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 512.
|
||||
"fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift.
|
||||
"fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift.
|
||||
"fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out2 = b2 * scale + shift.
|
||||
"fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out3 = b3 * scale + shift.
|
||||
"}", "=r,=r,=r,=r,r,r,r", false);
|
||||
|
||||
Value *packed_in = UndefValue::get(vec_ty(i16_ty, 2));
|
||||
packed_in = insert_elt(packed_in, in0, (int)0);
|
||||
packed_in = insert_elt(packed_in, in0, (int)1);
|
||||
Value *in = bit_cast(packed_in, i32_ty);
|
||||
|
||||
Value *ret = call(ptx, {in, scale_x512, shift});
|
||||
Value *packed_ret0 = extract_val(ret, {0});
|
||||
Value *packed_ret1 = extract_val(ret, {1});
|
||||
Value *packed_ret2 = extract_val(ret, {2});
|
||||
Value *packed_ret3 = extract_val(ret, {3});
|
||||
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h
|
||||
Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g
|
||||
Value *ret2 = extract_elt(packed_ret2, (uint64_t)0); // f
|
||||
Value *ret3 = extract_elt(packed_ret3, (uint64_t)0); // e
|
||||
Value *ret4 = extract_elt(packed_ret0, (uint64_t)1); // d
|
||||
Value *ret5 = extract_elt(packed_ret1, (uint64_t)1); // c
|
||||
Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b
|
||||
Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a
|
||||
return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7);
|
||||
}
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> generator::int32_to_float16x8(
|
||||
Value *in0, Value *scale_x512, Value *shift
|
||||
){
|
||||
/* unpacking 8 int4s packed into an int32 to 8 float16s
|
||||
* the algorithm is similar to
|
||||
* https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/
|
||||
fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1566-L1619
|
||||
*/
|
||||
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false),
|
||||
"{"
|
||||
".reg .b32 a<2>, b<4>; \n\t"
|
||||
"and.b32 a0, 0x0f0f0f0f, $4; \n\t" // If input is 0xabcdefgh set a to 0x0b0d0f0h
|
||||
"and.b32 a1, 0xf0f0f0f0, $4; \n\t" // If input is 0xabcdefgh set a to 0xa0c0e0g0
|
||||
"prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x000f000h
|
||||
"prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00e000g0
|
||||
"prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x000b000d
|
||||
"prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00a000c0
|
||||
"mov.b32 a0, 0x78007800; \n\t"
|
||||
"mov.b32 a1, 0x68006800; \n\t"
|
||||
"mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768.
|
||||
"mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 2048.
|
||||
"mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 32768.
|
||||
"mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 2048.
|
||||
"fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift.
|
||||
"fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift.
|
||||
"fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out0 = b0 * scale + shift.
|
||||
"fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out1 = b1 * scale + shift.
|
||||
"}", "=r,=r,=r,=r,r,r,r", false);
|
||||
|
||||
Value *ret = call(ptx, {in0, scale_x512, shift});
|
||||
Value *packed_ret0 = extract_val(ret, {0});
|
||||
Value *packed_ret1 = extract_val(ret, {1});
|
||||
Value *packed_ret2 = extract_val(ret, {2});
|
||||
Value *packed_ret3 = extract_val(ret, {3});
|
||||
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h
|
||||
Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g
|
||||
Value *ret2 = extract_elt(packed_ret0, (uint64_t)1); // f
|
||||
Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // e
|
||||
Value *ret4 = extract_elt(packed_ret2, (uint64_t)0); // d
|
||||
Value *ret5 = extract_elt(packed_ret3, (uint64_t)0); // c
|
||||
Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b
|
||||
Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a
|
||||
return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7);
|
||||
}
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*> generator::int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift){
|
||||
/* unpacking 4 int8s packed into an int32 to 4 fp16s
|
||||
* the algorithm is similar to
|
||||
* https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/
|
||||
fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1622-L1646
|
||||
*/
|
||||
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false),
|
||||
"{"
|
||||
".reg .b32 a, b<2>; \n\t"
|
||||
"prmt.b32 b0, 0, $2, 0x0504; \n\t" // If input is 0xabcdefgh set b0 to 0x00ef00gh
|
||||
"prmt.b32 b1, 0, $2, 0x0706; \n\t" // If input is 0xabcdefgh set b1 to 0x00ab00cd
|
||||
"mov.b32 a, 0x78007800; \n\t"
|
||||
"mul.f16x2 b0, b0, a; \n\t" // b0 = b0 * 32768.
|
||||
"mul.f16x2 b1, b1, a; \n\t" // b1 = b1 * 32768.
|
||||
"fma.rn.f16x2 $0, b0, $3, $4; \n\t" // out0 = b0 * scale + shift.
|
||||
"fma.rn.f16x2 $1, b1, $3, $4; \n\t" // out1 = b1 * scale + shift.
|
||||
"}", "=r,=r,r,r,r", false);
|
||||
|
||||
Value *ret = call(ptx, {in0, scale_x512, shift});
|
||||
Value *packed_ret0 = extract_val(ret, {0});
|
||||
Value *packed_ret1 = extract_val(ret, {1});
|
||||
Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // gh
|
||||
Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); // ef
|
||||
Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); // cd
|
||||
Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // ab
|
||||
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||
}
|
||||
|
||||
std::tuple<Value*, Value*> generator::prepare_scale_shift(Value *scale, Value *shift){
|
||||
Value *scale_x512 = fmul(scale, bit_cast(i16(0x6000), f16_ty));
|
||||
Value *p_scale_x512 = UndefValue::get(vec_ty(f16_ty, 2));
|
||||
p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)0);
|
||||
p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)1);
|
||||
p_scale_x512 = bit_cast(p_scale_x512, i32_ty);
|
||||
|
||||
Value *p_shift = UndefValue::get(vec_ty(f16_ty, 2));
|
||||
p_shift = insert_elt(p_shift, shift, (int)0);
|
||||
p_shift = insert_elt(p_shift, shift, (int)1);
|
||||
p_shift = bit_cast(p_shift, i32_ty);
|
||||
|
||||
return std::make_tuple(p_scale_x512, p_shift);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `dequantize`
|
||||
*/
|
||||
void generator::visit_dequantize_inst(ir::dequantize_inst* x) {
|
||||
ir::value *op = x->get_operand(0);
|
||||
|
||||
auto src_ty_size_in_bits = op->get_type()->get_scalar_ty()->get_primitive_size_in_bits();
|
||||
|
||||
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
|
||||
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
|
||||
|
||||
auto x_idxs = idxs_.at(x);
|
||||
auto op_idxs = idxs_.at(op);
|
||||
|
||||
ir::value *scale = x->get_operand(1);
|
||||
ir::value *shift = x->get_operand(2);
|
||||
|
||||
Value *p_scale_x512, *p_shift;
|
||||
std::tie(p_scale_x512, p_shift) = prepare_scale_shift(vals_[scale][{}], vals_[shift][{}]);
|
||||
|
||||
int ld = layouts_->get(x)->get_order(0);
|
||||
int contiguous = layouts_->get(x)->to_scanline()->nts(ld);
|
||||
|
||||
int op_ld = layouts_->get(op)->get_order(0);
|
||||
int op_contiguous = layouts_->get(op)->to_scanline()->nts(op_ld);
|
||||
|
||||
std::string err_msg;
|
||||
err_msg = "unsupported dequantization, cannot vectorize properly. x_idxs.size(): "
|
||||
+ std::to_string(x_idxs.size()) + "; op_idxs.size(): "
|
||||
+ std::to_string(op_idxs.size()) + "; contiguous: "
|
||||
+ std::to_string(contiguous) + "; op_contiguous: "
|
||||
+ std::to_string(op_contiguous) + ". if the condition "
|
||||
"is not met, please try adjusting block_size, num_warps or "
|
||||
"using tl.multiple_of to hint the input/output ptr address.";
|
||||
|
||||
if (ret_last_dim == 8 * op_last_dim) {
|
||||
if((x_idxs.size() != 8 * op_idxs.size()) || (contiguous != 8 * op_contiguous)) {
|
||||
throw std::runtime_error(err_msg);
|
||||
}
|
||||
|
||||
auto cvt = [&](
|
||||
Value* a, Value* scale, Value* shift
|
||||
){
|
||||
if (src_ty_size_in_bits == 16){ // int2 quantization, int16 to 8 fp16s
|
||||
return int16_to_float16x8(a, scale, shift);
|
||||
} else if (src_ty_size_in_bits == 32) { // int4 quantization, int32 to 8 fp16s
|
||||
return int32_to_float16x8(a, scale, shift);
|
||||
} else {
|
||||
throw std::runtime_error("unsupported conversion");
|
||||
}
|
||||
};
|
||||
|
||||
for(size_t j = 0; j < op_idxs.size(); j++){
|
||||
size_t i = j * 8;
|
||||
std::tie(vals_[x][x_idxs[i+0]],
|
||||
vals_[x][x_idxs[i+1]],
|
||||
vals_[x][x_idxs[i+2]],
|
||||
vals_[x][x_idxs[i+3]],
|
||||
vals_[x][x_idxs[i+4]],
|
||||
vals_[x][x_idxs[i+5]],
|
||||
vals_[x][x_idxs[i+6]],
|
||||
vals_[x][x_idxs[i+7]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift);
|
||||
}
|
||||
} else if (ret_last_dim == 4 * op_last_dim && src_ty_size_in_bits == 32) { // int8 quantization, int32 to 4 fp16s
|
||||
if((x_idxs.size() != 4 * op_idxs.size()) || (contiguous != 4 * op_contiguous)) {
|
||||
throw std::runtime_error(err_msg);
|
||||
}
|
||||
|
||||
auto cvt = [&](Value* a, Value* scale, Value* shift){
|
||||
return int32_to_float16x4(a, scale, shift);
|
||||
};
|
||||
|
||||
for(size_t j = 0; j < op_idxs.size(); j++){
|
||||
size_t i = j * 4;
|
||||
std::tie(vals_[x][x_idxs[i+0]],
|
||||
vals_[x][x_idxs[i+1]],
|
||||
vals_[x][x_idxs[i+2]],
|
||||
vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("unsupported dequantization");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `return`
|
||||
*/
|
||||
|
@@ -120,6 +120,14 @@ value *builder::create_ret(value* val) {
|
||||
return insert(return_inst::create(ctx_, val));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// dequantize instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value* builder::create_dequantize(value *src, value *scale, value *shift, type *dst_ty){
|
||||
return insert(dequantize_inst::create(src, scale, shift, dst_ty));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -323,6 +323,21 @@ unary_inst::unary_inst(type *ty, value_id_t id, value *v, const std::string &nam
|
||||
set_operand(0, v);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// dequantize_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
dequantize_inst::dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next)
|
||||
: instruction(ty, INST_DEQUANTIZE, 3, name, next) {
|
||||
set_operand(0, v);
|
||||
set_operand(1, scale);
|
||||
set_operand(2, shift);
|
||||
}
|
||||
|
||||
dequantize_inst *dequantize_inst::create(value *arg, value *scale, value *shift, type *ty, const std::string &name, instruction *next){
|
||||
return new dequantize_inst(ty, arg, scale, shift, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -834,6 +834,8 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_br", &ir::builder::create_br, ret::reference)
|
||||
.def("create_cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||
.def("create_ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||
// Dequantize instructions
|
||||
.def("create_dequantize", &ir::builder::create_dequantize, ret::reference)
|
||||
// Cast instructions
|
||||
.def("create_bitcast", &ir::builder::create_bitcast, ret::reference)
|
||||
.def("create_cast", &ir::builder::create_cast, ret::reference)
|
||||
|
261
python/test/unit/language/test_dequantize.py
Normal file
261
python/test/unit/language/test_dequantize.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# flake8: noqa: F821,F841
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int8(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int8(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 4)
|
||||
mask = w_offsets < (size // 4)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 8)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 4)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int4(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = input_ptr + 1 + w_offsets
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale_shift = tl.load(input_ptr)
|
||||
scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int4(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 4)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_int2(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + 2 + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(input_ptr).to(tl.float16, bitcast=True)
|
||||
shift = tl.load(input_ptr + 1).to(tl.float16, bitcast=True)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize_kernel_scale_shift_int2(
|
||||
output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
w_offsets = tl.arange(0, BLOCK_SIZE // 8)
|
||||
mask = w_offsets < (size // 8)
|
||||
input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1)
|
||||
input = tl.load(input_ptrs, mask=mask, other=0)
|
||||
scale = tl.load(scale_ptr)
|
||||
shift = tl.load(shift_ptr)
|
||||
output = tl.dequantize(input, scale, shift, 2)
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
output_ptrs = tl.multiple_of(output_ptr + offsets, 8)
|
||||
tl.store(output_ptrs, output, mask=offsets < size)
|
||||
|
||||
|
||||
def test_dequantize_int8() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 128, 4)
|
||||
else:
|
||||
size = random.randrange(132, 1024, 4)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int8 * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 128)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int8[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int8[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int4() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int32)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 2,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int32 = input_int8.view(torch.int32)
|
||||
|
||||
input_int8_h1 = input_int8 >> 4
|
||||
input_int8_h0 = input_int8 & 15
|
||||
|
||||
input_int4_val = torch.stack(
|
||||
(input_int8_h0, input_int8_h1), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int32))
|
||||
expected = (input_int4_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
dequantize_kernel_int4[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int4[grid](
|
||||
output,
|
||||
input_int32,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
|
||||
def test_dequantize_int2() -> None:
|
||||
for i in range(10):
|
||||
if i < 5:
|
||||
size = random.randrange(16, 256, 8)
|
||||
else:
|
||||
size = random.randrange(264, 1024, 8)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
|
||||
scale_val = random.uniform(0.1, 4.0)
|
||||
shift_val = random.uniform(-10.0, 10.0)
|
||||
scale = torch.tensor(scale_val, dtype=torch.float16, device=device)
|
||||
shift = torch.tensor(shift_val, dtype=torch.float16, device=device)
|
||||
scale_shift = torch.tensor(
|
||||
[scale_val, shift_val],
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
).view(torch.int16)
|
||||
|
||||
input_int8 = torch.randint(
|
||||
0, 256, (size // 4,), dtype=torch.uint8, device=device
|
||||
)
|
||||
input_int16 = input_int8.view(torch.int16)
|
||||
|
||||
input_int8_q3 = input_int8 >> 6
|
||||
input_int8_q2 = (input_int8 >> 4) & 3
|
||||
input_int8_q1 = (input_int8 >> 2) & 3
|
||||
input_int8_q0 = input_int8 & 3
|
||||
|
||||
input_int2_val = torch.stack(
|
||||
(input_int8_q0, input_int8_q1, input_int8_q2, input_int8_q3), dim=1
|
||||
).flatten()
|
||||
|
||||
input = torch.cat((scale_shift, input_int16))
|
||||
expected = (input_int2_val * scale + shift).to(torch.float16)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
block_size = max(triton.next_power_of_2(size), 256)
|
||||
grid = (1,)
|
||||
|
||||
dequantize_kernel_int2[grid](
|
||||
output, input, size, BLOCK_SIZE=block_size, num_warps=1
|
||||
)
|
||||
rtol, atol = 1e-02, 1e-02
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
||||
|
||||
output = torch.empty([size], dtype=torch.float16, device=device)
|
||||
dequantize_kernel_scale_shift_int2[grid](
|
||||
output,
|
||||
input_int16,
|
||||
scale,
|
||||
shift,
|
||||
size,
|
||||
BLOCK_SIZE=block_size,
|
||||
num_warps=1,
|
||||
)
|
||||
assert torch.allclose(output, expected, rtol, atol)
|
@@ -685,6 +685,20 @@ def zeros(shape, dtype, _builder=None):
|
||||
return semantic.zeros(shape, dtype, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# dequantize
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def dequantize(input, scale, shift, nbit, dst_ty=float16, _builder=None):
|
||||
"""
|
||||
Tries to dequantize the input to given dtype
|
||||
"""
|
||||
nbit = _constexpr_to_value(nbit)
|
||||
return semantic.dequantize(input, scale, shift, nbit, dst_ty, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Shape Manipulation
|
||||
# -----------------------
|
||||
|
@@ -544,6 +544,31 @@ def broadcast_impl_value(lhs: tl.tensor,
|
||||
# (scalar, scalar) => returns original blocks
|
||||
return lhs, rhs
|
||||
|
||||
|
||||
#######
|
||||
# dequantize
|
||||
#######
|
||||
|
||||
def dequantize(input: tl.tensor,
|
||||
scale: tl.tensor,
|
||||
shift: tl.tensor,
|
||||
nbit: int,
|
||||
dst_ty: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
input_ty = input.type
|
||||
assert input_ty.is_block()
|
||||
assert input_ty.element_ty.is_int32() or input_ty.element_ty.is_int16()
|
||||
assert nbit in [2, 4, 8]
|
||||
assert dst_ty == tl.float16
|
||||
|
||||
shape = input_ty.get_block_shapes()
|
||||
factor = input_ty.element_ty.primitive_bitwidth // nbit
|
||||
dst_shape = shape[:-1] + [factor * shape[-1]]
|
||||
|
||||
dst_ty = tl.block_type(dst_ty, dst_shape)
|
||||
return tl.tensor(builder.create_dequantize(input.handle, scale.handle, shift.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
|
||||
|
||||
#######
|
||||
# cast
|
||||
#######
|
||||
|
Reference in New Issue
Block a user