[BACKEND] Added Int8 mma (#440)
This commit is contained in:
		| @@ -33,7 +33,9 @@ inline bool is_hmma_c(ir::value *v, int sm){ | ||||
|     result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) || | ||||
|              (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) || | ||||
|              (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&  | ||||
|               x->allow_tf32() && sm >= 80); | ||||
|               x->allow_tf32() && sm >= 80) || | ||||
|              (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&  | ||||
|               sm >= 80); | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
| @@ -63,7 +65,7 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) { | ||||
|         return mma_type; | ||||
|       } | ||||
|     } else if (c_ty->get_scalar_ty()->is_integer_ty(32)) { | ||||
|       throw std::runtime_error("integer tensor cores are not yet supported"); | ||||
|       // throw std::runtime_error("integer tensor cores are not yet supported"); | ||||
|       // // integer tensor cores | ||||
|       // if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) { | ||||
|       //   mma_type = mma_layout::INT32_INT1_INT1_INT32; | ||||
| @@ -73,10 +75,10 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) { | ||||
|       //   mma_type = mma_layout::INT32_INT4_INT4_INT32; | ||||
|       //   return mma_type; | ||||
|       // } | ||||
|       // if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) { | ||||
|       //   mma_type = mma_layout::INT32_INT8_INT8_INT32; | ||||
|       //   return mma_type; | ||||
|       // } | ||||
|       if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) { | ||||
|         mma_type = mma_layout::INT32_INT8_INT8_INT32; | ||||
|         return mma_type; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return mma_layout::NOT_APPLICABLE; | ||||
| @@ -444,11 +446,21 @@ shared_layout::shared_layout(data_layout *arg, | ||||
|     std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_)); | ||||
|     mma_vec_     = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m | ||||
|     mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2]; | ||||
|  | ||||
|     // for now, disable swizzle when using lds.8 | ||||
|     if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32) | ||||
|       if (order_[0] == 0) // need transpose | ||||
|         allow_swizzle_ = false; | ||||
|   } else if (hmma_dot_b_) { | ||||
|     assert(order_.size() == 2); | ||||
|     std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_)); | ||||
|     mma_vec_     = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k | ||||
|     mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1]; | ||||
|  | ||||
|     // for now, disable swizzle when using lds.8 | ||||
|     if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32) | ||||
|       if (order_[0] == 1) // need transpose | ||||
|         allow_swizzle_ = false; | ||||
|   } | ||||
|  | ||||
|   // size | ||||
|   | ||||
| @@ -41,9 +41,15 @@ void swizzle::run(ir::module &) { | ||||
|           vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); | ||||
|       } | ||||
|       else { | ||||
|         if (!layout->allow_swizzle()) { | ||||
|           per_phase_[layout] = 1; | ||||
|           max_phase_[layout] = 1; | ||||
|           vec_[layout] = 1; | ||||
|         } else { | ||||
|           per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); | ||||
|           max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout]; | ||||
|           vec_[layout]       = layout->get_mma_vec(); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -1,78 +0,0 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <numeric> | ||||
| #include <sstream> | ||||
| #include <iomanip> | ||||
| #include "triton/codegen/selection/generator.h" | ||||
| #include "triton/codegen/target.h" | ||||
| #include "triton/codegen/analysis/axes.h" | ||||
| #include "triton/codegen/analysis/allocation.h" | ||||
| #include "triton/codegen/analysis/align.h" | ||||
| #include "triton/codegen/analysis/swizzle.h" | ||||
| #include "triton/codegen/transform/coalesce.h" | ||||
| #include "triton/ir/context.h" | ||||
| #include "triton/ir/module.h" | ||||
| #include "triton/ir/function.h" | ||||
| #include "triton/ir/type.h" | ||||
| #include "llvm/IR/Module.h" | ||||
| #include "llvm/IR/IRBuilder.h" | ||||
| #include "llvm/IR/IntrinsicsNVPTX.h" | ||||
| #include "llvm/IR/BasicBlock.h" | ||||
| #include "llvm/IR/Attributes.h" | ||||
| #include "llvm/IR/InlineAsm.h" | ||||
| #include "llvm/Transforms/Utils/BasicBlockUtils.h" | ||||
|  | ||||
| namespace triton::codegen { | ||||
| // types | ||||
| #define void_ty              builder_->getVoidTy() | ||||
| #define f16_ty               builder_->getHalfTy() | ||||
| #define bf16_ty              builder_->getBFloatTy() | ||||
| #define f32_ty               builder_->getFloatTy() | ||||
| #define i8_ty                builder_->getInt8Ty() | ||||
| #define i32_ty               builder_->getInt32Ty() | ||||
| #define vec_ty(type, num_el) VectorType::get(type, num_el, false) | ||||
| #define ptr_ty(...)          PointerType::get(__VA_ARGS__) | ||||
| // constants | ||||
| #define i32(...)             builder_->getInt32(__VA_ARGS__) | ||||
| // ops | ||||
| #define and_(...)            builder_->CreateAnd(__VA_ARGS__) | ||||
| #define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__) | ||||
| #define atomic_rmw(...)      builder_->CreateAtomicRMW(__VA_ARGS__) | ||||
| #define bin_op(...)          builder_->CreateBinOp(__VA_ARGS__) | ||||
| #define bit_cast(...)        builder_->CreateBitCast(__VA_ARGS__) | ||||
| #define br(...)              builder_->CreateBr(__VA_ARGS__) | ||||
| #define call(...)            builder_->CreateCall(__VA_ARGS__) | ||||
| #define cast(...)            builder_->CreateCast(__VA_ARGS__) | ||||
| #define cond_br(...)         builder_->CreateCondBr(__VA_ARGS__) | ||||
| #define exact_udiv(...)      builder_->CreateExactUDiv(__VA_ARGS__) | ||||
| #define extract_elt(...)     builder_->CreateExtractElement(__VA_ARGS__) | ||||
| #define extract_val(...)     builder_->CreateExtractValue(__VA_ARGS__) | ||||
| #define fadd(...)            builder_->CreateFAdd(__VA_ARGS__) | ||||
| #define fcmp(...)            builder_->CreateFCmp(__VA_ARGS__) | ||||
| #define fmul(...)            builder_->CreateFMul(__VA_ARGS__) | ||||
| #define fpcast(...)          builder_->CreateFPCast(__VA_ARGS__) | ||||
| #define fsub(...)            builder_->CreateFSub(__VA_ARGS__) | ||||
| #define icmp(...)            builder_->CreateICmp(__VA_ARGS__) | ||||
| #define icmp_eq(...)         builder_->CreateICmpEQ(__VA_ARGS__) | ||||
| #define icmp_sge(...)        builder_->CreateICmpSGE(__VA_ARGS__) | ||||
| #define icmp_sle(...)        builder_->CreateICmpSLE(__VA_ARGS__) | ||||
| #define icmp_ult(...)        builder_->CreateICmpULT(__VA_ARGS__) | ||||
| #define insert_elt(...)      builder_->CreateInsertElement(__VA_ARGS__) | ||||
| #define intrinsic(...)       builder_->CreateIntrinsic(__VA_ARGS__) | ||||
| #define load(...)            builder_->CreateLoad(__VA_ARGS__) | ||||
| #define lshr(...)            builder_->CreateLShr(__VA_ARGS__) | ||||
| #define max_num(...)         builder_->CreateMaxNum(__VA_ARGS__) | ||||
| #define min_num(...)         builder_->CreateMinNum(__VA_ARGS__) | ||||
| #define neg(...)             builder_->CreateNeg(__VA_ARGS__) | ||||
| #define phi(...)             builder_->CreatePHI(__VA_ARGS__) | ||||
| #define ret(...)             builder_->CreateRet(__VA_ARGS__) | ||||
| #define select(...)          builder_->CreateSelect(__VA_ARGS__) | ||||
| #define store(...)           builder_->CreateStore(__VA_ARGS__) | ||||
| #define sub(...)             builder_->CreateSub(__VA_ARGS__) | ||||
| #define shl(...)             builder_->CreateShl(__VA_ARGS__) | ||||
| #define udiv(...)            builder_->CreateUDiv(__VA_ARGS__) | ||||
| #define urem(...)            builder_->CreateURem(__VA_ARGS__) | ||||
| #define splat(...)           builder_->CreateVectorSplat(__VA_ARGS__) | ||||
| #define xor_(...)            builder_->CreateXor(__VA_ARGS__) | ||||
|  | ||||
| } // namespace triton::codegen | ||||
| @@ -1,6 +1,7 @@ | ||||
| #include <numeric> | ||||
| #include <sstream> | ||||
| #include <iomanip> | ||||
| #include <stdexcept> | ||||
| #include "triton/codegen/selection/generator.h" | ||||
| #include "triton/codegen/target.h" | ||||
| #include "triton/codegen/analysis/axes.h" | ||||
| @@ -1355,9 +1356,6 @@ public: | ||||
|     need_trans_ = k_order_ != order_[0]; | ||||
|     can_use_ldmatrix_ = dtsize == 2 || (!need_trans_); | ||||
|  | ||||
|     // std::cout << can_use_ldmatrix_ << std::endl; | ||||
|     // std::cout << need_trans_ << std::endl; | ||||
|  | ||||
|     // we need more pointers at the fast-changing axis,  | ||||
|     if (can_use_ldmatrix_) | ||||
|       num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]]; | ||||
| @@ -1365,6 +1363,9 @@ public: | ||||
|       num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]]; | ||||
|     num_ptr_ = std::max<int>(num_ptr_, 2); | ||||
|  | ||||
|     // special rule for i8/u8, 4 ptrs for each matrix | ||||
|     if (!can_use_ldmatrix_ && dtsize_ == 1) | ||||
|       num_ptr_ *= 4; | ||||
|  | ||||
|     // load_v4 stride (in num of mats) | ||||
|     int load_stride_in_mat[2]; | ||||
| @@ -1445,6 +1446,46 @@ public: | ||||
|       } | ||||
|       return offs; | ||||
|       // throw std::runtime_error("not implemented"); | ||||
|     } else if (dtsize_ == 1 && need_trans_) { | ||||
|       // load i8/u8 matrices with lds8 | ||||
|       Value *c_off_in_mat = udiv(lane, i32(4)); // | ||||
|       Value *s_off_in_mat = mul(urem(lane, i32(4)), i32(4)); // each thread load 4 cols | ||||
|  | ||||
|       // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); | ||||
|       std::vector<Value*> offs(num_ptr_); | ||||
|       for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time | ||||
|         int k_mat_arr_int  = (k_order_ == 1) ? mat/2 : mat%2; | ||||
|         int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2; | ||||
|         if (k_mat_arr_int > 0) // we don't need pointers for k | ||||
|           continue; | ||||
|         Value *k_mat_arr  = i32(k_mat_arr_int); | ||||
|         Value *nk_mat_arr = i32(nk_mat_arr_int); | ||||
|         // physical offset (before swizzling) | ||||
|         Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)), | ||||
|                                mul(nk_mat_arr, i32(mat_arr_stride_))); | ||||
|         Value *s_mat_off = k_mat_arr; // always 0? | ||||
|          | ||||
|         for (int loadx4_off = 0; loadx4_off < num_ptr_/8; ++loadx4_off) { | ||||
|           for (int elem_off = 0; elem_off < 4; ++elem_off) { | ||||
|             int ptr_off = loadx4_off*8 + nk_mat_arr_int*4 + elem_off; | ||||
|  | ||||
|             Value *c_mat_off_i = add(c_mat_off, i32(loadx4_off*p_load_stride_in_mat_*(k_order_ == 1?1:2))); | ||||
|             Value *s_off_in_mat_elem = add(s_off_in_mat, i32(elem_off)); | ||||
|  | ||||
|             // disable swizzling ... | ||||
|             // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); | ||||
|             // c_mat_off_i = xor_(c_mat_off_i, phase); | ||||
|  | ||||
|             Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_))); | ||||
|             Value *s_off = add(s_off_in_mat_elem, mul(s_mat_off, i32(s_mat_shape_))); | ||||
|             // To prevent out-of-bound access when the tile is too small | ||||
|             c_off = urem(c_off, i32(tile_shape_[order_[0]])); | ||||
|             s_off = urem(s_off, i32(tile_shape_[order_[1]])); | ||||
|             offs[ptr_off] = add(c_off, mul(s_off, i32(s_stride_))); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|       return offs; | ||||
|     } else | ||||
|       throw std::runtime_error("invalid smem load config"); | ||||
|   } | ||||
| @@ -1461,8 +1502,10 @@ public: | ||||
|     int ptr_idx = -1; | ||||
|     if (can_use_ldmatrix_) | ||||
|       ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]); | ||||
|     else // tf32 & trans | ||||
|     else if (dtsize_ == 4 && need_trans_) // tf32 & trans | ||||
|       ptr_idx = mat_idx[order_[0]]; | ||||
|     else // i8 & trans | ||||
|       ptr_idx = mat_idx[order_[0]] * 4; | ||||
|  | ||||
|     auto get_ptr = [&](int idx) -> Value* { | ||||
|       Value *ptr = nullptr; | ||||
| @@ -1495,9 +1538,7 @@ public: | ||||
|               extract_val(res_v4, std::vector<unsigned>{1}), | ||||
|               extract_val(res_v4, std::vector<unsigned>{2}), | ||||
|               extract_val(res_v4, std::vector<unsigned>{3})}; | ||||
|     } else { | ||||
|       // assert(false && "should not be here"); | ||||
|       assert(dtsize_ == 4 && need_trans_); | ||||
|     } else if (dtsize_ == 4 && need_trans_) { // use lds.32 to load tf32 matrices | ||||
|       Value *ptr2 = get_ptr(ptr_idx+1); | ||||
|       assert(s_mat_stride_ == 1); | ||||
|       int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; | ||||
| @@ -1521,7 +1562,96 @@ public: | ||||
|         prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3); | ||||
|       } | ||||
|       return {elem0, elem1, elem2, elem3}; | ||||
|     } | ||||
|     } else if (dtsize_ == 1 && need_trans_) { // use lds.8 to load i8/u8 matrices | ||||
|       Value *ptr00 = get_ptr(ptr_idx); | ||||
|       Value *ptr01 = get_ptr(ptr_idx+1); | ||||
|       Value *ptr02 = get_ptr(ptr_idx+2); | ||||
|       Value *ptr03 = get_ptr(ptr_idx+3); | ||||
|  | ||||
|       Value *ptr10 = get_ptr(ptr_idx+4); | ||||
|       Value *ptr11 = get_ptr(ptr_idx+5); | ||||
|       Value *ptr12 = get_ptr(ptr_idx+6); | ||||
|       Value *ptr13 = get_ptr(ptr_idx+7); | ||||
|  | ||||
|       assert(s_mat_stride_ == 1); | ||||
|       int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; | ||||
|       int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_; | ||||
|  | ||||
|       Value *i8v4_elems[4]; | ||||
|       Value *i32_elems[4]; | ||||
|       for (int i=0; i<4; ++i) | ||||
|         i8v4_elems[i] = UndefValue::get(vec_ty(i8_ty, 4)); | ||||
|        | ||||
|       Value *elem00, *elem01, *elem02, *elem03; | ||||
|       Value *elem10, *elem11, *elem12, *elem13; | ||||
|       Value *elem20, *elem21, *elem22, *elem23; | ||||
|       Value *elem30, *elem31, *elem32, *elem33; | ||||
|       Value *i8_elems[4*4]; | ||||
|       if (k_order_ == 1) { //  | ||||
|         i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem))); | ||||
|         i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem))); | ||||
|         i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem))); | ||||
|         i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem))); | ||||
|  | ||||
|         assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8)); | ||||
|  | ||||
|         i8_elems[1*4 + 0] = load(gep(ptr10, i32(s_offset_elem))); | ||||
|         i8_elems[1*4 + 1] = load(gep(ptr11, i32(s_offset_elem))); | ||||
|         i8_elems[1*4 + 2] = load(gep(ptr12, i32(s_offset_elem))); | ||||
|         i8_elems[1*4 + 3] = load(gep(ptr13, i32(s_offset_elem))); | ||||
|  | ||||
|         i8_elems[2*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|         i8_elems[2*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|         i8_elems[2*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|         i8_elems[2*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|  | ||||
|         i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|         i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|         i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|         i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|  | ||||
|         for (int m=0; m<4; ++m) { | ||||
|           for (int e=0; e<4; ++e) | ||||
|             i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e); | ||||
|           i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty); | ||||
|         } | ||||
|       } else { // for b (k first) | ||||
|         i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem))); | ||||
|         i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem))); | ||||
|         i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem))); | ||||
|         i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem))); | ||||
|  | ||||
|         assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8)); | ||||
|  | ||||
|         i8_elems[2*4 + 0] = load(gep(ptr10, i32(s_offset_elem))); | ||||
|         i8_elems[2*4 + 1] = load(gep(ptr11, i32(s_offset_elem))); | ||||
|         i8_elems[2*4 + 2] = load(gep(ptr12, i32(s_offset_elem))); | ||||
|         i8_elems[2*4 + 3] = load(gep(ptr13, i32(s_offset_elem))); | ||||
|  | ||||
|         i8_elems[1*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|         i8_elems[1*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|         i8_elems[1*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|         i8_elems[1*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|  | ||||
|         i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|         i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|         i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|         i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem))); | ||||
|  | ||||
|         for (int m=0; m<4; ++m) { | ||||
|           for (int e=0; e<4; ++e) | ||||
|             i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e); | ||||
|           i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty); | ||||
|         } | ||||
|       } | ||||
|       if (k == 0 && inc == 1 && is_prefetch) { | ||||
|         for (int m = 0; m < 4; ++m) | ||||
|         for (int e = 0; e < 4; ++e) | ||||
|           prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(i8_elems[m*4 + e]); | ||||
|       } | ||||
|       return {i32_elems[0], i32_elems[1], i32_elems[2], i32_elems[3]}; | ||||
|     } else | ||||
|       throw std::runtime_error("invalid smem load"); | ||||
|   } | ||||
|  | ||||
|   int get_num_ptr() const { return num_ptr_; } | ||||
| @@ -1596,12 +1726,18 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: | ||||
|   const int num_rep_n = shapes[1] / layout->shape_per_cta(1); | ||||
|   const int num_rep_k = std::max<int>(NK/mma_instr_k, 1); | ||||
|  | ||||
|   // floating point types | ||||
|   Type *fp32_ty = f32_ty; | ||||
|   Type *fp16x2_ty = vec_ty(f16_ty, 2); | ||||
|   Type *bf16x2_ty = vec_ty(bf16_ty, 2); | ||||
|   Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); | ||||
|   Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty}); | ||||
|   Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty}); | ||||
|   // integer types | ||||
|   Type *i8x4_ty = vec_ty(i8_ty, 4); | ||||
|   Type *i8x4_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty}); | ||||
|   Type *i32_pack4_ty  = StructType::get(*ctx_, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty}); | ||||
|  | ||||
|  | ||||
|   FunctionType *ldmatrix_ty = nullptr; | ||||
|   FunctionType *mma_ty = nullptr; | ||||
| @@ -1630,6 +1766,16 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: | ||||
|     smem_ptr_ty = ptr_ty(fp32_ty, 3); | ||||
|     ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false); | ||||
|     phi_ty = fp32_ty; | ||||
|   } else if (A_ir_ty->is_integer_ty(8) && B_ir_ty->is_integer_ty(8)) { | ||||
|     // FIXME: We should use i8 here (but nvptx will generate extra casts when using i8) | ||||
|     mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false); | ||||
|     smem_ptr_ty = ptr_ty(i8_ty, 3); | ||||
|     ldmatrix_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false); | ||||
|     phi_ty = i32_ty; | ||||
|     // mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false); | ||||
|     // smem_ptr_ty = ptr_ty(i8_ty, 3); | ||||
|     // ldmatrix_ty = FunctionType::get(i8x4_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false); | ||||
|     // phi_ty = i8x4_ty; | ||||
|   } else | ||||
|     throw std::runtime_error("mma16816 data type not supported"); | ||||
|  | ||||
| @@ -1690,7 +1836,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: | ||||
|                                              " {$4, $5, $6, $7}," | ||||
|                                              " {$8, $9}," | ||||
|                                              " {$10, $11, $12, $13};", | ||||
|                                              "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true); | ||||
|                                              "=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true); | ||||
|  | ||||
|   // create mma & unpack result, m, n, k are offsets in mat | ||||
|   auto call_mma = [&](unsigned m, unsigned n, unsigned k) { | ||||
| @@ -1715,12 +1861,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: | ||||
|   ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B); | ||||
|  | ||||
|   auto register_lds2 = | ||||
|     [&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int n, int k, int inc, Value* val, bool is_prefetch) { | ||||
|     [&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) { | ||||
|       if (k < 2 && is_prefetch) { | ||||
|         ir::basic_block* inc_block = phiA->get_incoming_block(inc); | ||||
|         lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{n, k}], val, inc_block)); | ||||
|         lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block)); | ||||
|       } else | ||||
|         vals[{n, k}] = val; | ||||
|         vals[{mn, k}] = val; | ||||
|   }; | ||||
|  | ||||
|   auto load_a = [&](int m, int k, int inc, bool is_prefetch) { | ||||
| @@ -1922,7 +2068,10 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { | ||||
|     return visit_mma884(dot, A, B, D, NK); | ||||
|   if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80) | ||||
|     return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()? | ||||
|   return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); | ||||
|   if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&  | ||||
|       A->get_type()->get_scalar_ty()->is_fp32_ty()) | ||||
|     return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); | ||||
|   throw std::runtime_error("dot has invalid operand type"); | ||||
| } | ||||
|  | ||||
| void generator::visit_trans_inst(ir::trans_inst* trans) { | ||||
|   | ||||
| @@ -61,7 +61,8 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ | ||||
|   // dot(a, b, c) + d -> dot(a, b, c + d) | ||||
|   // d + dot(a, b, c) -> dot(a, b, c + d) | ||||
|   auto add = dynamic_cast<ir::binary_operator*>(value); | ||||
|   if(add && add->get_op() == ir::binary_op_t::FAdd) { | ||||
|   if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) { | ||||
|     bool is_int_dot = add->get_op() == ir::binary_op_t::Add; | ||||
|     ir::value *lhs = add->get_operand(0); | ||||
|     ir::value *rhs = add->get_operand(1); | ||||
|     ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs); | ||||
| @@ -72,11 +73,17 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ | ||||
|     ir::value *other = (dot == lhs) ? rhs : lhs; | ||||
|     ir::value *acc = dot->get_operand(2); | ||||
|     ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc); | ||||
|     ir::constant_fp *_0 = nullptr; | ||||
|     ir::constant *_0 = nullptr; | ||||
|     if(splat) | ||||
|       _0 = dynamic_cast<ir::constant_fp*>(splat->get_operand(0)); | ||||
|     if(!(_0 && _0->get_value() == 0.0)) | ||||
|       _0 = dynamic_cast<ir::constant*>(splat->get_operand(0)); | ||||
|     if(!_0) | ||||
|       return false; | ||||
|     if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0)) | ||||
|       if (fp_0->get_value() != 0.0) | ||||
|         return false; | ||||
|     if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0)) | ||||
|       if (int_0->get_value() != 0) | ||||
|         return false; | ||||
|     ir::value *a = dot->get_operand(0); | ||||
|     ir::value *b = dot->get_operand(1); | ||||
|     builder.set_insert_point(add); | ||||
|   | ||||
| @@ -33,6 +33,9 @@ void prefetch::run(ir::module &mod) { | ||||
|       if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() || | ||||
|             dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() || | ||||
|             (dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32() | ||||
|              && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ||  | ||||
|             (dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8) | ||||
|              && dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8) | ||||
|              && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) | ||||
|            ) | ||||
|          ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user