diff --git a/include/triton/codegen/analysis/align.h b/include/triton/codegen/analysis/align.h index 513868aea..9f4926cfc 100644 --- a/include/triton/codegen/analysis/align.h +++ b/include/triton/codegen/analysis/align.h @@ -14,6 +14,7 @@ namespace ir { class cast_inst; class cmp_inst; class reshape_inst; + class dequantize_inst; class broadcast_inst; class binary_operator; class getelementptr_inst; @@ -34,6 +35,7 @@ private: std::vector populate_is_constant_phi(ir::phi_node* x); std::vector populate_is_constant_splat(ir::splat_inst* x); std::vector populate_is_constant_reshape(ir::reshape_inst* x); + std::vector populate_is_constant_dequantize(ir::dequantize_inst* x); std::vector populate_is_constant_broadcast(ir::broadcast_inst* x); std::vector populate_is_constant_binop(ir::binary_operator* x); std::vector populate_is_constant_cmp(ir::cmp_inst* x); @@ -44,6 +46,7 @@ private: std::vector populate_max_contiguous_phi(ir::phi_node* x); std::vector populate_max_contiguous_splat(ir::splat_inst* x); std::vector populate_max_contiguous_reshape(ir::reshape_inst* x); + std::vector populate_max_contiguous_dequantize(ir::dequantize_inst* x); std::vector populate_max_contiguous_broadcast(ir::broadcast_inst* x); std::vector populate_max_contiguous_binop(ir::binary_operator* x); std::vector populate_max_contiguous_gep(ir::getelementptr_inst* x); @@ -54,6 +57,7 @@ private: std::vector populate_starting_multiple_phi(ir::phi_node* x); std::vector populate_starting_multiple_splat(ir::splat_inst* x); std::vector populate_starting_multiple_reshape(ir::reshape_inst* x); + std::vector populate_starting_multiple_dequantize(ir::dequantize_inst* x); std::vector populate_starting_multiple_broadcast(ir::broadcast_inst* x); std::vector populate_starting_multiple_binop(ir::binary_operator* x); std::vector populate_starting_multiple_gep(ir::getelementptr_inst* x); diff --git a/include/triton/codegen/analysis/axes.h b/include/triton/codegen/analysis/axes.h index 759ed0f8f..9e8570b5c 100644 --- a/include/triton/codegen/analysis/axes.h +++ b/include/triton/codegen/analysis/axes.h @@ -25,6 +25,7 @@ private: void update_graph_reduce(ir::instruction *i); void update_graph_reshape(ir::instruction *i); void update_graph_trans(ir::instruction *i); + void update_graph_dequantize(ir::instruction *i); void update_graph_broadcast(ir::instruction *i); void update_graph_dot(ir::instruction *i); void update_graph_elementwise(ir::instruction *i, diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index bbf0417f1..cf7dacb09 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -152,7 +152,15 @@ private: std::tuple bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); Value* bf16_to_fp32(Value *in0); Value* fp32_to_bf16(Value *in0); - + std::tuple int16_to_float16x8( + Value *in0, Value *scale_x512, Value *shift + ); + std::tuple int32_to_float16x8( + Value *in0, Value *scale_x512, Value *shift + ); + std::tuple int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift); + std::tuple prepare_scale_shift(Value *scale, Value *shift); + void visit_dequantize_inst(ir::dequantize_inst*); void visit_cast_inst(ir::cast_inst*); void visit_return_inst(ir::return_inst*); void visit_cond_branch_inst(ir::cond_branch_inst*); @@ -265,7 +273,7 @@ private: /// idx for multi-stage pipeline std::map read_smem_idx_; std::map write_smem_idx_; - + /// triton bb -> llvm bb std::map bbs_; std::map> ords_; diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 8eb1c2ce3..d94dc4a2a 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -73,6 +73,8 @@ public: value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest); value* create_ret_void(); value* create_ret(value *ret); + // Dequantize instructions + value* create_dequantize(value *src, value *scale, value *shift, type *dest_ty); // Cast instructions value* create_bitcast(value *src, type *dest_ty); value *create_cast(cast_op_t op, value *v, type *dst_ty); diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 4e60d3444..0ecdb409d 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -108,6 +108,8 @@ enum value_id_t: unsigned { // cmp INST_ICMP, INST_FCMP, + // dequantize + INST_DEQUANTIZE, // cast INST_CAST_TRUNC, INST_CAST_ZEXT, diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 8a1c3f7cf..8d5748694 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -274,6 +274,24 @@ protected: unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next); }; +//===----------------------------------------------------------------------===// +// dequantize_inst classes +//===----------------------------------------------------------------------===// + +class dequantize_inst: public instruction{ +private: + std::string repr_impl() const override { return "dequantize"; } + +protected: + dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next); + +public: + static dequantize_inst *create(value *arg, value *scale, value *shift, type *ty, + const std::string &name = "", instruction *next = nullptr); + + _TRITON_DEFINE_CLONE(dequantize_inst) + _TRITON_DEFINE_ACCEPT(dequantize_inst) +}; //===----------------------------------------------------------------------===// // cast_inst classes @@ -482,7 +500,7 @@ protected: std::string get_cache_modifier_repr() const { if (cache_ == CA) return ".ca"; if (cache_ == CG) return ".cg"; - return ""; + return ""; } CACHE_MODIFIER cache_; @@ -850,16 +868,16 @@ public: class dot_inst: public builtin_inst { public: enum TransT { NoTrans, Trans }; - enum DataType { - FP8, FP16, BF16, TF32, FP32, - INT1, INT4, INT8, INT32, + enum DataType { + FP8, FP16, BF16, TF32, FP32, + INT1, INT4, INT8, INT32, UNKNOWN, }; private: dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next); std::string repr_impl() const { return "dot"; } - + public: bool is_prefetched() const { return is_prefetched_; } void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; } @@ -1046,11 +1064,11 @@ class prefetch_s_inst : public instruction { std::string repr_impl() const { return "prefetch_s"; } _TRITON_DEFINE_CLONE(prefetch_s_inst) _TRITON_DEFINE_ACCEPT(prefetch_s_inst) - + /// inc_: 0->first, 1->latch int inc_ = 0; public: - prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next) + prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next) : instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) { set_operand(0, arg); } diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index 5f84f414f..b03b5f4fe 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -20,6 +20,7 @@ class getelementptr_inst; class icmp_inst; class fcmp_inst; +class dequantize_inst; class cast_inst; class trunc_inst; class z_ext_inst; @@ -124,6 +125,7 @@ public: virtual void visit_icmp_inst(icmp_inst*) = 0; virtual void visit_fcmp_inst(fcmp_inst*) = 0; + virtual void visit_dequantize_inst(dequantize_inst*) = 0; virtual void visit_cast_inst(cast_inst*) = 0; virtual void visit_return_inst(return_inst*) = 0; diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index 6bd6e4ef9..a4a066928 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -115,6 +115,18 @@ std::vector align::populate_is_constant_reshape(ir::reshape_ins return add_to_cache(x, result, is_constant_); } +std::vector align::populate_is_constant_dequantize(ir::dequantize_inst* x) { + auto x_shapes = get_shapes(x); + std::vector result; + ir::value *op = x->get_operand(0); + auto op_shapes = op->get_type()->get_block_shapes(); + auto op_cst = populate_is_constant(op); + for(size_t d = 0; d < x_shapes.size(); d++) { + result.push_back(op_cst[d]); + } + return add_to_cache(x, result, is_constant_); +} + std::vector align::populate_is_constant_broadcast(ir::broadcast_inst* x) { auto x_shapes = get_shapes(x); std::vector result; @@ -146,7 +158,7 @@ std::vector align::populate_is_constant_cmp(ir::cmp_inst* x) { // 16 17 18 ... 32 < 24 24 24 ... 24 => equal in groups of 8 // 16 17 18 ... 32 < 20 20 20 ... 20 => equal in groups of 4 // 16 17 18 ... 32 < 16 16 16 ... 16 => equal in groups of 16 - // + // // if LHS is a range of N continuous (or equal) elements that starts at M, // and RHS is a set of N constants that start at K // then the result in constant in groups of gcd(M, K) @@ -212,6 +224,8 @@ std::vector align::populate_is_constant(ir::value *v) { return populate_is_constant_splat(x); if(auto *x = dynamic_cast(v)) return populate_is_constant_reshape(x); + if(auto *x = dynamic_cast(v)) + return populate_is_constant_dequantize(x); if(auto *x = dynamic_cast(v)) return populate_is_constant_broadcast(x); if(auto *x = dynamic_cast(v)) @@ -279,6 +293,23 @@ std::vector align::populate_max_contiguous_reshape(ir::reshape_inst* x return add_to_cache(x, result, max_contiguous_); } +std::vector align::populate_max_contiguous_dequantize(ir::dequantize_inst* x) { + auto shapes = get_shapes(x); + std::vector result; + ir::value *op = x->get_operand(0); + auto ret_last_dim = (x->get_type()->get_block_shapes()).back(); + auto op_last_dim = (op->get_type()->get_block_shapes()).back(); + auto op_mc = populate_max_contiguous(op); + for(size_t d = 0; d < shapes.size(); d++) { + unsigned factor = 1; + if (d == shapes.size() - 1) { + factor = ret_last_dim / op_last_dim; + } + result.push_back(factor * op_mc[d]); + } + return add_to_cache(x, result, max_contiguous_); +} + std::vector align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) { auto shapes = get_shapes(x); std::vector result; @@ -376,6 +407,8 @@ std::vector align::populate_max_contiguous(ir::value *v){ return populate_max_contiguous_splat(x); if(auto *x = dynamic_cast(v)) return populate_max_contiguous_reshape(x); + if(auto *x = dynamic_cast(v)) + return populate_max_contiguous_dequantize(x); if(auto *x = dynamic_cast(v)) return populate_max_contiguous_broadcast(x); if(auto *x = dynamic_cast(v)) @@ -420,6 +453,23 @@ std::vector align::populate_starting_multiple_reshape(ir::reshape_inst return add_to_cache(x, result, starting_multiple_); } +std::vector align::populate_starting_multiple_dequantize(ir::dequantize_inst* x){ + auto shapes = get_shapes(x); + std::vector result; + ir::value *op = x->get_operand(0); + auto ret_last_dim = (x->get_type()->get_block_shapes()).back(); + auto op_last_dim = (op->get_type()->get_block_shapes()).back(); + auto op_multiple = populate_starting_multiple(op); + for(size_t d = 0; d < shapes.size(); d++) { + unsigned factor = 1; + if (d == shapes.size() - 1) { + factor = ret_last_dim / op_last_dim; + } + result.push_back(factor * op_multiple[d]); + } + return add_to_cache(x, result, starting_multiple_); +} + std::vector align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){ auto result = populate_starting_multiple(x->get_operand(0)); return add_to_cache(x, result, starting_multiple_); @@ -539,6 +589,8 @@ std::vector align::populate_starting_multiple(ir::value *v){ return populate_starting_multiple_splat(x); if(auto *x = dynamic_cast(v)) return populate_starting_multiple_reshape(x); + if(auto *x = dynamic_cast(v)) + return populate_starting_multiple_dequantize(x); if(auto *x = dynamic_cast(v)) return populate_starting_multiple_broadcast(x); if(auto *x = dynamic_cast(v)) diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc index f079d2580..9e941fee6 100644 --- a/lib/codegen/analysis/axes.cc +++ b/lib/codegen/analysis/axes.cc @@ -56,6 +56,17 @@ void axes::update_graph_trans(ir::instruction *i) { graph_.add_edge({i, perm[d]}, {op, d}); } +void axes::update_graph_dequantize(ir::instruction *i) { + auto *dequantize = static_cast(i); + auto shapes = dequantize->get_type()->get_block_shapes(); + ir::value *op = dequantize->get_operand(0); + + // add edge except the last axis + for(unsigned d = 0; d < shapes.size() - 1; d ++){ + graph_.add_edge({i, d}, {op, d}); + } +} + void axes::update_graph_broadcast(ir::instruction *i) { auto *broadcast = static_cast(i); auto shapes = broadcast->get_type()->get_block_shapes(); @@ -79,7 +90,7 @@ void axes::update_graph_dot(ir::instruction *i) { graph_.add_edge({dot, d}, {D, d}); } -void axes::update_graph_elementwise(ir::instruction *i, +void axes::update_graph_elementwise(ir::instruction *i, bool is_masked_load_async) { if(i->get_num_operands() == 0) return; @@ -119,6 +130,7 @@ void axes::update_graph(ir::instruction *i) { case ir::INST_SPLAT: return update_graph_no_edge(i); case ir::INST_CAT: return update_graph_elementwise(i, true); case ir::INST_TRANS: return update_graph_trans(i); + case ir::INST_DEQUANTIZE: return update_graph_dequantize(i); case ir::INST_BROADCAST: return update_graph_broadcast(i); case ir::INST_DOT: return update_graph_dot(i); case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 526c64b47..415adaab2 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -99,6 +99,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define vec_ty(type, num_el) VectorType::get(type, num_el, false) #define ptr_ty(...) PointerType::get(__VA_ARGS__) // constants +#define i16(...) builder_->getInt16(__VA_ARGS__) #define i32(...) builder_->getInt32(__VA_ARGS__) // ops #define and_(...) builder_->CreateAnd(__VA_ARGS__) @@ -854,6 +855,234 @@ void generator::visit_cast_inst(ir::cast_inst* x) { } } +std::tuple generator::int16_to_float16x8( + Value *in0, Value *scale_x512, Value *shift +){ + /* unpacking 8 int2s packed into an int16 to 8 float16s + * the algorithm is similar to + * https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/ + fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1492-L1563 + */ + Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); + InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false), + "{" + ".reg .b32 a<2>, b<4>; \n\t" // input is 0xab,cd,ef,gh,ab,cd,ef,gh, each a, b etc occupies two bits. + "and.b32 a0, 0x30300303, $4; \n\t" // set a0 to 0x0b,00,0f,00,00,0d,00,0h + "and.b32 a1, 0xc0c00c0c, $4; \n\t" // set a1 to 0xa0,00,e0,00,00,c0,00,g0 + "prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x00,00,00,0d,00,00,00,0h + "prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00,00,00,c0,00,00,00,g0 + "prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x00,00,0b,00,00,00,0f,00 + "prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00,00,a0,00,00,00,e0,00 + "mov.b32 a0, 0x78007800; \n\t" // a0 = 32768 + "mov.b32 a1, 0x70007000; \n\t" // a1 = 8192 + "mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768. + "mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 8192. + "mov.b32 a0, 0x68006800; \n\t" // a0 = 2048 + "mov.b32 a1, 0x60006000; \n\t" // a1 = 512 + "mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 2048. + "mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 512. + "fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift. + "fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift. + "fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out2 = b2 * scale + shift. + "fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out3 = b3 * scale + shift. + "}", "=r,=r,=r,=r,r,r,r", false); + + Value *packed_in = UndefValue::get(vec_ty(i16_ty, 2)); + packed_in = insert_elt(packed_in, in0, (int)0); + packed_in = insert_elt(packed_in, in0, (int)1); + Value *in = bit_cast(packed_in, i32_ty); + + Value *ret = call(ptx, {in, scale_x512, shift}); + Value *packed_ret0 = extract_val(ret, {0}); + Value *packed_ret1 = extract_val(ret, {1}); + Value *packed_ret2 = extract_val(ret, {2}); + Value *packed_ret3 = extract_val(ret, {3}); + Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h + Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g + Value *ret2 = extract_elt(packed_ret2, (uint64_t)0); // f + Value *ret3 = extract_elt(packed_ret3, (uint64_t)0); // e + Value *ret4 = extract_elt(packed_ret0, (uint64_t)1); // d + Value *ret5 = extract_elt(packed_ret1, (uint64_t)1); // c + Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b + Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a + return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7); +} + +std::tuple generator::int32_to_float16x8( + Value *in0, Value *scale_x512, Value *shift +){ + /* unpacking 8 int4s packed into an int32 to 8 float16s + * the algorithm is similar to + * https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/ + fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1566-L1619 + */ + Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); + InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false), + "{" + ".reg .b32 a<2>, b<4>; \n\t" + "and.b32 a0, 0x0f0f0f0f, $4; \n\t" // If input is 0xabcdefgh set a to 0x0b0d0f0h + "and.b32 a1, 0xf0f0f0f0, $4; \n\t" // If input is 0xabcdefgh set a to 0xa0c0e0g0 + "prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x000f000h + "prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00e000g0 + "prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x000b000d + "prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00a000c0 + "mov.b32 a0, 0x78007800; \n\t" + "mov.b32 a1, 0x68006800; \n\t" + "mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768. + "mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 2048. + "mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 32768. + "mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 2048. + "fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift. + "fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift. + "fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out0 = b0 * scale + shift. + "fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out1 = b1 * scale + shift. + "}", "=r,=r,=r,=r,r,r,r", false); + + Value *ret = call(ptx, {in0, scale_x512, shift}); + Value *packed_ret0 = extract_val(ret, {0}); + Value *packed_ret1 = extract_val(ret, {1}); + Value *packed_ret2 = extract_val(ret, {2}); + Value *packed_ret3 = extract_val(ret, {3}); + Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h + Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g + Value *ret2 = extract_elt(packed_ret0, (uint64_t)1); // f + Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // e + Value *ret4 = extract_elt(packed_ret2, (uint64_t)0); // d + Value *ret5 = extract_elt(packed_ret3, (uint64_t)0); // c + Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b + Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a + return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7); +} + +std::tuple generator::int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift){ + /* unpacking 4 int8s packed into an int32 to 4 fp16s + * the algorithm is similar to + * https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/ + fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1622-L1646 + */ + Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); + InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false), + "{" + ".reg .b32 a, b<2>; \n\t" + "prmt.b32 b0, 0, $2, 0x0504; \n\t" // If input is 0xabcdefgh set b0 to 0x00ef00gh + "prmt.b32 b1, 0, $2, 0x0706; \n\t" // If input is 0xabcdefgh set b1 to 0x00ab00cd + "mov.b32 a, 0x78007800; \n\t" + "mul.f16x2 b0, b0, a; \n\t" // b0 = b0 * 32768. + "mul.f16x2 b1, b1, a; \n\t" // b1 = b1 * 32768. + "fma.rn.f16x2 $0, b0, $3, $4; \n\t" // out0 = b0 * scale + shift. + "fma.rn.f16x2 $1, b1, $3, $4; \n\t" // out1 = b1 * scale + shift. + "}", "=r,=r,r,r,r", false); + + Value *ret = call(ptx, {in0, scale_x512, shift}); + Value *packed_ret0 = extract_val(ret, {0}); + Value *packed_ret1 = extract_val(ret, {1}); + Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // gh + Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); // ef + Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); // cd + Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // ab + return std::make_tuple(ret0, ret1, ret2, ret3); +} + +std::tuple generator::prepare_scale_shift(Value *scale, Value *shift){ + Value *scale_x512 = fmul(scale, bit_cast(i16(0x6000), f16_ty)); + Value *p_scale_x512 = UndefValue::get(vec_ty(f16_ty, 2)); + p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)0); + p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)1); + p_scale_x512 = bit_cast(p_scale_x512, i32_ty); + + Value *p_shift = UndefValue::get(vec_ty(f16_ty, 2)); + p_shift = insert_elt(p_shift, shift, (int)0); + p_shift = insert_elt(p_shift, shift, (int)1); + p_shift = bit_cast(p_shift, i32_ty); + + return std::make_tuple(p_scale_x512, p_shift); +} + +/** + * \brief Code Generation for `dequantize` + */ +void generator::visit_dequantize_inst(ir::dequantize_inst* x) { + ir::value *op = x->get_operand(0); + + auto src_ty_size_in_bits = op->get_type()->get_scalar_ty()->get_primitive_size_in_bits(); + + auto ret_last_dim = (x->get_type()->get_block_shapes()).back(); + auto op_last_dim = (op->get_type()->get_block_shapes()).back(); + + auto x_idxs = idxs_.at(x); + auto op_idxs = idxs_.at(op); + + ir::value *scale = x->get_operand(1); + ir::value *shift = x->get_operand(2); + + Value *p_scale_x512, *p_shift; + std::tie(p_scale_x512, p_shift) = prepare_scale_shift(vals_[scale][{}], vals_[shift][{}]); + + int ld = layouts_->get(x)->get_order(0); + int contiguous = layouts_->get(x)->to_scanline()->nts(ld); + + int op_ld = layouts_->get(op)->get_order(0); + int op_contiguous = layouts_->get(op)->to_scanline()->nts(op_ld); + + std::string err_msg; + err_msg = "unsupported dequantization, cannot vectorize properly. x_idxs.size(): " + + std::to_string(x_idxs.size()) + "; op_idxs.size(): " + + std::to_string(op_idxs.size()) + "; contiguous: " + + std::to_string(contiguous) + "; op_contiguous: " + + std::to_string(op_contiguous) + ". if the condition " + "is not met, please try adjusting block_size, num_warps or " + "using tl.multiple_of to hint the input/output ptr address."; + + if (ret_last_dim == 8 * op_last_dim) { + if((x_idxs.size() != 8 * op_idxs.size()) || (contiguous != 8 * op_contiguous)) { + throw std::runtime_error(err_msg); + } + + auto cvt = [&]( + Value* a, Value* scale, Value* shift + ){ + if (src_ty_size_in_bits == 16){ // int2 quantization, int16 to 8 fp16s + return int16_to_float16x8(a, scale, shift); + } else if (src_ty_size_in_bits == 32) { // int4 quantization, int32 to 8 fp16s + return int32_to_float16x8(a, scale, shift); + } else { + throw std::runtime_error("unsupported conversion"); + } + }; + + for(size_t j = 0; j < op_idxs.size(); j++){ + size_t i = j * 8; + std::tie(vals_[x][x_idxs[i+0]], + vals_[x][x_idxs[i+1]], + vals_[x][x_idxs[i+2]], + vals_[x][x_idxs[i+3]], + vals_[x][x_idxs[i+4]], + vals_[x][x_idxs[i+5]], + vals_[x][x_idxs[i+6]], + vals_[x][x_idxs[i+7]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift); + } + } else if (ret_last_dim == 4 * op_last_dim && src_ty_size_in_bits == 32) { // int8 quantization, int32 to 4 fp16s + if((x_idxs.size() != 4 * op_idxs.size()) || (contiguous != 4 * op_contiguous)) { + throw std::runtime_error(err_msg); + } + + auto cvt = [&](Value* a, Value* scale, Value* shift){ + return int32_to_float16x4(a, scale, shift); + }; + + for(size_t j = 0; j < op_idxs.size(); j++){ + size_t i = j * 4; + std::tie(vals_[x][x_idxs[i+0]], + vals_[x][x_idxs[i+1]], + vals_[x][x_idxs[i+2]], + vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift); + } + } else { + throw std::runtime_error("unsupported dequantization"); + } + return; +} + /** * \brief Code Generation for `return` */ @@ -907,7 +1136,7 @@ void generator::visit_load_inst(ir::load_inst* x){ vec = std::min(layout->contig_per_thread(ord[0]), aln); // TODO: generalize - is_mma_first_row = (ord.size() >= 1) && layout->to_mma() && + is_mma_first_row = (ord.size() >= 1) && layout->to_mma() && (a_axes_->get(x, ord[0]) == layouts_->get(x)->get_axis(1)); if(is_mma_first_row) vec = std::min(2, aln); @@ -1009,7 +1238,7 @@ void generator::visit_load_inst(ir::load_inst* x){ std::vector arg_tys = {pred->getType(), ptr->getType()}; for(Value *v: others) arg_tys.push_back(v->getType()); - if (has_l2_evict_policy) + if (has_l2_evict_policy) arg_tys.push_back(i64_ty); FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false); // --- @@ -1025,7 +1254,7 @@ void generator::visit_load_inst(ir::load_inst* x){ asm_cstrt += ","; asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); } - if (has_l2_evict_policy) + if (has_l2_evict_policy) asm_cstrt += ",l"; // --- // finally call inline ASM @@ -1036,8 +1265,8 @@ void generator::visit_load_inst(ir::load_inst* x){ args.push_back(v); if (has_l2_evict_policy) args.push_back(policies_.at(x->get_eviction_policy())); - - + + Value *_ret = call(inlineAsm, args); // if(!op->get_type()->is_block_ty()){ // Value* cond = icmp_eq(tid, i32(0)); @@ -1050,7 +1279,7 @@ void generator::visit_load_inst(ir::load_inst* x){ // _ret = load(shptr); // add_barrier(); // } - + // --- // extract and store return values // --- @@ -1104,7 +1333,7 @@ void generator::visit_store_inst(ir::store_inst * x){ // vec = std::min(nts, aln); vec = std::min(layout->contig_per_thread(ord[0]), aln); // TODO: generalize - bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() && + bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() && (a_axes_->get(ptr_op, ord[0]) == layouts_->get(ptr_op)->get_axis(1)); if(is_mma_first_row) vec = std::min(2, aln); @@ -1166,7 +1395,7 @@ void generator::visit_store_inst(ir::store_inst * x){ std::vector arg_tys = {pred->getType(), ptr->getType()}; for(int ii = 0; ii < n_words; ii++) arg_tys.push_back(val_arg_ty); - if (has_l2_evict_policy) + if (has_l2_evict_policy) arg_tys.push_back(i64_ty); FunctionType *asm_ty = FunctionType::get(builder_->getVoidTy(), arg_tys, false); // --- @@ -1177,7 +1406,7 @@ void generator::visit_store_inst(ir::store_inst * x){ asm_cstrt += ","; asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); } - if (has_l2_evict_policy) + if (has_l2_evict_policy) asm_cstrt += ",l"; // --- // finally call inline ASM @@ -1817,13 +2046,13 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va namespace { class mma16816_smem_loader { public: - mma16816_smem_loader(int wpt, std::vector order, int k_order, - std::vector tile_shape, - std::vector instr_shape, std::vector mat_shape, - int per_phase, int max_phase, int dtsize, Builder *builder, + mma16816_smem_loader(int wpt, std::vector order, int k_order, + std::vector tile_shape, + std::vector instr_shape, std::vector mat_shape, + int per_phase, int max_phase, int dtsize, Builder *builder, adder add, multiplier mul, geper gep) : wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape), - instr_shape_(instr_shape), mat_shape_(mat_shape), + instr_shape_(instr_shape), mat_shape_(mat_shape), per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder), add(add), mul(mul), gep(gep) { // compute compile-time constant variables & types @@ -1837,7 +2066,7 @@ public: need_trans_ = k_order_ != order_[0]; can_use_ldmatrix_ = dtsize == 2 || (!need_trans_); - // we need more pointers at the fast-changing axis, + // we need more pointers at the fast-changing axis, if (can_use_ldmatrix_) num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]]; else // warning: this only works for tf32 & need transpose @@ -1873,7 +2102,7 @@ public: Value *s0 = urem(s, i32(2)); Value *s1 = udiv(s, i32(2)); - // We use different orders for a & b for better performance. + // We use different orders for a & b for better performance. Value *k_mat_arr = (k_order_ == 1) ? s1 : s0; Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1; mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)), @@ -1884,7 +2113,7 @@ public: Value *s_mat_off = mat_off[order_[1]]; // offset inside a matrix Value *s_off_in_mat = c; - + std::vector offs(num_ptr_); Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); // pre-compute strided offset @@ -1898,7 +2127,7 @@ public: } else if (dtsize_ == 4 && need_trans_) { // load tf32 matrices with lds32 Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]] - Value *s_off_in_mat = urem(lane, i32(4)); // + Value *s_off_in_mat = urem(lane, i32(4)); // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); std::vector offs(num_ptr_); @@ -1945,7 +2174,7 @@ public: Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)), mul(nk_mat_arr, i32(mat_arr_stride_))); Value *s_mat_off = k_mat_arr; // always 0? - + for (int loadx4_off = 0; loadx4_off < num_ptr_/8; ++loadx4_off) { for (int elem_off = 0; elem_off < 4; ++elem_off) { int ptr_off = loadx4_off*8 + nk_mat_arr_int*4 + elem_off; @@ -1971,10 +2200,10 @@ public: throw std::runtime_error("invalid smem load config"); } - std::tuple + std::tuple load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn, Value *pre_ptr, Value *next_ptr, std::vector &off, std::vector &ptrs, - FunctionType *ldmatrix_ty, Type *smem_ptr_ty, + FunctionType *ldmatrix_ty, Type *smem_ptr_ty, std::map> &prefetch_latch_to_bb_) { assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned"); int mat_idx[2] = {mat0, mat1}; @@ -2006,7 +2235,7 @@ public: std::string trans = need_trans_ ? ".trans" : ""; // the offset (in byte) on the strided axis is a constant int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_; - InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty, + InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty, "ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 " "{$0, $1, $2, $3}, " "[$4 + " + std::to_string(s_offset) + "];", @@ -2015,7 +2244,7 @@ public: res_v4 = call(ldmatrix_ty, ld_fn, {ptr}); if (k == 0 && inc == 1 && is_prefetch) prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4); - return {extract_val(res_v4, std::vector{0}), + return {extract_val(res_v4, std::vector{0}), extract_val(res_v4, std::vector{1}), extract_val(res_v4, std::vector{2}), extract_val(res_v4, std::vector{3})}; @@ -2062,13 +2291,13 @@ public: Value *i32_elems[4]; for (int i=0; i<4; ++i) i8v4_elems[i] = UndefValue::get(vec_ty(i8_ty, 4)); - + Value *elem00, *elem01, *elem02, *elem03; Value *elem10, *elem11, *elem12, *elem13; Value *elem20, *elem21, *elem22, *elem23; Value *elem30, *elem31, *elem32, *elem33; Value *i8_elems[4*4]; - if (k_order_ == 1) { // + if (k_order_ == 1) { // i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem))); i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem))); i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem))); @@ -2155,7 +2384,7 @@ private: int s_mat_stride_; // stride when moving to next not-k mat int warp_off_stride_; - int mat_arr_stride_; // matrix arrangement (inside a load) stride + int mat_arr_stride_; // matrix arrangement (inside a load) stride bool need_trans_, can_use_ldmatrix_; int num_ptr_; @@ -2232,7 +2461,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); smem_ptr_ty = ptr_ty(f16_ty, 3); ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{smem_ptr_ty}, false); - phi_ty = fp16x2_ty; + phi_ty = fp16x2_ty; } else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) { mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); smem_ptr_ty = ptr_ty(bf16_ty, 3); @@ -2303,8 +2532,8 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: if(is_a_shared) { const int per_phase_a = swizzle_->get_per_phase(layout_a); const int max_phase_a = swizzle_->get_max_phase(layout_a); - mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a, - {mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k}, + mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a, + {mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k}, per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep); std::vector off_a = a_loader.compute_offs(warp_m, lane); int num_ptr_a = a_loader.get_num_ptr(); @@ -2319,7 +2548,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: // loading function load_a = [&,a_loader,ptrs_a,off_a](int m, int k, int inc, bool is_prefetch) mutable { auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a], - shared_next_ptr_[layout_a], off_a, ptrs_a, + shared_next_ptr_[layout_a], off_a, ptrs_a, ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); register_lds2(ha, m, k, inc, ha0, is_prefetch); register_lds2(ha, m+1, k, inc, ha1, is_prefetch); @@ -2389,12 +2618,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: for(int i = 0; i < num_ptr_b; i++) ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty); - + // loading function std::function load_b; load_b = [&](int n, int k, int inc, bool is_prefetch) { auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b], - shared_next_ptr_[layout_b], off_b, ptrs_b, + shared_next_ptr_[layout_b], off_b, ptrs_b, ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); register_lds2(hb, n, k, inc, hb0, is_prefetch); register_lds2(hb, n+1, k, inc, hb2, is_prefetch); @@ -2419,7 +2648,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: (m + 1)*cols_per_thread + (n*2 + 0), (m + 1)*cols_per_thread + (n*2 + 1) }; - Value *nc = call(mma_ty, mma_fn, + Value *nc = call(mma_ty, mma_fn, {ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}], hb[{n, k}], hb[{n, k+1}], fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]}); @@ -2608,7 +2837,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { return visit_mma884(dot, A, B, D, NK); if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80) return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()? - if (dot->get_type()->get_scalar_ty()->is_fp32_ty() && + if (dot->get_type()->get_scalar_ty()->is_fp32_ty() && A->get_type()->get_scalar_ty()->is_fp32_ty()) return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); throw std::runtime_error("dot has invalid operand type"); @@ -2710,7 +2939,7 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va warps_per_inner = layout->to_mma()->wpt(1); col_per_thread = axes_.at(a_axes_->get(arg, 1)).values.size(); warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id; - } + } assert(warp_j != nullptr); // unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]); @@ -3367,7 +3596,7 @@ void generator::visit_constant_fp(ir::constant_fp *x){ if (x->get_type()->get_scalar_ty()->is_bf16_ty()) { // highest 16 bits of fp32 float fp32_value = x->get_value(); - uint16_t bf16_raw = (*reinterpret_cast(&fp32_value) + uint16_t bf16_raw = (*reinterpret_cast(&fp32_value) & 0xffff0000) >> 16; std::stringstream const_str; const_str << "0x" << std::hex << bf16_raw << "U"; // unsigned diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 120b575cf..7f7dfdc98 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -120,6 +120,14 @@ value *builder::create_ret(value* val) { return insert(return_inst::create(ctx_, val)); } +//===----------------------------------------------------------------------===// +// dequantize instructions +//===----------------------------------------------------------------------===// + +value* builder::create_dequantize(value *src, value *scale, value *shift, type *dst_ty){ + return insert(dequantize_inst::create(src, scale, shift, dst_ty)); +} + //===----------------------------------------------------------------------===// // cast instructions //===----------------------------------------------------------------------===// diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 7831e1650..8f6631e34 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -323,6 +323,21 @@ unary_inst::unary_inst(type *ty, value_id_t id, value *v, const std::string &nam set_operand(0, v); } +//===----------------------------------------------------------------------===// +// dequantize_inst classes +//===----------------------------------------------------------------------===// + +dequantize_inst::dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next) + : instruction(ty, INST_DEQUANTIZE, 3, name, next) { + set_operand(0, v); + set_operand(1, scale); + set_operand(2, shift); +} + +dequantize_inst *dequantize_inst::create(value *arg, value *scale, value *shift, type *ty, const std::string &name, instruction *next){ + return new dequantize_inst(ty, arg, scale, shift, name, next); +} + //===----------------------------------------------------------------------===// // cast_inst classes //===----------------------------------------------------------------------===// @@ -584,7 +599,7 @@ masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, EVICTI set_operand(2, mask); } -masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, EVICTION_POLICY eviction, +masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, EVICTION_POLICY eviction, const std::string &name, instruction *next) { return new masked_store_inst(ptr, val, mask, eviction, name, next); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 8918e8809..cb3fdbb6e 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -83,8 +83,8 @@ void cu_enqueue(uint64_t stream, uint64_t kernel, CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size, CU_LAUNCH_PARAM_END }; - drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, - block_0, block_1, block_2, + drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, + block_0, block_1, block_2, shared_mem, (CUstream)stream, nullptr, config); } @@ -97,8 +97,8 @@ void hip_enqueue(uint64_t stream, uint64_t kernel, HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size, HIP_LAUNCH_PARAM_END }; - drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2, - block_0, block_1, block_2, + drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2, + block_0, block_1, block_2, shared_mem, (hipStream_t)stream, nullptr, config); } @@ -302,8 +302,8 @@ void init_triton_runtime(py::module &&m) { // cache key - m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, - py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages, + m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, + py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages, py::dict extern_libs, py::function add_to_cache, py::object grid){ // parse arguments to compute cache key, compile-time constants and packed kernel arguments long _num_warps = PyLong_AsLong(num_warps.ptr()); @@ -351,8 +351,8 @@ void init_triton_runtime(py::module &&m) { // release the gil in case the enqueue blocks // cuda will block if too many ops are enqueued py::gil_scoped_release allow_threads; - drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, - _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, + drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, + _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, nullptr, config); } return bin; @@ -372,7 +372,7 @@ void init_triton_runtime(py::module &&m) { m.def("max_shared_memory", [](backend_t backend, uint64_t device) { if (backend == HOST) return 0; - if(backend == CUDA) + if(backend == CUDA) return cuGetInfo(device); if(backend == ROCM) return hipGetInfo(device); @@ -422,7 +422,7 @@ void init_triton_runtime(py::module &&m) { hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); }); - + } /*****************************************************************************/ @@ -430,9 +430,9 @@ void init_triton_runtime(py::module &&m) { /*****************************************************************************/ typedef std::map asm_map_t; -// --------------------------------------- +// --------------------------------------- // Compile Triton-IR to assembly -// --------------------------------------- +// --------------------------------------- void init_triton_codegen(py::module &&m) { m.def("compile_ttir", @@ -550,13 +550,13 @@ void init_triton_ir(py::module &&m) { .value("CA", ir::load_inst::CA) .value("CG", ir::load_inst::CG) .export_values(); - + py::enum_(m, "EVICTION_POLICY") .value("NORMAL", ir::load_inst::NORMAL) .value("EVICT_FIRST", ir::load_inst::EVICT_FIRST) .value("EVICT_LAST", ir::load_inst::EVICT_LAST) .export_values(); - + py::enum_(m, "REDUCE_OP") .value("ADD", ir::reduce_inst::ADD) .value("FADD", ir::reduce_inst::FADD) @@ -573,7 +573,7 @@ void init_triton_ir(py::module &&m) { .value("ARGFMIN", ir::reduce_inst::ARGFMIN) .value("ARGFMAX", ir::reduce_inst::ARGFMAX) .value("XOR", ir::reduce_inst::XOR); - + py::enum_(m, "ATOMIC_OP") .value("ADD", ir::atomic_rmw_op_t::Add) .value("FADD", ir::atomic_rmw_op_t::FAdd) @@ -704,7 +704,7 @@ void init_triton_ir(py::module &&m) { py::class_(m, "function_type") .def_property_readonly("ret_ty", &ir::function_type::get_return_ty) - .def_property_readonly("arg_tys", [](ir::function_type* self){ + .def_property_readonly("arg_tys", [](ir::function_type* self){ return std::vector(self->params_begin(), self->params_end()); }); @@ -713,7 +713,7 @@ void init_triton_ir(py::module &&m) { py::class_(m, "block_type") .def_property_readonly("shape", &ir::block_type::get_shapes) .def_property_readonly("numel", &ir::type::get_tile_num_elements); - + py::class_(m, "struct_type") .def("get", &ir::struct_type::get, ret::reference) .def_property_readonly("num_types", &ir::struct_type::get_num_types); @@ -834,6 +834,8 @@ void init_triton_ir(py::module &&m) { .def("create_br", &ir::builder::create_br, ret::reference) .def("create_cond_br", &ir::builder::create_cond_br, ret::reference) .def("create_ret_void", &ir::builder::create_ret_void, ret::reference) + // Dequantize instructions + .def("create_dequantize", &ir::builder::create_dequantize, ret::reference) // Cast instructions .def("create_bitcast", &ir::builder::create_bitcast, ret::reference) .def("create_cast", &ir::builder::create_cast, ret::reference) @@ -857,27 +859,27 @@ void init_triton_ir(py::module &&m) { .def("create_frem", &ir::builder::create_frem, ret::reference) .def("create_fadd", &ir::builder::create_fadd, ret::reference) .def("create_fsub", &ir::builder::create_fsub, ret::reference) - .def("create_mul", &ir::builder::create_mul, ret::reference, - py::arg("lhs"), py::arg("rhs"), + .def("create_mul", &ir::builder::create_mul, ret::reference, + py::arg("lhs"), py::arg("rhs"), py::arg("has_nuw")=false, py::arg("has_nsw")=false) .def("create_sdiv", &ir::builder::create_sdiv, ret::reference) .def("create_udiv", &ir::builder::create_udiv, ret::reference) .def("create_srem", &ir::builder::create_srem, ret::reference) .def("create_urem", &ir::builder::create_urem, ret::reference) - .def("create_add", &ir::builder::create_add, ret::reference, - py::arg("lhs"), py::arg("rhs"), + .def("create_add", &ir::builder::create_add, ret::reference, + py::arg("lhs"), py::arg("rhs"), py::arg("has_nuw")=false, py::arg("has_nsw")=false) .def("create_sub", &ir::builder::create_sub, ret::reference, - py::arg("lhs"), py::arg("rhs"), + py::arg("lhs"), py::arg("rhs"), py::arg("has_nuw")=false, py::arg("has_nsw")=false) .def("create_shl", &ir::builder::create_shl, ret::reference, - py::arg("lhs"), py::arg("rhs"), + py::arg("lhs"), py::arg("rhs"), py::arg("has_nuw")=false, py::arg("has_nsw")=false) .def("create_lshr", &ir::builder::create_lshr, ret::reference, - py::arg("lhs"), py::arg("rhs"), + py::arg("lhs"), py::arg("rhs"), py::arg("has_nuw")=false, py::arg("has_nsw")=false) .def("create_ashr", &ir::builder::create_ashr, ret::reference, - py::arg("lhs"), py::arg("rhs"), + py::arg("lhs"), py::arg("rhs"), py::arg("has_nuw")=false, py::arg("has_nsw")=false) // GEP .def("create_gep", &ir::builder::create_gep, ret::reference) diff --git a/python/test/unit/language/test_dequantize.py b/python/test/unit/language/test_dequantize.py new file mode 100644 index 000000000..93935a4b0 --- /dev/null +++ b/python/test/unit/language/test_dequantize.py @@ -0,0 +1,261 @@ +# flake8: noqa: F821,F841 + +import random + +import torch + +import triton +import triton.language as tl + + +@triton.jit +def dequantize_kernel_int8(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr): + w_offsets = tl.arange(0, BLOCK_SIZE // 4) + mask = w_offsets < (size // 4) + input_ptrs = input_ptr + 1 + w_offsets + input = tl.load(input_ptrs, mask=mask, other=0) + scale_shift = tl.load(input_ptr) + scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True) + shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True) + output = tl.dequantize(input, scale, shift, 8) + offsets = tl.arange(0, BLOCK_SIZE) + output_ptrs = tl.multiple_of(output_ptr + offsets, 4) + tl.store(output_ptrs, output, mask=offsets < size) + + +@triton.jit +def dequantize_kernel_scale_shift_int8( + output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr +): + w_offsets = tl.arange(0, BLOCK_SIZE // 4) + mask = w_offsets < (size // 4) + input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1) + input = tl.load(input_ptrs, mask=mask, other=0) + scale = tl.load(scale_ptr) + shift = tl.load(shift_ptr) + output = tl.dequantize(input, scale, shift, 8) + offsets = tl.arange(0, BLOCK_SIZE) + output_ptrs = tl.multiple_of(output_ptr + offsets, 4) + tl.store(output_ptrs, output, mask=offsets < size) + + +@triton.jit +def dequantize_kernel_int4(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr): + w_offsets = tl.arange(0, BLOCK_SIZE // 8) + mask = w_offsets < (size // 8) + input_ptrs = input_ptr + 1 + w_offsets + input = tl.load(input_ptrs, mask=mask, other=0) + scale_shift = tl.load(input_ptr) + scale = (scale_shift & 65535).to(tl.int16).to(tl.float16, bitcast=True) + shift = (scale_shift >> 16).to(tl.int16).to(tl.float16, bitcast=True) + output = tl.dequantize(input, scale, shift, 4) + offsets = tl.arange(0, BLOCK_SIZE) + output_ptrs = tl.multiple_of(output_ptr + offsets, 8) + tl.store(output_ptrs, output, mask=offsets < size) + + +@triton.jit +def dequantize_kernel_scale_shift_int4( + output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr +): + w_offsets = tl.arange(0, BLOCK_SIZE // 8) + mask = w_offsets < (size // 8) + input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1) + input = tl.load(input_ptrs, mask=mask, other=0) + scale = tl.load(scale_ptr) + shift = tl.load(shift_ptr) + output = tl.dequantize(input, scale, shift, 4) + offsets = tl.arange(0, BLOCK_SIZE) + output_ptrs = tl.multiple_of(output_ptr + offsets, 8) + tl.store(output_ptrs, output, mask=offsets < size) + + +@triton.jit +def dequantize_kernel_int2(output_ptr, input_ptr, size, BLOCK_SIZE: tl.constexpr): + w_offsets = tl.arange(0, BLOCK_SIZE // 8) + mask = w_offsets < (size // 8) + input_ptrs = tl.multiple_of(input_ptr + 2 + w_offsets, 1) + input = tl.load(input_ptrs, mask=mask, other=0) + scale = tl.load(input_ptr).to(tl.float16, bitcast=True) + shift = tl.load(input_ptr + 1).to(tl.float16, bitcast=True) + output = tl.dequantize(input, scale, shift, 2) + offsets = tl.arange(0, BLOCK_SIZE) + output_ptrs = tl.multiple_of(output_ptr + offsets, 8) + tl.store(output_ptrs, output, mask=offsets < size) + + +@triton.jit +def dequantize_kernel_scale_shift_int2( + output_ptr, input_ptr, scale_ptr, shift_ptr, size, BLOCK_SIZE: tl.constexpr +): + w_offsets = tl.arange(0, BLOCK_SIZE // 8) + mask = w_offsets < (size // 8) + input_ptrs = tl.multiple_of(input_ptr + w_offsets, 1) + input = tl.load(input_ptrs, mask=mask, other=0) + scale = tl.load(scale_ptr) + shift = tl.load(shift_ptr) + output = tl.dequantize(input, scale, shift, 2) + offsets = tl.arange(0, BLOCK_SIZE) + output_ptrs = tl.multiple_of(output_ptr + offsets, 8) + tl.store(output_ptrs, output, mask=offsets < size) + + +def test_dequantize_int8() -> None: + for i in range(10): + if i < 5: + size = random.randrange(16, 128, 4) + else: + size = random.randrange(132, 1024, 4) + device = torch.device(torch.cuda.current_device()) + + scale_val = random.uniform(0.1, 4.0) + shift_val = random.uniform(-10.0, 10.0) + scale = torch.tensor(scale_val, dtype=torch.float16, device=device) + shift = torch.tensor(shift_val, dtype=torch.float16, device=device) + scale_shift = torch.tensor( + [scale_val, shift_val], + dtype=torch.float16, + device=device, + ).view(torch.int32) + + input_int8 = torch.randint( + 0, 256, (size,), dtype=torch.uint8, device=device + ) + input_int32 = input_int8.view(torch.int32) + + input = torch.cat((scale_shift, input_int32)) + expected = (input_int8 * scale + shift).to(torch.float16) + + output = torch.empty([size], dtype=torch.float16, device=device) + block_size = max(triton.next_power_of_2(size), 128) + grid = (1,) + dequantize_kernel_int8[grid]( + output, input, size, BLOCK_SIZE=block_size, num_warps=1 + ) + rtol, atol = 1e-02, 1e-02 + assert torch.allclose(output, expected, rtol, atol) + + output = torch.empty([size], dtype=torch.float16, device=device) + dequantize_kernel_scale_shift_int8[grid]( + output, + input_int32, + scale, + shift, + size, + BLOCK_SIZE=block_size, + num_warps=1, + ) + assert torch.allclose(output, expected, rtol, atol) + + +def test_dequantize_int4() -> None: + for i in range(10): + if i < 5: + size = random.randrange(16, 256, 8) + else: + size = random.randrange(264, 1024, 8) + device = torch.device(torch.cuda.current_device()) + + scale_val = random.uniform(0.1, 4.0) + shift_val = random.uniform(-10.0, 10.0) + scale = torch.tensor(scale_val, dtype=torch.float16, device=device) + shift = torch.tensor(shift_val, dtype=torch.float16, device=device) + scale_shift = torch.tensor( + [scale_val, shift_val], + dtype=torch.float16, + device=device, + ).view(torch.int32) + + input_int8 = torch.randint( + 0, 256, (size // 2,), dtype=torch.uint8, device=device + ) + input_int32 = input_int8.view(torch.int32) + + input_int8_h1 = input_int8 >> 4 + input_int8_h0 = input_int8 & 15 + + input_int4_val = torch.stack( + (input_int8_h0, input_int8_h1), dim=1 + ).flatten() + + input = torch.cat((scale_shift, input_int32)) + expected = (input_int4_val * scale + shift).to(torch.float16) + + output = torch.empty([size], dtype=torch.float16, device=device) + block_size = max(triton.next_power_of_2(size), 256) + grid = (1,) + dequantize_kernel_int4[grid]( + output, input, size, BLOCK_SIZE=block_size, num_warps=1 + ) + rtol, atol = 1e-02, 1e-02 + assert torch.allclose(output, expected, rtol, atol) + + output = torch.empty([size], dtype=torch.float16, device=device) + dequantize_kernel_scale_shift_int4[grid]( + output, + input_int32, + scale, + shift, + size, + BLOCK_SIZE=block_size, + num_warps=1, + ) + assert torch.allclose(output, expected, rtol, atol) + + +def test_dequantize_int2() -> None: + for i in range(10): + if i < 5: + size = random.randrange(16, 256, 8) + else: + size = random.randrange(264, 1024, 8) + device = torch.device(torch.cuda.current_device()) + + scale_val = random.uniform(0.1, 4.0) + shift_val = random.uniform(-10.0, 10.0) + scale = torch.tensor(scale_val, dtype=torch.float16, device=device) + shift = torch.tensor(shift_val, dtype=torch.float16, device=device) + scale_shift = torch.tensor( + [scale_val, shift_val], + dtype=torch.float16, + device=device, + ).view(torch.int16) + + input_int8 = torch.randint( + 0, 256, (size // 4,), dtype=torch.uint8, device=device + ) + input_int16 = input_int8.view(torch.int16) + + input_int8_q3 = input_int8 >> 6 + input_int8_q2 = (input_int8 >> 4) & 3 + input_int8_q1 = (input_int8 >> 2) & 3 + input_int8_q0 = input_int8 & 3 + + input_int2_val = torch.stack( + (input_int8_q0, input_int8_q1, input_int8_q2, input_int8_q3), dim=1 + ).flatten() + + input = torch.cat((scale_shift, input_int16)) + expected = (input_int2_val * scale + shift).to(torch.float16) + + output = torch.empty([size], dtype=torch.float16, device=device) + block_size = max(triton.next_power_of_2(size), 256) + grid = (1,) + + dequantize_kernel_int2[grid]( + output, input, size, BLOCK_SIZE=block_size, num_warps=1 + ) + rtol, atol = 1e-02, 1e-02 + assert torch.allclose(output, expected, rtol, atol) + + output = torch.empty([size], dtype=torch.float16, device=device) + dequantize_kernel_scale_shift_int2[grid]( + output, + input_int16, + scale, + shift, + size, + BLOCK_SIZE=block_size, + num_warps=1, + ) + assert torch.allclose(output, expected, rtol, atol) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 63a9ab7f2..69f49c146 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -685,6 +685,20 @@ def zeros(shape, dtype, _builder=None): return semantic.zeros(shape, dtype, _builder) +# ----------------------- +# dequantize +# ----------------------- + + +@builtin +def dequantize(input, scale, shift, nbit, dst_ty=float16, _builder=None): + """ + Tries to dequantize the input to given dtype + """ + nbit = _constexpr_to_value(nbit) + return semantic.dequantize(input, scale, shift, nbit, dst_ty, _builder) + + # ----------------------- # Shape Manipulation # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 11306851c..62e4a30bd 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -544,6 +544,31 @@ def broadcast_impl_value(lhs: tl.tensor, # (scalar, scalar) => returns original blocks return lhs, rhs + +####### +# dequantize +####### + +def dequantize(input: tl.tensor, + scale: tl.tensor, + shift: tl.tensor, + nbit: int, + dst_ty: tl.dtype, + builder: ir.builder) -> tl.tensor: + input_ty = input.type + assert input_ty.is_block() + assert input_ty.element_ty.is_int32() or input_ty.element_ty.is_int16() + assert nbit in [2, 4, 8] + assert dst_ty == tl.float16 + + shape = input_ty.get_block_shapes() + factor = input_ty.element_ty.primitive_bitwidth // nbit + dst_shape = shape[:-1] + [factor * shape[-1]] + + dst_ty = tl.block_type(dst_ty, dst_shape) + return tl.tensor(builder.create_dequantize(input.handle, scale.handle, shift.handle, dst_ty.to_ir(builder)), dst_ty) + + ####### # cast #######