[IR] Added special-purpose dequantize instruction (#759)

It is currently necessary for optimal performance in quantized workloads to add a special-purpose instruction in the IR. Backward compatibility with this instruction is *NOT* guaranteed.
This commit is contained in:
Yu Guo
2022-10-12 14:14:45 -07:00
committed by GitHub
parent 33e6f0df7f
commit 71b46acc42
16 changed files with 728 additions and 73 deletions

View File

@@ -14,6 +14,7 @@ namespace ir {
class cast_inst;
class cmp_inst;
class reshape_inst;
class dequantize_inst;
class broadcast_inst;
class binary_operator;
class getelementptr_inst;
@@ -34,6 +35,7 @@ private:
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
std::vector<cst_info> populate_is_constant_dequantize(ir::dequantize_inst* x);
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
std::vector<cst_info> populate_is_constant_cmp(ir::cmp_inst* x);
@@ -44,6 +46,7 @@ private:
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_max_contiguous_dequantize(ir::dequantize_inst* x);
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
@@ -54,6 +57,7 @@ private:
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_starting_multiple_dequantize(ir::dequantize_inst* x);
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);

View File

@@ -25,6 +25,7 @@ private:
void update_graph_reduce(ir::instruction *i);
void update_graph_reshape(ir::instruction *i);
void update_graph_trans(ir::instruction *i);
void update_graph_dequantize(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i,

View File

@@ -152,7 +152,15 @@ private:
std::tuple<Value*, Value*, Value*, Value*> bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
Value* bf16_to_fp32(Value *in0);
Value* fp32_to_bf16(Value *in0);
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int16_to_float16x8(
Value *in0, Value *scale_x512, Value *shift
);
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int32_to_float16x8(
Value *in0, Value *scale_x512, Value *shift
);
std::tuple<Value*, Value*, Value*, Value*> int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift);
std::tuple<Value*, Value*> prepare_scale_shift(Value *scale, Value *shift);
void visit_dequantize_inst(ir::dequantize_inst*);
void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*);
@@ -265,7 +273,7 @@ private:
/// idx for multi-stage pipeline
std::map<analysis::data_layout*, Value*> read_smem_idx_;
std::map<analysis::data_layout*, Value*> write_smem_idx_;
/// triton bb -> llvm bb
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;

View File

@@ -73,6 +73,8 @@ public:
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void();
value* create_ret(value *ret);
// Dequantize instructions
value* create_dequantize(value *src, value *scale, value *shift, type *dest_ty);
// Cast instructions
value* create_bitcast(value *src, type *dest_ty);
value *create_cast(cast_op_t op, value *v, type *dst_ty);

View File

@@ -108,6 +108,8 @@ enum value_id_t: unsigned {
// cmp
INST_ICMP,
INST_FCMP,
// dequantize
INST_DEQUANTIZE,
// cast
INST_CAST_TRUNC,
INST_CAST_ZEXT,

View File

@@ -274,6 +274,24 @@ protected:
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
};
//===----------------------------------------------------------------------===//
// dequantize_inst classes
//===----------------------------------------------------------------------===//
class dequantize_inst: public instruction{
private:
std::string repr_impl() const override { return "dequantize"; }
protected:
dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next);
public:
static dequantize_inst *create(value *arg, value *scale, value *shift, type *ty,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(dequantize_inst)
_TRITON_DEFINE_ACCEPT(dequantize_inst)
};
//===----------------------------------------------------------------------===//
// cast_inst classes
@@ -482,7 +500,7 @@ protected:
std::string get_cache_modifier_repr() const {
if (cache_ == CA) return ".ca";
if (cache_ == CG) return ".cg";
return "";
return "";
}
CACHE_MODIFIER cache_;
@@ -850,16 +868,16 @@ public:
class dot_inst: public builtin_inst {
public:
enum TransT { NoTrans, Trans };
enum DataType {
FP8, FP16, BF16, TF32, FP32,
INT1, INT4, INT8, INT32,
enum DataType {
FP8, FP16, BF16, TF32, FP32,
INT1, INT4, INT8, INT32,
UNKNOWN,
};
private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; }
public:
bool is_prefetched() const { return is_prefetched_; }
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
@@ -1046,11 +1064,11 @@ class prefetch_s_inst : public instruction {
std::string repr_impl() const { return "prefetch_s"; }
_TRITON_DEFINE_CLONE(prefetch_s_inst)
_TRITON_DEFINE_ACCEPT(prefetch_s_inst)
/// inc_: 0->first, 1->latch
int inc_ = 0;
public:
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) {
set_operand(0, arg);
}

View File

@@ -20,6 +20,7 @@ class getelementptr_inst;
class icmp_inst;
class fcmp_inst;
class dequantize_inst;
class cast_inst;
class trunc_inst;
class z_ext_inst;
@@ -124,6 +125,7 @@ public:
virtual void visit_icmp_inst(icmp_inst*) = 0;
virtual void visit_fcmp_inst(fcmp_inst*) = 0;
virtual void visit_dequantize_inst(dequantize_inst*) = 0;
virtual void visit_cast_inst(cast_inst*) = 0;
virtual void visit_return_inst(return_inst*) = 0;