From 59d371c6eb3e8f904c9bc879e5df07b0b79131ad Mon Sep 17 00:00:00 2001 From: daadaada Date: Fri, 28 Jan 2022 01:12:44 +0800 Subject: [PATCH] [BACKEND] Added Int8 mma (#440) --- include/triton/codegen/analysis/layout.h | 2 + lib/codegen/analysis/layout.cc | 24 +++- lib/codegen/analysis/swizzle.cc | 6 + lib/codegen/selection/common.h | 78 ---------- lib/codegen/selection/generator.cc | 175 +++++++++++++++++++++-- lib/codegen/transform/peephole.cc | 15 +- lib/codegen/transform/prefetch.cc | 3 + lib/driver/llvm.cc | 6 +- lib/ir/dispatch.cc | 6 +- python/setup.py | 2 +- python/test/unit/language/test_core.py | 30 ++-- 11 files changed, 232 insertions(+), 115 deletions(-) delete mode 100644 lib/codegen/selection/common.h diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index dc5150f05..b6376d7cc 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -268,6 +268,7 @@ public: void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; } int get_mma_vec() { return mma_vec_;} int get_mma_strided() { return mma_strided_; } + bool allow_swizzle() const { return allow_swizzle_; } data_layout* get_arg_layout() { return arg_layout_; } private: @@ -281,6 +282,7 @@ private: data_layout* arg_layout_; int mma_vec_; int mma_strided_; + bool allow_swizzle_ = true; target *tgt_; }; diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index d00959e45..2206f5b6a 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -33,7 +33,9 @@ inline bool is_hmma_c(ir::value *v, int sm){ result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) || (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) || (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() && - x->allow_tf32() && sm >= 80); + x->allow_tf32() && sm >= 80) || + (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) && + sm >= 80); } return result; } @@ -63,7 +65,7 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) { return mma_type; } } else if (c_ty->get_scalar_ty()->is_integer_ty(32)) { - throw std::runtime_error("integer tensor cores are not yet supported"); + // throw std::runtime_error("integer tensor cores are not yet supported"); // // integer tensor cores // if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) { // mma_type = mma_layout::INT32_INT1_INT1_INT32; @@ -73,10 +75,10 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) { // mma_type = mma_layout::INT32_INT4_INT4_INT32; // return mma_type; // } - // if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) { - // mma_type = mma_layout::INT32_INT8_INT8_INT32; - // return mma_type; - // } + if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) { + mma_type = mma_layout::INT32_INT8_INT8_INT32; + return mma_type; + } } } return mma_layout::NOT_APPLICABLE; @@ -444,11 +446,21 @@ shared_layout::shared_layout(data_layout *arg, std::vector mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_)); mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2]; + + // for now, disable swizzle when using lds.8 + if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32) + if (order_[0] == 0) // need transpose + allow_swizzle_ = false; } else if (hmma_dot_b_) { assert(order_.size() == 2); std::vector mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_)); mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1]; + + // for now, disable swizzle when using lds.8 + if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32) + if (order_[0] == 1) // need transpose + allow_swizzle_ = false; } // size diff --git a/lib/codegen/analysis/swizzle.cc b/lib/codegen/analysis/swizzle.cc index 1dbae10d4..5737f80a0 100644 --- a/lib/codegen/analysis/swizzle.cc +++ b/lib/codegen/analysis/swizzle.cc @@ -41,9 +41,15 @@ void swizzle::run(ir::module &) { vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); } else { + if (!layout->allow_swizzle()) { + per_phase_[layout] = 1; + max_phase_[layout] = 1; + vec_[layout] = 1; + } else { per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout]; vec_[layout] = layout->get_mma_vec(); + } } } } diff --git a/lib/codegen/selection/common.h b/lib/codegen/selection/common.h deleted file mode 100644 index c4b0951da..000000000 --- a/lib/codegen/selection/common.h +++ /dev/null @@ -1,78 +0,0 @@ -#pragma once - -#include -#include -#include -#include "triton/codegen/selection/generator.h" -#include "triton/codegen/target.h" -#include "triton/codegen/analysis/axes.h" -#include "triton/codegen/analysis/allocation.h" -#include "triton/codegen/analysis/align.h" -#include "triton/codegen/analysis/swizzle.h" -#include "triton/codegen/transform/coalesce.h" -#include "triton/ir/context.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/type.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/IntrinsicsNVPTX.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Attributes.h" -#include "llvm/IR/InlineAsm.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" - -namespace triton::codegen { -// types -#define void_ty builder_->getVoidTy() -#define f16_ty builder_->getHalfTy() -#define bf16_ty builder_->getBFloatTy() -#define f32_ty builder_->getFloatTy() -#define i8_ty builder_->getInt8Ty() -#define i32_ty builder_->getInt32Ty() -#define vec_ty(type, num_el) VectorType::get(type, num_el, false) -#define ptr_ty(...) PointerType::get(__VA_ARGS__) -// constants -#define i32(...) builder_->getInt32(__VA_ARGS__) -// ops -#define and_(...) builder_->CreateAnd(__VA_ARGS__) -#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__) -#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__) -#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__) -#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__) -#define br(...) builder_->CreateBr(__VA_ARGS__) -#define call(...) builder_->CreateCall(__VA_ARGS__) -#define cast(...) builder_->CreateCast(__VA_ARGS__) -#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__) -#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__) -#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__) -#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__) -#define fadd(...) builder_->CreateFAdd(__VA_ARGS__) -#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__) -#define fmul(...) builder_->CreateFMul(__VA_ARGS__) -#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__) -#define fsub(...) builder_->CreateFSub(__VA_ARGS__) -#define icmp(...) builder_->CreateICmp(__VA_ARGS__) -#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__) -#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__) -#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__) -#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__) -#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) -#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) -#define load(...) builder_->CreateLoad(__VA_ARGS__) -#define lshr(...) builder_->CreateLShr(__VA_ARGS__) -#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__) -#define min_num(...) builder_->CreateMinNum(__VA_ARGS__) -#define neg(...) builder_->CreateNeg(__VA_ARGS__) -#define phi(...) builder_->CreatePHI(__VA_ARGS__) -#define ret(...) builder_->CreateRet(__VA_ARGS__) -#define select(...) builder_->CreateSelect(__VA_ARGS__) -#define store(...) builder_->CreateStore(__VA_ARGS__) -#define sub(...) builder_->CreateSub(__VA_ARGS__) -#define shl(...) builder_->CreateShl(__VA_ARGS__) -#define udiv(...) builder_->CreateUDiv(__VA_ARGS__) -#define urem(...) builder_->CreateURem(__VA_ARGS__) -#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__) -#define xor_(...) builder_->CreateXor(__VA_ARGS__) - -} // namespace triton::codegen \ No newline at end of file diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index b180ecb12..a55991475 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1,6 +1,7 @@ #include #include #include +#include #include "triton/codegen/selection/generator.h" #include "triton/codegen/target.h" #include "triton/codegen/analysis/axes.h" @@ -1355,9 +1356,6 @@ public: need_trans_ = k_order_ != order_[0]; can_use_ldmatrix_ = dtsize == 2 || (!need_trans_); - // std::cout << can_use_ldmatrix_ << std::endl; - // std::cout << need_trans_ << std::endl; - // we need more pointers at the fast-changing axis, if (can_use_ldmatrix_) num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]]; @@ -1365,6 +1363,9 @@ public: num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]]; num_ptr_ = std::max(num_ptr_, 2); + // special rule for i8/u8, 4 ptrs for each matrix + if (!can_use_ldmatrix_ && dtsize_ == 1) + num_ptr_ *= 4; // load_v4 stride (in num of mats) int load_stride_in_mat[2]; @@ -1445,6 +1446,46 @@ public: } return offs; // throw std::runtime_error("not implemented"); + } else if (dtsize_ == 1 && need_trans_) { + // load i8/u8 matrices with lds8 + Value *c_off_in_mat = udiv(lane, i32(4)); // + Value *s_off_in_mat = mul(urem(lane, i32(4)), i32(4)); // each thread load 4 cols + + // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); + std::vector offs(num_ptr_); + for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time + int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2; + int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2; + if (k_mat_arr_int > 0) // we don't need pointers for k + continue; + Value *k_mat_arr = i32(k_mat_arr_int); + Value *nk_mat_arr = i32(nk_mat_arr_int); + // physical offset (before swizzling) + Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)), + mul(nk_mat_arr, i32(mat_arr_stride_))); + Value *s_mat_off = k_mat_arr; // always 0? + + for (int loadx4_off = 0; loadx4_off < num_ptr_/8; ++loadx4_off) { + for (int elem_off = 0; elem_off < 4; ++elem_off) { + int ptr_off = loadx4_off*8 + nk_mat_arr_int*4 + elem_off; + + Value *c_mat_off_i = add(c_mat_off, i32(loadx4_off*p_load_stride_in_mat_*(k_order_ == 1?1:2))); + Value *s_off_in_mat_elem = add(s_off_in_mat, i32(elem_off)); + + // disable swizzling ... + // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); + // c_mat_off_i = xor_(c_mat_off_i, phase); + + Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_))); + Value *s_off = add(s_off_in_mat_elem, mul(s_mat_off, i32(s_mat_shape_))); + // To prevent out-of-bound access when the tile is too small + c_off = urem(c_off, i32(tile_shape_[order_[0]])); + s_off = urem(s_off, i32(tile_shape_[order_[1]])); + offs[ptr_off] = add(c_off, mul(s_off, i32(s_stride_))); + } + } + } + return offs; } else throw std::runtime_error("invalid smem load config"); } @@ -1461,8 +1502,10 @@ public: int ptr_idx = -1; if (can_use_ldmatrix_) ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]); - else // tf32 & trans + else if (dtsize_ == 4 && need_trans_) // tf32 & trans ptr_idx = mat_idx[order_[0]]; + else // i8 & trans + ptr_idx = mat_idx[order_[0]] * 4; auto get_ptr = [&](int idx) -> Value* { Value *ptr = nullptr; @@ -1495,9 +1538,7 @@ public: extract_val(res_v4, std::vector{1}), extract_val(res_v4, std::vector{2}), extract_val(res_v4, std::vector{3})}; - } else { - // assert(false && "should not be here"); - assert(dtsize_ == 4 && need_trans_); + } else if (dtsize_ == 4 && need_trans_) { // use lds.32 to load tf32 matrices Value *ptr2 = get_ptr(ptr_idx+1); assert(s_mat_stride_ == 1); int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; @@ -1521,7 +1562,96 @@ public: prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3); } return {elem0, elem1, elem2, elem3}; - } + } else if (dtsize_ == 1 && need_trans_) { // use lds.8 to load i8/u8 matrices + Value *ptr00 = get_ptr(ptr_idx); + Value *ptr01 = get_ptr(ptr_idx+1); + Value *ptr02 = get_ptr(ptr_idx+2); + Value *ptr03 = get_ptr(ptr_idx+3); + + Value *ptr10 = get_ptr(ptr_idx+4); + Value *ptr11 = get_ptr(ptr_idx+5); + Value *ptr12 = get_ptr(ptr_idx+6); + Value *ptr13 = get_ptr(ptr_idx+7); + + assert(s_mat_stride_ == 1); + int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; + int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_; + + Value *i8v4_elems[4]; + Value *i32_elems[4]; + for (int i=0; i<4; ++i) + i8v4_elems[i] = UndefValue::get(vec_ty(i8_ty, 4)); + + Value *elem00, *elem01, *elem02, *elem03; + Value *elem10, *elem11, *elem12, *elem13; + Value *elem20, *elem21, *elem22, *elem23; + Value *elem30, *elem31, *elem32, *elem33; + Value *i8_elems[4*4]; + if (k_order_ == 1) { // + i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem))); + i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem))); + i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem))); + i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem))); + + assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8)); + + i8_elems[1*4 + 0] = load(gep(ptr10, i32(s_offset_elem))); + i8_elems[1*4 + 1] = load(gep(ptr11, i32(s_offset_elem))); + i8_elems[1*4 + 2] = load(gep(ptr12, i32(s_offset_elem))); + i8_elems[1*4 + 3] = load(gep(ptr13, i32(s_offset_elem))); + + i8_elems[2*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[2*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[2*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[2*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem))); + + i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem))); + + for (int m=0; m<4; ++m) { + for (int e=0; e<4; ++e) + i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e); + i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty); + } + } else { // for b (k first) + i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem))); + i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem))); + i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem))); + i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem))); + + assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8)); + + i8_elems[2*4 + 0] = load(gep(ptr10, i32(s_offset_elem))); + i8_elems[2*4 + 1] = load(gep(ptr11, i32(s_offset_elem))); + i8_elems[2*4 + 2] = load(gep(ptr12, i32(s_offset_elem))); + i8_elems[2*4 + 3] = load(gep(ptr13, i32(s_offset_elem))); + + i8_elems[1*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[1*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[1*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[1*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem))); + + i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem))); + i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem))); + + for (int m=0; m<4; ++m) { + for (int e=0; e<4; ++e) + i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e); + i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty); + } + } + if (k == 0 && inc == 1 && is_prefetch) { + for (int m = 0; m < 4; ++m) + for (int e = 0; e < 4; ++e) + prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(i8_elems[m*4 + e]); + } + return {i32_elems[0], i32_elems[1], i32_elems[2], i32_elems[3]}; + } else + throw std::runtime_error("invalid smem load"); } int get_num_ptr() const { return num_ptr_; } @@ -1596,12 +1726,18 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: const int num_rep_n = shapes[1] / layout->shape_per_cta(1); const int num_rep_k = std::max(NK/mma_instr_k, 1); + // floating point types Type *fp32_ty = f32_ty; Type *fp16x2_ty = vec_ty(f16_ty, 2); Type *bf16x2_ty = vec_ty(bf16_ty, 2); Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty}); Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty}); + // integer types + Type *i8x4_ty = vec_ty(i8_ty, 4); + Type *i8x4_pack4_ty = StructType::get(*ctx_, std::vector{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty}); + Type *i32_pack4_ty = StructType::get(*ctx_, std::vector{i32_ty, i32_ty, i32_ty, i32_ty}); + FunctionType *ldmatrix_ty = nullptr; FunctionType *mma_ty = nullptr; @@ -1630,6 +1766,16 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: smem_ptr_ty = ptr_ty(fp32_ty, 3); ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector{smem_ptr_ty}, false); phi_ty = fp32_ty; + } else if (A_ir_ty->is_integer_ty(8) && B_ir_ty->is_integer_ty(8)) { + // FIXME: We should use i8 here (but nvptx will generate extra casts when using i8) + mma_ty = FunctionType::get(i32_pack4_ty, std::vector{i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false); + smem_ptr_ty = ptr_ty(i8_ty, 3); + ldmatrix_ty = FunctionType::get(i32_pack4_ty, std::vector{smem_ptr_ty}, false); + phi_ty = i32_ty; + // mma_ty = FunctionType::get(i32_pack4_ty, std::vector{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false); + // smem_ptr_ty = ptr_ty(i8_ty, 3); + // ldmatrix_ty = FunctionType::get(i8x4_pack4_ty, std::vector{smem_ptr_ty}, false); + // phi_ty = i8x4_ty; } else throw std::runtime_error("mma16816 data type not supported"); @@ -1690,7 +1836,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: " {$4, $5, $6, $7}," " {$8, $9}," " {$10, $11, $12, $13};", - "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true); + "=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true); // create mma & unpack result, m, n, k are offsets in mat auto call_mma = [&](unsigned m, unsigned n, unsigned k) { @@ -1715,12 +1861,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: ir::phi_node* phiB = dynamic_cast(B); auto register_lds2 = - [&](std::map, Value*>& vals, int n, int k, int inc, Value* val, bool is_prefetch) { + [&](std::map, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) { if (k < 2 && is_prefetch) { ir::basic_block* inc_block = phiA->get_incoming_block(inc); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{n, k}], val, inc_block)); + lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block)); } else - vals[{n, k}] = val; + vals[{mn, k}] = val; }; auto load_a = [&](int m, int k, int inc, bool is_prefetch) { @@ -1922,7 +2068,10 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { return visit_mma884(dot, A, B, D, NK); if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80) return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()? - return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); + if (dot->get_type()->get_scalar_ty()->is_fp32_ty() && + A->get_type()->get_scalar_ty()->is_fp32_ty()) + return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); + throw std::runtime_error("dot has invalid operand type"); } void generator::visit_trans_inst(ir::trans_inst* trans) { diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index bae8fe828..b381d3cb0 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -61,7 +61,8 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ // dot(a, b, c) + d -> dot(a, b, c + d) // d + dot(a, b, c) -> dot(a, b, c + d) auto add = dynamic_cast(value); - if(add && add->get_op() == ir::binary_op_t::FAdd) { + if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) { + bool is_int_dot = add->get_op() == ir::binary_op_t::Add; ir::value *lhs = add->get_operand(0); ir::value *rhs = add->get_operand(1); ir::dot_inst *lhs_dot = dynamic_cast(lhs); @@ -72,11 +73,17 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ ir::value *other = (dot == lhs) ? rhs : lhs; ir::value *acc = dot->get_operand(2); ir::splat_inst *splat = dynamic_cast(acc); - ir::constant_fp *_0 = nullptr; + ir::constant *_0 = nullptr; if(splat) - _0 = dynamic_cast(splat->get_operand(0)); - if(!(_0 && _0->get_value() == 0.0)) + _0 = dynamic_cast(splat->get_operand(0)); + if(!_0) return false; + if (auto *fp_0 = dynamic_cast(_0)) + if (fp_0->get_value() != 0.0) + return false; + if (auto *int_0 = dynamic_cast(_0)) + if (int_0->get_value() != 0) + return false; ir::value *a = dot->get_operand(0); ir::value *b = dot->get_operand(1); builder.set_insert_point(add); diff --git a/lib/codegen/transform/prefetch.cc b/lib/codegen/transform/prefetch.cc index f94d8f368..30b2a10f2 100644 --- a/lib/codegen/transform/prefetch.cc +++ b/lib/codegen/transform/prefetch.cc @@ -33,6 +33,9 @@ void prefetch::run(ir::module &mod) { if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() || dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() || (dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32() + && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) || + (dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8) + && dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8) && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ) ) diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 910f23a1a..ada5459d6 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -46,6 +46,7 @@ #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Scalar.h" // begin AMD stuff #include "llvm/Support/FileSystem.h" @@ -121,9 +122,12 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ init_llvm(); // verify and store llvm llvm::legacy::PassManager pm; - // module->print(llvm::outs(), nullptr); pm.add(llvm::createVerifierPass()); + // pm.add(llvm::createDeadCodeEliminationPass()); + // pm.add(llvm::createEarlyCSEPass()); pm.run(*module); + // module->print(llvm::outs(), nullptr); + // create machine module->setTargetTriple(triple); std::string error; diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 69c76b5e5..e3168b375 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -726,7 +726,11 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask //===----------------------------------------------------------------------===// ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) { - ir::value *_0 = builder->get_float32(0); + ir::value *_0 = nullptr; + if (lhs->get_type()->is_int_or_tileint_ty()) + _0 = builder->get_int32(0); + else + _0 = builder->get_float32(0); unsigned M = lhs->get_type()->get_block_shapes()[0]; unsigned N = rhs->get_type()->get_block_shapes()[1]; _0 = builder->create_splat(_0, {M, N}); diff --git a/python/setup.py b/python/setup.py index db22c14af..86e3e5160 100644 --- a/python/setup.py +++ b/python/setup.py @@ -77,7 +77,7 @@ class CMakeBuild(build_ext): def build_extension(self, ext): llvm_include_dir, llvm_library_dir = get_llvm() - # self.debug = True + self.debug = True extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories build_suffix = 'debug' if self.debug else 'release' diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 4d6e32aa6..3e35700f8 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -661,11 +661,20 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # --------------- -@pytest.mark.parametrize("epilogue, allow_tf32", - [(epilogue, allow_tf32) +@pytest.mark.parametrize("epilogue, allow_tf32, dtype", + [(epilogue, allow_tf32, dtype) for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] - for allow_tf32 in [True, False]]) -def test_dot(epilogue, allow_tf32, device='cuda'): + for allow_tf32 in [True, False] + for dtype in ['float32', 'int8'] + if not (allow_tf32 and (dtype == 'int8'))]) +def test_dot(epilogue, allow_tf32, dtype, device='cuda'): + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 80: + if dtype == 'int8': + pytest.skip("Only test int8 on devices with sm >= 80") + elif dtype == 'float32' and allow_tf32: + pytest.skip("Only test tf32 on devices with sm >= 80") + # triton kernel @triton.jit def kernel(X, stride_xm, stride_xk, @@ -693,18 +702,15 @@ def test_dot(epilogue, allow_tf32, device='cuda'): # input M, N, K = 64, 64, 32 rs = RandomState(17) - x = numpy_random((M, K), dtype_str='float32', rs=rs) - y = numpy_random((K, N), dtype_str='float32', rs=rs) + x = numpy_random((M, K), dtype_str=dtype, rs=rs) + y = numpy_random((K, N), dtype_str=dtype, rs=rs) if allow_tf32: - cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) - if cc < 80: - pytest.skip("Only test tf32 on devices with sm >= 80") x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') x_tri = to_triton(x, device=device) y_tri = to_triton(y, device=device) # triton result - z = numpy_random((M, N), dtype_str='float32', rs=rs) + z = numpy_random((M, N), dtype_str=dtype, rs=rs) z_tri = to_triton(z, device=device) if epilogue == 'trans': z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) @@ -732,8 +738,10 @@ def test_dot(epilogue, allow_tf32, device='cuda'): assert 'st.global.v4' in ptx if allow_tf32: assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx - else: + elif dtype == 'float32': assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx + elif dtype == 'int8': + assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx def test_dot_without_load():