From 94a2e10fe57015375b6b264c67d0453fe7d10789 Mon Sep 17 00:00:00 2001 From: daadaada Date: Wed, 12 Jan 2022 02:20:31 +0800 Subject: [PATCH] [BACKEND] Add bf16 & tf32 mma supports (on A100) (#426) --- include/triton/codegen/analysis/layout.h | 77 ++- include/triton/ir/builder.h | 2 +- include/triton/ir/dispatch.h | 2 +- include/triton/ir/instructions.h | 25 +- lib/codegen/analysis/layout.cc | 103 +++- lib/codegen/analysis/swizzle.cc | 9 +- lib/codegen/pass.cc | 1 - lib/codegen/selection/common.h | 78 +++ lib/codegen/selection/generator.cc | 607 +++++++++++++++-------- lib/codegen/transform/peephole.cc | 2 +- lib/codegen/transform/prefetch.cc | 9 +- lib/driver/llvm.cc | 1 + lib/ir/builder.cc | 4 +- lib/ir/dispatch.cc | 5 +- lib/ir/instructions.cc | 23 +- python/test/unit/language/test_core.py | 24 +- python/triton/language/core.py | 8 +- 17 files changed, 717 insertions(+), 263 deletions(-) create mode 100644 lib/codegen/selection/common.h diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 4d12e34c0..dc5150f05 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -109,6 +109,63 @@ protected: }; class mma_layout: public distributed_layout { +public: + enum TensorCoreType : uint8_t { + // floating-point tensor core instr + FP32_FP16_FP16_FP32 = 0, // default + FP32_BF16_BF16_FP32, + FP32_TF32_TF32_FP32, + // integer tensor core instr + INT32_INT1_INT1_INT32, // Not implemented + INT32_INT4_INT4_INT32, // Not implemented + INT32_INT8_INT8_INT32, // Not implemented + // + NOT_APPLICABLE, + }; + + // Used on nvidia GPUs with sm >= 80 + inline static const std::map> mma_instr_shape_ = { + {FP32_FP16_FP16_FP32, {16, 8, 16}}, + {FP32_BF16_BF16_FP32, {16, 8, 16}}, + {FP32_TF32_TF32_FP32, {16, 8, 8}}, + + {INT32_INT1_INT1_INT32, {16, 8, 256}}, + {INT32_INT4_INT4_INT32, {16, 8, 64}}, + {INT32_INT8_INT8_INT32, {16, 8, 32}}, + }; + + // shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices) + inline static const std::map> mma_mat_shape_ = { + {FP32_FP16_FP16_FP32, {8, 8, 8}}, + {FP32_BF16_BF16_FP32, {8, 8, 8}}, + {FP32_TF32_TF32_FP32, {8, 8, 4}}, + + {INT32_INT1_INT1_INT32, {8, 8, 64}}, + {INT32_INT4_INT4_INT32, {8, 8, 32}}, + {INT32_INT8_INT8_INT32, {8, 8, 16}}, + }; + + inline static const std::map mma_instr_ptx_ = { + {FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"}, + {FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"}, + {FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"}, + + {INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"}, + {INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"}, + {INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"}, + }; + + // vector length per ldmatrix (16*8/elelment_size_in_bits) + inline static const std::map mma_instr_vec_ = { + {FP32_FP16_FP16_FP32, 8}, + {FP32_BF16_BF16_FP32, 8}, + {FP32_TF32_TF32_FP32, 4}, + + {INT32_INT1_INT1_INT32, 128}, + {INT32_INT4_INT4_INT32, 32}, + {INT32_INT8_INT8_INT32, 16}, + }; + public: mma_layout(size_t num_warps, const std::vector& axes, @@ -116,7 +173,8 @@ public: const std::vector &values, analysis::align* align, target *tgt, shared_layout* layout_a, - shared_layout* layout_b); + shared_layout* layout_b, + ir::value *dot); void accept(layout_visitor* vst) { vst->visit_layout_mma(this); } // accessor int fpw(size_t k) { return fpw_.at(k); } @@ -124,6 +182,16 @@ public: int spw(size_t k) { return spw_.at(k); } int rep(size_t k) { return rep_.at(k); } + // helpers for generator.cc + std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); } + std::vector get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); } + std::vector get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); } + int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); } + int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); } + + // setter + void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; } + private: // fragment per warp std::vector fpw_; @@ -135,6 +203,8 @@ private: std::vector spt_; // repetitions std::vector rep_; + + TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32; }; struct scanline_layout: public distributed_layout { @@ -182,7 +252,7 @@ public: const std::vector& shapes, const std::vector &values_, ir::type *ty, - analysis::align* align); + analysis::align* align, target *tgt); void accept(layout_visitor* vst) { vst->visit_layout_shared(this); } // accessors size_t get_size() { return size_; } @@ -197,6 +267,7 @@ public: ir::value* hmma_dot_b() { return hmma_dot_b_; } void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; } int get_mma_vec() { return mma_vec_;} + int get_mma_strided() { return mma_strided_; } data_layout* get_arg_layout() { return arg_layout_; } private: @@ -209,6 +280,8 @@ private: ir::value* hmma_dot_b_; data_layout* arg_layout_; int mma_vec_; + int mma_strided_; + target *tgt_; }; diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 3a4094123..67ab47c90 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -154,7 +154,7 @@ public: value *create_cos(value* arg); value *create_sin(value* arg); value *create_log(value* arg); - value *create_dot(value *A, value *B, value *C); + value *create_dot(value *A, value *B, value *C, bool allow_tf32); value *create_trans(value *A, const std::vector &perm = {}); value *create_sqrt(value *A); value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis); diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index c90480f1e..d8293d231 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -80,7 +80,7 @@ struct dispatch{ static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); // linear algebra - static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder); + static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder); // indexing static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 7c147f634..ca1416f48 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -742,26 +742,29 @@ public: }; private: - dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next); + dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next); std::string repr_impl() const { return "dot"; } - - bool is_prefetched_ = false; - DataType C_type_ = DataType::FP32; - DataType A_type_ = DataType::FP16; - DataType B_type_ = DataType::FP16; public: bool is_prefetched() const { return is_prefetched_; } void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; } + bool allow_tf32() const { return allow_tf32_; } public: - static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr); - static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); - static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); - static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); - static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); + static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); + static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); + static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); + static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); + static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(dot_inst) _TRITON_DEFINE_ACCEPT(dot_inst) + +private: + bool is_prefetched_ = false; + bool allow_tf32_ = false; + DataType C_type_ = DataType::FP32; + DataType A_type_ = DataType::FP16; + DataType B_type_ = DataType::FP16; }; //class outer_inst: public builtin_inst { diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 64163c91c..d00959e45 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -23,19 +23,65 @@ inline unsigned clamp(unsigned x, unsigned a, unsigned b) { return std::min(std::max(x, lo), hi); } -inline bool is_hmma_c(ir::value *v){ +inline bool is_hmma_c(ir::value *v, int sm){ bool result = false; if(auto *x = dynamic_cast(v)){ ir::value *a = x->get_operand(0); ir::type *a_ty = a->get_type(); ir::value *b = x->get_operand(1); ir::type *b_ty = b->get_type(); - result = a_ty->get_scalar_ty()->is_fp16_ty() && - b_ty->get_scalar_ty()->is_fp16_ty(); + result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) || + (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) || + (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() && + x->allow_tf32() && sm >= 80); } return result; } +static mma_layout::TensorCoreType get_mma_type(ir::value *v) { + mma_layout::TensorCoreType mma_type; + if (auto* dot = dynamic_cast(v)) { + ir::value* a = dot->get_operand(0); + ir::value* b = dot->get_operand(1); + ir::type* a_ty = a->get_type(); + ir::type* b_ty = b->get_type(); + ir::type* c_ty = v->get_type(); + + if (c_ty->get_scalar_ty()->is_fp32_ty()) { + // floating point tensor cores + if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) { + mma_type = mma_layout::FP32_FP16_FP16_FP32; + return mma_type; + } + if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) { + mma_type = mma_layout::FP32_BF16_BF16_FP32; + return mma_type; + } + if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() + && dot->allow_tf32()) { + mma_type = mma_layout::FP32_TF32_TF32_FP32; + return mma_type; + } + } else if (c_ty->get_scalar_ty()->is_integer_ty(32)) { + throw std::runtime_error("integer tensor cores are not yet supported"); + // // integer tensor cores + // if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) { + // mma_type = mma_layout::INT32_INT1_INT1_INT32; + // return mma_type; + // } + // if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) { + // mma_type = mma_layout::INT32_INT4_INT4_INT32; + // return mma_type; + // } + // if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) { + // mma_type = mma_layout::INT32_INT8_INT8_INT32; + // return mma_type; + // } + } + } + return mma_layout::NOT_APPLICABLE; +} + inline void extract_io_use(ir::value *v, std::set& result) { for(ir::user* u: v->get_users()){ auto i = dynamic_cast(u); @@ -52,11 +98,12 @@ inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) { } } -inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) { +inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) { for(ir::user* u: v->get_users()){ auto i = dynamic_cast(u); - if(i && is_hmma_c(i) && i->get_operand(n) == v) + if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) { result = i; + } } } @@ -142,7 +189,9 @@ mma_layout::mma_layout(size_t num_warps, const std::vector& shape, const std::vector &values, analysis::align* align, target* tgt, - shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) { + shared_layout *layout_a, shared_layout *layout_b, + ir::value *dot): distributed_layout(MMA, axes, shape, values, align) { + tensor_core_type_ = get_mma_type(dot); /* fragments per warp */ // try to make things as square as possible to maximize data re-use if(tgt->as_nvidia()->sm() < 80){ @@ -159,9 +208,9 @@ mma_layout::mma_layout(size_t num_warps, spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; } else{ - fpw_ = {1, 1, 1}; - spw_ = {16, 8, 1}; - rep_ = {2, 2, 1}; + // fpw_ = {1, 1, 1}; + spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 + // rep_ = {2, 2, 1}; } order_ = {0, 1}; @@ -356,7 +405,8 @@ shared_layout::shared_layout(data_layout *arg, const std::vector& shape, const std::vector &values, ir::type *ty, - analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) { + analysis::align* align, target *tgt) + : data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) { size_ = 0; arg_layout_ = arg; @@ -382,12 +432,25 @@ shared_layout::shared_layout(data_layout *arg, for(ir::value* v: values){ extract_dot_use(v, dot_a, 0); extract_dot_use(v, dot_b, 1); - extract_hmma_dot_use(v, hmma_dot_a, 0); - extract_hmma_dot_use(v, hmma_dot_b, 1); + extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm()); + extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm()); } hmma_dot_a_ = hmma_dot_a; hmma_dot_b_ = hmma_dot_b; + // Update mma_vec + if (hmma_dot_a_) { + assert(order_.size() == 2); + std::vector mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_)); + mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m + mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2]; + } else if (hmma_dot_b_) { + assert(order_.size() == 2); + std::vector mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_)); + mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k + mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1]; + } + // size size_ = ty_->get_primitive_size_in_bits() / 8; for(auto s: shape_) @@ -451,7 +514,8 @@ void layouts::make_graph(ir::instruction *i) { void layouts::create(size_t id, const std::vector& values) { // if(layouts_.find(id) != layouts_.end()) // return; - auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c); + auto it_hmma_c = std::find_if(values.begin(), values.end(), + [&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); }); auto cmp = [](ir::value* x, ir::value *y) { std::pair xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()}; std::pair yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()}; @@ -473,13 +537,16 @@ void layouts::create(size_t id, const std::vector& values) { ir::value *b = dot->get_operand(1); create(groups_.at(a), values_.at(groups_.at(a))); create(groups_.at(b), values_.at(groups_.at(b))); - layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b))); + layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, + (shared_layout*)layouts_.at(groups_.at(a)), + (shared_layout*)layouts_.at(groups_.at(b)), + dot); } else if(it_cts != values.end()){ ir::instruction *cts = (ir::instruction*)*it_cts; ir::value *arg = cts->get_operand(0); create(groups_.at(arg), values_.at(groups_.at(arg))); - layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); + layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_); } else{ layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_); @@ -516,7 +583,7 @@ void layouts::run(ir::module &mod) { scanline_layout *layout = get(arg)->to_scanline(); shapes[axis] = layout->mts(axis); // create layout - layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_); + layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_); tmp_[red] = id; } if(auto *val = dynamic_cast(i)){ @@ -529,12 +596,12 @@ void layouts::run(ir::module &mod) { shape[k] = std::max(in_layout->shape_per_cta(k), out_layout->shape_per_cta(k)); } - layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_); + layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_); tmp_[val] = id; } if(auto *atom = dynamic_cast(i)){ id++; - layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_); + layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_); tmp_[atom] = id; } }); diff --git a/lib/codegen/analysis/swizzle.cc b/lib/codegen/analysis/swizzle.cc index fcde938d9..1dbae10d4 100644 --- a/lib/codegen/analysis/swizzle.cc +++ b/lib/codegen/analysis/swizzle.cc @@ -19,6 +19,7 @@ void swizzle::run(ir::module &) { continue; ir::value* mma_dot_a = layout->hmma_dot_a(); ir::value* mma_dot_b = layout->hmma_dot_b(); + if(!mma_dot_a && !mma_dot_b){ per_phase_[layout] = 1; max_phase_[layout] = 1; @@ -39,10 +40,10 @@ void swizzle::run(ir::module &) { else vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); } - else{ - per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); - max_phase_[layout] = 8 / per_phase_[layout]; - vec_[layout] = 8; + else { + per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout]; + vec_[layout] = layout->get_mma_vec(); } } } diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index d38d81a9c..8921d6c84 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -85,7 +85,6 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC allocation.run(ir); prefetch_s.run(ir); barriers.run(ir); - // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); return llvm; diff --git a/lib/codegen/selection/common.h b/lib/codegen/selection/common.h new file mode 100644 index 000000000..c4b0951da --- /dev/null +++ b/lib/codegen/selection/common.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include +#include "triton/codegen/selection/generator.h" +#include "triton/codegen/target.h" +#include "triton/codegen/analysis/axes.h" +#include "triton/codegen/analysis/allocation.h" +#include "triton/codegen/analysis/align.h" +#include "triton/codegen/analysis/swizzle.h" +#include "triton/codegen/transform/coalesce.h" +#include "triton/ir/context.h" +#include "triton/ir/module.h" +#include "triton/ir/function.h" +#include "triton/ir/type.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsNVPTX.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +namespace triton::codegen { +// types +#define void_ty builder_->getVoidTy() +#define f16_ty builder_->getHalfTy() +#define bf16_ty builder_->getBFloatTy() +#define f32_ty builder_->getFloatTy() +#define i8_ty builder_->getInt8Ty() +#define i32_ty builder_->getInt32Ty() +#define vec_ty(type, num_el) VectorType::get(type, num_el, false) +#define ptr_ty(...) PointerType::get(__VA_ARGS__) +// constants +#define i32(...) builder_->getInt32(__VA_ARGS__) +// ops +#define and_(...) builder_->CreateAnd(__VA_ARGS__) +#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__) +#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__) +#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__) +#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__) +#define br(...) builder_->CreateBr(__VA_ARGS__) +#define call(...) builder_->CreateCall(__VA_ARGS__) +#define cast(...) builder_->CreateCast(__VA_ARGS__) +#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__) +#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__) +#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__) +#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__) +#define fadd(...) builder_->CreateFAdd(__VA_ARGS__) +#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__) +#define fmul(...) builder_->CreateFMul(__VA_ARGS__) +#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__) +#define fsub(...) builder_->CreateFSub(__VA_ARGS__) +#define icmp(...) builder_->CreateICmp(__VA_ARGS__) +#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__) +#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__) +#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__) +#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__) +#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) +#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) +#define load(...) builder_->CreateLoad(__VA_ARGS__) +#define lshr(...) builder_->CreateLShr(__VA_ARGS__) +#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__) +#define min_num(...) builder_->CreateMinNum(__VA_ARGS__) +#define neg(...) builder_->CreateNeg(__VA_ARGS__) +#define phi(...) builder_->CreatePHI(__VA_ARGS__) +#define ret(...) builder_->CreateRet(__VA_ARGS__) +#define select(...) builder_->CreateSelect(__VA_ARGS__) +#define store(...) builder_->CreateStore(__VA_ARGS__) +#define sub(...) builder_->CreateSub(__VA_ARGS__) +#define shl(...) builder_->CreateShl(__VA_ARGS__) +#define udiv(...) builder_->CreateUDiv(__VA_ARGS__) +#define urem(...) builder_->CreateURem(__VA_ARGS__) +#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__) +#define xor_(...) builder_->CreateXor(__VA_ARGS__) + +} // namespace triton::codegen \ No newline at end of file diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index a6148b2d1..b180ecb12 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -81,12 +81,13 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ // return (*builder_)->CreateGEP(ty, ptr, vals, name); //} - // types #define void_ty builder_->getVoidTy() #define f16_ty builder_->getHalfTy() +#define bf16_ty builder_->getBFloatTy() #define f32_ty builder_->getFloatTy() #define i8_ty builder_->getInt8Ty() +#define i16_ty builder_->getInt16Ty() #define i32_ty builder_->getInt32Ty() #define vec_ty(type, num_el) VectorType::get(type, num_el, false) #define ptr_ty(...) PointerType::get(__VA_ARGS__) @@ -133,7 +134,6 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define splat(...) builder_->CreateVectorSplat(__VA_ARGS__) #define xor_(...) builder_->CreateXor(__VA_ARGS__) - /** * \brief Convert Triton-IR Type to LLVM-IR Type */ @@ -162,7 +162,7 @@ Type *generator::cvt(ir::type *ty) { case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); case ir::type::FP16TyID: return Type::getHalfTy(*ctx_); - case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); + case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_); case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); @@ -457,19 +457,25 @@ std::tuple generator::fp8x4_to_fp16x4(Value *in0 } Value* generator::bf16_to_fp32(Value *in0){ - Value *ret = UndefValue::get(vec_ty(builder_->getInt16Ty(), 2)); - ret = insert_elt(ret, in0, (uint64_t)1); - ret = insert_elt(ret, builder_->getInt16(0), (uint64_t)0); - return bit_cast(ret, builder_->getFloatTy()); + if (tgt_->as_nvidia()->sm() >= 80) { + InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false), + "cvt.rn.f32.bf16 $0, $1;", "=r,h", false); + return call(ptx, {in0}); + } else { + Value *ret = UndefValue::get(vec_ty(i16_ty, 2)); + ret = insert_elt(ret, bit_cast(in0, i16_ty), (uint64_t)1); + ret = insert_elt(ret, bit_cast(builder_->getInt16(0), i16_ty), (uint64_t)0); + return bit_cast(ret, f32_ty); + } } Value* generator::fp32_to_bf16(Value *in0){ if(tgt_->as_nvidia()->sm() >= 80){ - InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}, false), + InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false), "cvt.rn.bf16.f32 $0, $1;", "=h,r", false); return call(ptx, {in0}); } - return extract_elt(bit_cast(in0, vec_ty(builder_->getInt16Ty(), 2)), (uint64_t)1); + return extract_elt(bit_cast(in0, vec_ty(i16_ty, 2)), (uint64_t)1); } /** @@ -514,12 +520,16 @@ void generator::visit_cast_inst(ir::cast_inst* x) { if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){ // FP32 -> BF16 if(op_sca_ty->is_fp32_ty()) - for(size_t i = 0; i < x_idxs.size(); i++) - vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]); + // for(size_t i = 0; i < x_idxs.size(); i++) + // vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]); + for (indices_t idx: idxs_.at(x)) { + Value *arg = vals_[x->get_operand(0)][idx]; + vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty); + } // BF16 -> FP32 if(ret_sca_ty->is_fp32_ty()) - for(size_t i = 0; i < x_idxs.size(); i++) - vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]); + for(size_t i = 0; i < x_idxs.size(); i++) + vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]); return; } @@ -678,6 +688,7 @@ void generator::visit_load_inst(ir::load_inst* x){ // --- std::vector ret_tys(n_words, IntegerType::get(*ctx_, width)); Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0]; + // ret_ty->print(llvm::outs()); std::vector arg_tys = {pred->getType(), ptr->getType()}; for(Value *v: others) arg_tys.push_back(v->getType()); @@ -747,15 +758,19 @@ void generator::visit_store_inst(ir::store_inst * x){ } auto idxs = idxs_.at(val_op); Type *ty = cvt(val_op->get_type()->get_scalar_ty()); + if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store + ty = f16_ty; for(size_t i = 0; i < idxs.size(); i += vec){ auto idx = idxs[i]; // pointer Value *ptr = vals_[ptr_op][idx]; - ptr = bit_cast(ptr, vec_ty(ty, vec)->getPointerTo(1)); + // vectorize + Type *v_ty = vec_ty(ty, vec); + ptr = bit_cast(ptr, v_ty->getPointerTo(1)); // value - Value* val = UndefValue::get(vec_ty(ty, vec)); + Value* val = UndefValue::get(v_ty); for(size_t ii = 0; ii < vec; ii++) - val = insert_elt(val, vals_.at(val_op)[idxs[i + ii]], ii); + val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii); if(mx){ Value *msk = vals_[mx->get_mask_operand()][idx]; Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {}); @@ -1317,6 +1332,229 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va vals_[C][idxs_[C][i]] = acc[i]; } +namespace { +class mma16816_smem_loader { +public: + mma16816_smem_loader(int wpt, std::vector order, int k_order, + std::vector tile_shape, + std::vector instr_shape, std::vector mat_shape, + int per_phase, int max_phase, int dtsize, Builder *builder, + adder add, multiplier mul, geper gep) + : wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape), + instr_shape_(instr_shape), mat_shape_(mat_shape), + per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder), + add(add), mul(mul), gep(gep) { + // compute compile-time constant variables & types + c_mat_shape_ = mat_shape[order[0]]; + s_mat_shape_ = mat_shape[order[1]]; + + c_stride_ = tile_shape[order[1]]; + s_stride_ = tile_shape[order[0]]; + + // rule: k must be the fast-changing axis + need_trans_ = k_order_ != order_[0]; + can_use_ldmatrix_ = dtsize == 2 || (!need_trans_); + + // std::cout << can_use_ldmatrix_ << std::endl; + // std::cout << need_trans_ << std::endl; + + // we need more pointers at the fast-changing axis, + if (can_use_ldmatrix_) + num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]]; + else // warning: this only works for tf32 & need transpose + num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]]; + num_ptr_ = std::max(num_ptr_, 2); + + + // load_v4 stride (in num of mats) + int load_stride_in_mat[2]; + load_stride_in_mat[k_order] = 2; // instr_shape[k_order] / mat_shape[k_order], always 2 + load_stride_in_mat[k_order^1] = wpt * (instr_shape[k_order^1] / mat_shape[k_order^1]); + p_load_stride_in_mat_ = load_stride_in_mat[order[0]]; + // stride in mat, used by load_v4 + s_mat_stride_ = load_stride_in_mat[order[1]] / (instr_shape[order[1]]/mat_shape[order[1]]); + } + + std::vector compute_offs(Value *warp_off, Value *lane) { + // TODO: this needs to be moved to constructor (and extracted to arr_order) + mat_arr_stride_ = (k_order_ == 1) ? 1 : wpt_; + warp_off_stride_ = instr_shape_[k_order_^1] / mat_shape_[k_order_^1]; + // start matrix logic offset (rename it as base_mat_off?) + Value *mat_off[2] = {nullptr, nullptr}; + + if (can_use_ldmatrix_) { + // c: lane idx inside a group (a group is a collection of 8 contiguous threads) + // s: group idx (0,1,2,3) inside a warp + Value *c = urem(lane, i32(8)); + Value *s = udiv(lane, i32(8)); + // We can decompose s => s_0, s_1... + Value *s0 = urem(s, i32(2)); + Value *s1 = udiv(s, i32(2)); + + // We use different orders for a & b for better performance. + Value *k_mat_arr = (k_order_ == 1) ? s1 : s0; + Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1; + mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)), + mul(nk_mat_arr, i32(mat_arr_stride_))); + mat_off[k_order_] = k_mat_arr; + // physical offset (before swizzling) + Value *c_mat_off = mat_off[order_[0]]; + Value *s_mat_off = mat_off[order_[1]]; + // offset inside a matrix + Value *s_off_in_mat = c; + + std::vector offs(num_ptr_); + Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); + // pre-compute strided offset + Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_))); + for (int i=0; i < num_ptr_; ++i) { + Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_)); + c_mat_off_i = xor_(c_mat_off_i, phase); // smem swizzle + offs[i] = add(mul(c_mat_off_i, i32(c_mat_shape_)), mul(s_off, i32(s_stride_))); + } + return offs; + } else if (dtsize_ == 4 && need_trans_) { + // load tf32 matrices with lds32 + Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]] + Value *s_off_in_mat = urem(lane, i32(4)); // + + Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); + std::vector offs(num_ptr_); + for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time + int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2; + int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2; + if (k_mat_arr_int > 0) // we don't need pointers for k + continue; + Value *k_mat_arr = i32(k_mat_arr_int); + Value *nk_mat_arr = i32(nk_mat_arr_int); + // physical offset (before swizzling) + Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)), + mul(nk_mat_arr, i32(mat_arr_stride_))); + Value *s_mat_off = k_mat_arr; // always 0? + Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_))); + // FIXME: (k_order_ == 1?) is really dirty hack + for (int i = 0; i < num_ptr_/2; ++i) { + Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_*(k_order_ == 1?1:2))); + c_mat_off_i = xor_(c_mat_off_i, phase); + Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_))); + // TODO: move this out of the loop + c_off = urem(c_off, i32(tile_shape_[order_[0]])); + s_off = urem(s_off, i32(tile_shape_[order_[1]])); + offs[2*i + nk_mat_arr_int] = add(c_off, mul(s_off, i32(s_stride_))); + } + } + return offs; + // throw std::runtime_error("not implemented"); + } else + throw std::runtime_error("invalid smem load config"); + } + + std::tuple + load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn, + Value *pre_ptr, Value *next_ptr, std::vector &off, std::vector &ptrs, + FunctionType *ldmatrix_ty, Type *smem_ptr_ty, + std::map> &prefetch_latch_to_bb_) { + assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned"); + int mat_idx[2] = {mat0, mat1}; + int k = mat_idx[k_order_]; + + int ptr_idx = -1; + if (can_use_ldmatrix_) + ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]); + else // tf32 & trans + ptr_idx = mat_idx[order_[0]]; + + auto get_ptr = [&](int idx) -> Value* { + Value *ptr = nullptr; + if (k == 0 && is_prefetch) { + if (inc == 0) + ptr = bit_cast(gep(pre_ptr, off.at(idx)), smem_ptr_ty); + else + ptr = bit_cast(gep(next_ptr, off.at(idx)), smem_ptr_ty); + } else + ptr = ptrs.at(idx); + return ptr; + }; + Value *ptr = get_ptr(ptr_idx); + + Value *res_v4 = nullptr; + if (can_use_ldmatrix_) { + std::string trans = need_trans_ ? ".trans" : ""; + // the offset (in byte) on the strided axis is a constant + int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_; + InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty, + "ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 " + "{$0, $1, $2, $3}, " + "[$4 + " + std::to_string(s_offset) + "];", + "=r,=r,=r,=r,r", true); + assert(ptr); + res_v4 = call(ldmatrix_ty, ld_fn, {ptr}); + if (k == 0 && inc == 1 && is_prefetch) + prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4); + return {extract_val(res_v4, std::vector{0}), + extract_val(res_v4, std::vector{1}), + extract_val(res_v4, std::vector{2}), + extract_val(res_v4, std::vector{3})}; + } else { + // assert(false && "should not be here"); + assert(dtsize_ == 4 && need_trans_); + Value *ptr2 = get_ptr(ptr_idx+1); + assert(s_mat_stride_ == 1); + int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; + int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_; + Value *elem0, *elem1, *elem2, *elem3; + if (k_order_ == 1) { + elem0 = load(gep(ptr, i32(s_offset_elem))); + elem1 = load(gep(ptr2, i32(s_offset_elem))); + elem2 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem))); + elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem))); + } else { // for b (k first) + elem0 = load(gep(ptr, i32(s_offset_elem))); + elem2 = load(gep(ptr2, i32(s_offset_elem))); + elem1 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem))); + elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem))); + } + if (k == 0 && inc == 1 && is_prefetch) { + prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem0); + prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem1); + prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem2); + prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3); + } + return {elem0, elem1, elem2, elem3}; + } + } + + int get_num_ptr() const { return num_ptr_; } + +private: + int wpt_; + std::vector order_; + int k_order_; + std::vector tile_shape_; + std::vector instr_shape_; + std::vector mat_shape_; + int per_phase_, max_phase_; + int dtsize_; + + // generated + int c_mat_shape_, s_mat_shape_; + int c_stride_, s_stride_; + // p_: on the pointer axis + int p_load_stride_in_mat_; + int s_mat_stride_; + // stride when moving to next not-k mat + int warp_off_stride_; + int mat_arr_stride_; // matrix arrangement (inside a load) stride + bool need_trans_, can_use_ldmatrix_; + int num_ptr_; + + Builder *builder_; + adder add; + multiplier mul; + geper gep; +}; +} + /** * \brief Code Generation for `mma.16816` (A100) */ @@ -1338,35 +1576,65 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1)); bool is_a_row = ord_a[0] == 1; bool is_b_row = ord_b[0] == 1; - std::string a_trans = is_a_row ? "" : ".trans"; - std::string b_trans = is_b_row ? ".trans" : ""; - int stride_a_m = is_a_row ? shape_a[1] : 1; - int stride_a_k = is_a_row ? 1 : shape_a[0]; - int stride_b_n = is_b_row ? 1 : shape_b[0]; - int stride_b_k = is_b_row ? shape_b[1] : 1; - int stride_a0 = is_a_row ? stride_a_k : stride_a_m; - int stride_a1 = is_a_row ? stride_a_m : stride_a_k; - int stride_b0 = is_b_row ? stride_b_n : stride_b_k; - int stride_b1 = is_b_row ? stride_b_k : stride_b_n; - int lda = is_a_row ? stride_a_m : stride_a_k; - int ldb = is_b_row ? stride_b_k : stride_b_n; - int per_phase_a = swizzle_->get_per_phase(layout_a); - int max_phase_a = swizzle_->get_max_phase(layout_a); - int per_phase_b = swizzle_->get_per_phase(layout_b); - int max_phase_b = swizzle_->get_max_phase(layout_b); - int num_ptr_a = 8; - int num_ptr_b = 8; - int vec_a = 8; - int vec_b = 8; + + std::vector mma_instr_shape = layout->get_mma_instr_shape(); + const int mma_instr_m = mma_instr_shape[0]; + const int mma_instr_n = mma_instr_shape[1]; + const int mma_instr_k = mma_instr_shape[2]; + + std::vector mat_shape = layout->get_mma_mat_shape(); + const int mat_shape_m = mat_shape[0]; + const int mat_shape_n = mat_shape[1]; + const int mat_shape_k = mat_shape[2]; + + const int per_phase_a = swizzle_->get_per_phase(layout_a); + const int max_phase_a = swizzle_->get_max_phase(layout_a); + const int per_phase_b = swizzle_->get_per_phase(layout_b); + const int max_phase_b = swizzle_->get_max_phase(layout_b); + + const int num_rep_m = shapes[0] / layout->shape_per_cta(0); + const int num_rep_n = shapes[1] / layout->shape_per_cta(1); + const int num_rep_k = std::max(NK/mma_instr_k, 1); Type *fp32_ty = f32_ty; Type *fp16x2_ty = vec_ty(f16_ty, 2); + Type *bf16x2_ty = vec_ty(bf16_ty, 2); Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); + Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty}); Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty}); - FunctionType *ld_x4_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{ptr_ty(f16_ty, 3)}, false); + + FunctionType *ldmatrix_ty = nullptr; + FunctionType *mma_ty = nullptr; + Type *phi_ty = nullptr; + Type *smem_ptr_ty = nullptr; + + ir::type *A_ir_ty = A->get_type()->get_scalar_ty(); + ir::type *B_ir_ty = B->get_type()->get_scalar_ty(); + if (A_ir_ty->is_fp16_ty() && B_ir_ty->is_fp16_ty()) { + mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); + smem_ptr_ty = ptr_ty(f16_ty, 3); + ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{smem_ptr_ty}, false); + phi_ty = fp16x2_ty; + } else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) { + // FIXME: We should use bf16 here. + mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); + smem_ptr_ty = ptr_ty(f16_ty, 3); + ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{smem_ptr_ty}, false); + phi_ty = fp16x2_ty; + // mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); + // smem_ptr_ty = ptr_ty(bf16_ty, 3); + // ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector{smem_ptr_ty}, false); + // phi_ty = bf16x2_ty; + } else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) { + mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); + smem_ptr_ty = ptr_ty(fp32_ty, 3); + ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector{smem_ptr_ty}, false); + phi_ty = fp32_ty; + } else + throw std::runtime_error("mma16816 data type not supported"); // left-hand-side values - std::map, std::pair> ha; + std::map, Value*> ha; std::map, Value*> hb; BasicBlock* CurrBB = builder_->GetInsertBlock(); @@ -1377,79 +1645,66 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: Value* thread = tgt_->get_local_id(mod_, *builder_, 0); Value *lane = urem(thread, i32(32)); Value *warp = udiv(thread, i32(32)); - Value *warp12 = udiv(warp, i32(layout->wpt(0))); - Value *warp0 = urem(warp, i32(layout->wpt(0))); - Value *warp1 = urem(warp12, i32(layout->wpt(1))); + Value *warp_mn = udiv(warp, i32(layout->wpt(0))); + Value *warp_m = urem(warp, i32(layout->wpt(0))); + Value *warp_n = urem(warp_mn, i32(layout->wpt(1))); std::vector& fc = fcs.begin()->second; - Value *tidr8 = urem(lane, i32(8)); - Value *phase_a = urem(udiv(tidr8, i32(per_phase_a)), i32(max_phase_a)); - Value* off_a0 = mul(tidr8, i32(lda)); - Value *off_am = mul(add(urem(udiv(lane, i32(8)), i32(2)), mul(warp0, i32(2))), i32(8)); - Value *off_ak = mul(udiv(lane, i32(16)), i32(8)); - off_am = urem(off_am, i32(shape_a[0])); - off_ak = urem(off_ak, i32(shape_a[1])); - off_a0 = add(off_a0, is_a_row ? off_ak : off_am); - Value* off_a1 = is_a_row ? off_am : off_ak; - std::vector off_a(num_ptr_a); - for(int i = 0; i < num_ptr_a; i++){ - Value* off_a0i = add(off_a0, i32(i*16*(is_a_row?1:layout->wpt(0)))); - off_a0i = exact_udiv(off_a0i, i32(vec_a)); - off_a0i = xor_(off_a0i, phase_a); - off_a0i = mul(off_a0i, i32(vec_a)); - off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1))); - } + size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; - Value *phase_b = urem(udiv(tidr8, i32(per_phase_b)), i32(max_phase_b)); - Value* off_b0 = mul(tidr8, i32(ldb)); - Value *off_bn = mul(add(mul(udiv(lane, i32(16)), i32(layout->wpt(1))), mul(warp1, i32(1))), i32(8)); - Value *off_bk = mul(urem(udiv(lane, i32(8)), i32(2)), i32(8)); - off_bn = urem(off_bn, i32(shape_b[1])); - off_bk = urem(off_bk, i32(shape_b[0])); - off_b0 = add(off_b0, is_b_row ? off_bn : off_bk); - Value* off_b1 = is_b_row ? off_bk : off_bn; - std::vector off_b(num_ptr_b); - for(int i = 0; i < num_ptr_b; i++){ - Value* off_b0i = add(off_b0, i32(i*(is_b_row?8*layout->wpt(1):16))); - off_b0i = exact_udiv(off_b0i, i32(vec_b)); - off_b0i = xor_(off_b0i, phase_b); - off_b0i = mul(off_b0i, i32(vec_b)); - off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1))); - } + // | -> k (row-major), since we have ldmatrix.trans, we only need to change stride + // v (s0_0(0), s1_0(2), | *num_rep_k + // m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2) + // ----------- + // *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0)) + mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a, + {mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k}, + per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep); + std::vector off_a = a_loader.compute_offs(warp_m, lane); + int num_ptr_a = a_loader.get_num_ptr(); + + // | -> n (col-major) + // v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n + // k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1)) + // ----------- + // *num_rep_k (stride in num of matrices(mat_stride_bk): 2) + mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b, + {mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n}, + per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep); + std::vector off_b = b_loader.compute_offs(warp_n, lane); + int num_ptr_b = b_loader.get_num_ptr(); builder_->SetInsertPoint(CurrBB); // A pointer std::vector ptrs_a(num_ptr_a); for(int i = 0; i < num_ptr_a; i++) - ptrs_a[i] = gep(shmems_[A], {off_a[i]}); + ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty); // B pointer std::vector ptrs_b(num_ptr_b); for(int i = 0; i < num_ptr_b; i++) - ptrs_b[i] = gep(shmems_[B], {off_b[i]}); + ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty); - FunctionType *mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - InlineAsm *mma_fn = InlineAsm::get(mma_ty, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{$0, $1, $2, $3}, " - "{$4, $5, $6, $7}, " - "{$8, $9}, " - "{$10, $11, $12, $13};", + InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() + + " {$0, $1, $2, $3}," + " {$4, $5, $6, $7}," + " {$8, $9}," + " {$10, $11, $12, $13};", "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true); - unsigned num_rep_0 = shapes[0] / layout->shape_per_cta(0); - unsigned num_rep_1 = shapes[1] / layout->shape_per_cta(1); - - // create mma & unpack result - auto call_mma = [&](unsigned m, unsigned n, unsigned K) { - unsigned cols_per_thread = num_rep_0 * 2; + // create mma & unpack result, m, n, k are offsets in mat + auto call_mma = [&](unsigned m, unsigned n, unsigned k) { + unsigned cols_per_thread = num_rep_m * 2; std::vector idx = { - (m*2 + 0) + (n*2 + 0)*cols_per_thread, - (m*2 + 0) + (n*2 + 1)*cols_per_thread, - (m*2 + 1) + (n*2 + 0)*cols_per_thread, - (m*2 + 1) + (n*2 + 1)*cols_per_thread + (m + 0) + (n*2 + 0)*cols_per_thread, + (m + 0) + (n*2 + 1)*cols_per_thread, + (m + 1) + (n*2 + 0)*cols_per_thread, + (m + 1) + (n*2 + 1)*cols_per_thread }; - Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second, - hb[{n, K}], hb[{n, K+8}], - fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]}); + Value *nc = call(mma_ty, mma_fn, + {ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}], + hb[{n, k}], hb[{n, k+1}], + fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]}); fc[idx[0]] = extract_val(nc, std::vector{0}); fc[idx[1]] = extract_val(nc, std::vector{1}); fc[idx[2]] = extract_val(nc, std::vector{2}); @@ -1459,131 +1714,83 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: ir::phi_node* phiA = dynamic_cast(A); ir::phi_node* phiB = dynamic_cast(B); - auto register_lds = - [&](decltype(ha)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) { - if (K <= 8 && is_prefetch) { - ir::basic_block* inc_block = phiA->get_incoming_block(inc); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block)); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block)); - } else - vals[{m, K}] = {val0, val1}; - }; - auto register_lds2 = - [&](decltype(hb)& vals, int m, int K, int inc, Value* val, bool is_prefetch) { - if (K <= 8 && is_prefetch) { + [&](std::map, Value*>& vals, int n, int k, int inc, Value* val, bool is_prefetch) { + if (k < 2 && is_prefetch) { ir::basic_block* inc_block = phiA->get_incoming_block(inc); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}], val, inc_block)); + lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{n, k}], val, inc_block)); } else - vals[{m, K}] = val; + vals[{n, k}] = val; }; - auto load_a = [&](int m, int K, int inc, bool is_prefetch) { - int offidx = (is_a_row ? K/16 : m) % num_ptr_a; - Value* ptra; - if(K == 0 && is_prefetch){ - if(inc == 0) - ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]); - else - ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]); - } - else - ptra = ptrs_a[offidx]; - int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); - int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K; - InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 " - "{$0, $1, $2, $3}, [$4 + " + - std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];", - "=r,=r,=r,=r,r", true); - Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra}); - if(K == 0 && inc == 1 && is_prefetch) - prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(haa); - Value *ha0 = extract_val(haa, std::vector{0}); - Value *ha1 = extract_val(haa, std::vector{1}); - Value *ha2 = extract_val(haa, std::vector{2}); - Value *ha3 = extract_val(haa, std::vector{3}); - register_lds(ha, m, K, inc, ha0, ha1, is_prefetch); - register_lds(ha, m, K + 8, inc, ha2, ha3, is_prefetch); + auto load_a = [&](int m, int k, int inc, bool is_prefetch) { + auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a], + shared_next_ptr_[layout_a], off_a, ptrs_a, + ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); + register_lds2(ha, m, k, inc, ha0, is_prefetch); + register_lds2(ha, m+1, k, inc, ha1, is_prefetch); + register_lds2(ha, m, k+1, inc, ha2, is_prefetch); + register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch); }; - auto load_b = [&](int n, int K, int inc, bool is_prefetch) { - int offidx = (is_b_row ? n : K/16) % num_ptr_b; - Value* ptrb; - if(K == 0 && is_prefetch){ - if(inc == 0) - ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]); - else - ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]); - } - else - ptrb = ptrs_b[offidx]; - int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; - int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8); - InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 " - "{$0, $1, $2, $3}, [$4 + " + - std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];", - "=r,=r,=r,=r,r", true); - Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb}); - if(K == 0 && inc == 1 && is_prefetch) - prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hbb); - Value *hb0 = extract_val(hbb, std::vector{0}); - Value *hb1 = extract_val(hbb, std::vector{1}); - Value *hb2 = extract_val(hbb, std::vector{2}); - Value *hb3 = extract_val(hbb, std::vector{3}); - register_lds2(hb, n, K, inc, hb0, is_prefetch); - register_lds2(hb, n+1, K, inc, hb2, is_prefetch); - register_lds2(hb, n, K+8, inc, hb1, is_prefetch); - register_lds2(hb, n+1, K+8, inc, hb3, is_prefetch); + auto load_b = [&](int n, int k, int inc, bool is_prefetch) { + auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b], + shared_next_ptr_[layout_b], off_b, ptrs_b, + ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); + register_lds2(hb, n, k, inc, hb0, is_prefetch); + register_lds2(hb, n+1, k, inc, hb2, is_prefetch); + register_lds2(hb, n, k+1, inc, hb1, is_prefetch); + register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch); }; if (C->is_prefetched()) { // create phis builder_->SetInsertPoint(CurrBB->getFirstNonPHI()); - for(unsigned m = 0; m < num_rep_0; m++){ - ha[{m, 0}].first = phi(fp16x2_ty, 2); - ha[{m, 0}].second = phi(fp16x2_ty, 2); - ha[{m, 8}].first = phi(fp16x2_ty, 2); - ha[{m, 8}].second = phi(fp16x2_ty, 2); + for(unsigned m = 0; m < num_rep_m; m++){ + ha[{2*m, 0}] = phi(phi_ty, 2); + ha[{2*m+1, 0}] = phi(phi_ty, 2); + ha[{2*m, 1}] = phi(phi_ty, 2); + ha[{2*m+1, 1}] = phi(phi_ty, 2); } - for(unsigned n = 0; n < num_rep_1; n+=2){ - hb[{n, 0}] = phi(fp16x2_ty, 2); - hb[{n+1, 0}] = phi(fp16x2_ty, 2); - hb[{n, 8}] = phi(fp16x2_ty, 2); - hb[{n+1, 8}] = phi(fp16x2_ty, 2); + for(unsigned n = 0; n < num_rep_n; n+=2){ + hb[{n, 0}] = phi(phi_ty, 2); + hb[{n+1, 0}] = phi(phi_ty, 2); + hb[{n, 1}] = phi(phi_ty, 2); + hb[{n+1, 1}] = phi(phi_ty, 2); } // insert prefetched lds at the end of loop header builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator()); - for(unsigned m = 0; m < num_rep_0; m++) - load_a(m, 0, 0, true); - for(unsigned n = 0; n < num_rep_1; n+=2) + for(unsigned m = 0; m < num_rep_m; m++) + load_a(2*m, 0, 0, true); + for(unsigned n = 0; n < num_rep_n; n+=2) load_b(n, 0, 0, true); // update accumulators builder_->SetInsertPoint(CurrBB); - for(unsigned K = 0; K < NK; K += 16){ - int NEXTK = (K + 16) % NK; + for(unsigned k = 0; k < num_rep_k; ++k){ // stride of instr in mat is 2 + int next_k = (k + 1) % num_rep_k; // prefetch A - for(unsigned m = 0; m < num_rep_0; m++) - load_a(m, NEXTK, 1, true); + for(unsigned m = 0; m < num_rep_m; m++) + load_a(2*m, 2*next_k, 1, true); // prefetch B - for(unsigned n = 0; n < num_rep_1; n+=2) - load_b(n, NEXTK, 1, true); + for(unsigned n = 0; n < num_rep_n; n+=2) + load_b(n, 2*next_k, 1, true); // tensor core ops - for(unsigned m = 0; m < num_rep_0; m++) - for(unsigned n = 0; n < num_rep_1; n++){ - call_mma(m, n, K); + for(unsigned m = 0; m < num_rep_m; m++) + for(unsigned n = 0; n < num_rep_n; n++){ + call_mma(2*m, n, 2*k); } } } else{ - for(unsigned K = 0; K < NK; K += 16) - for(unsigned m = 0; m < num_rep_0; m++) - for(unsigned n = 0; n < num_rep_1; n++){ - if(ha.find({m, K}) == ha.end()) - load_a(m, K, 0, false); - if(hb.find({n, K})==hb.end()) - load_b(n, K, 0, false); - call_mma(m, n, K); - } + for (unsigned k = 0; k < num_rep_k; k++) { + for (unsigned m = 0; m < num_rep_m; m++) + load_a(2*m, 2*k, 0, /*is_prefetch*/false); + for (unsigned n = 0; n < num_rep_n; n+=2) + load_b(n, 2*k, 0, /*is_prefetch*/false); + for (unsigned m = 0; m < num_rep_m; m++) + for (unsigned n = 0; n < num_rep_n; n++) + call_mma(2*m, n, 2*k); + } } // write back unsigned i = 0; @@ -1714,7 +1921,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80) return visit_mma884(dot, A, B, D, NK); if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80) - return visit_mma16816(dot, A, B, D, NK); + return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()? return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); } @@ -1752,13 +1959,13 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){ InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false); if(ty->getPrimitiveSizeInBits() <= 32) return call(shfl, {acc, i32(i)}); - acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2)); + acc = bit_cast(acc, vec_ty(f32_ty, 2)); Value* acc0 = builder_->CreateExtractElement(acc, i32(0)); Value* acc1 = builder_->CreateExtractElement(acc, i32(1)); Value* ret = UndefValue::get(vec_ty(f32_ty, 2)); ret = insert_elt(ret, shfl_sync(acc0, i), i32(0)); ret = insert_elt(ret, shfl_sync(acc1, i), i32(1)); - return builder_->CreateBitCast(ret, ty); + return bit_cast(ret, ty); } /** @@ -1936,6 +2143,10 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes(); // pointer to temporary shared memory Type *ty = cvt(out->get_type()->get_scalar_ty()); + + if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store + ty = f16_ty; + // Orders analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(in)); analysis::distributed_layout* out_layout = dynamic_cast(layouts_->get(out)); @@ -1976,7 +2187,7 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ // stash value to shared mem indices_t idxs = {in_ax[0][i*max_ii + ii], in_ax[1][j*max_jj + jj]}; - store(vals_[in][idxs], ptr); + store(bit_cast(vals_[in][idxs], ty), ptr); } add_barrier(); max_ii = out_ax[0].size()/n_reps[0]; diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index 864fea85c..bae8fe828 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -80,7 +80,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ ir::value *a = dot->get_operand(0); ir::value *b = dot->get_operand(1); builder.set_insert_point(add); - ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name())); + ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name())); add->replace_all_uses_with(new_dot); return true; } diff --git a/lib/codegen/transform/prefetch.cc b/lib/codegen/transform/prefetch.cc index c54a92005..f94d8f368 100644 --- a/lib/codegen/transform/prefetch.cc +++ b/lib/codegen/transform/prefetch.cc @@ -29,8 +29,13 @@ void prefetch::run(ir::module &mod) { std::vector to_prefetch; ir::for_each_instruction(mod, [&](ir::instruction *i) { if (auto *dot = dynamic_cast(i)) { - // Now only do prefetching when dot is fp16 - if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::FP16TyID) + // Now only do prefetching when dot is using tensor cores + if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() || + dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() || + (dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32() + && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) + ) + ) return; auto *a = dynamic_cast(dot->get_operand(0)); auto *b = dynamic_cast(dot->get_operand(1)); diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 7248d6cec..910f23a1a 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -121,6 +121,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){ init_llvm(); // verify and store llvm llvm::legacy::PassManager pm; + // module->print(llvm::outs(), nullptr); pm.add(llvm::createVerifierPass()); pm.run(*module); // create machine diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index a8ba68d1c..bf59b5b39 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -367,8 +367,8 @@ value *builder::create_log(value *arg){ return insert(log_inst::create(arg)); } -value *builder::create_dot(value *A, value *B, value *C) { - return insert(dot_inst::create_nn(A, B, C)); +value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) { + return insert(dot_inst::create_nn(A, B, C, allow_tf32)); } value *builder::create_trans(value *A, const std::vector& perm) { diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index a46c34b4f..12be53bfc 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -718,12 +718,13 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask // Linear Algebra //===----------------------------------------------------------------------===// -ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) { +ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) { ir::value *_0 = builder->get_float32(0); unsigned M = lhs->get_type()->get_block_shapes()[0]; unsigned N = rhs->get_type()->get_block_shapes()[1]; _0 = builder->create_splat(_0, {M, N}); - return builder->create_dot(lhs, rhs, _0); + bool _allow_tf32 = allow_tf32->get_value() != 0; + return builder->create_dot(lhs, rhs, _0, _allow_tf32); } diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 0206b7e77..6e416a43e 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -577,40 +577,41 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct // matmul_inst classes //===----------------------------------------------------------------------===// -dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, +dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next) : builtin_inst(C->get_type(), INST_DOT, 3, name, next) { set_operand(0, A); set_operand(1, B); set_operand(2, C); + allow_tf32_ = allow_tf32; } instruction *dot_inst::create(value *A, value *B, value *C, - bool AT, bool BT, + bool AT, bool BT, bool allow_tf32, const std::string &name, instruction *next) { TransT OPA = AT ? Trans : NoTrans; TransT OPB = BT ? Trans : NoTrans; - return new dot_inst(A, B, C, OPA, OPB, name, next); + return new dot_inst(A, B, C, OPA, OPB, allow_tf32, name, next); } -instruction *dot_inst::create_nn(value *A, value *B, value *C, +instruction *dot_inst::create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name, instruction *next) { - return new dot_inst(A, B, C, NoTrans, NoTrans, name, next); + return new dot_inst(A, B, C, NoTrans, NoTrans, allow_tf32, name, next); } -instruction *dot_inst::create_nt(value *A, value *B, value *C, +instruction *dot_inst::create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name, instruction *next) { - return new dot_inst(A, B, C, NoTrans, Trans, name, next); + return new dot_inst(A, B, C, NoTrans, Trans, allow_tf32, name, next); } -instruction *dot_inst::create_tn(value *A, value *B, value *C, +instruction *dot_inst::create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name, instruction *next) { - return new dot_inst(A, B, C, Trans, NoTrans, name, next); + return new dot_inst(A, B, C, Trans, NoTrans, allow_tf32, name, next); } -instruction *dot_inst::create_tt(value *A, value *B, value *C, +instruction *dot_inst::create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name, instruction *next) { - return new dot_inst(A, B, C, Trans, Trans, name, next); + return new dot_inst(A, B, C, Trans, Trans, allow_tf32, name, next); } //===----------------------------------------------------------------------===// diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d8e88a609..e32622005 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -10,6 +10,7 @@ import torch from numpy.random import RandomState import triton +import triton._C.libtriton.triton as _triton import triton.language as tl from triton.code_gen import TensorWrapper, reinterpret @@ -660,22 +661,26 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # --------------- -@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']) -def test_dot(epilogue, device='cuda'): +@pytest.mark.parametrize("epilogue, allow_tf32", + [(epilogue, allow_tf32) + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] + for allow_tf32 in [True, False]]) +def test_dot(epilogue, allow_tf32, device='cuda'): # triton kernel @triton.jit def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr): + ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, + ALLOW_TF32: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) off_k = tl.arange(0, BLOCK_K) Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn - z = tl.dot(tl.load(Xs), tl.load(Ys)) + z = tl.dot(tl.load(Xs), tl.load(Ys), allow_tf32=ALLOW_TF32) if ADD_MATRIX: z += tl.load(Zs) if ADD_ROWS: @@ -690,6 +695,12 @@ def test_dot(epilogue, device='cuda'): rs = RandomState(17) x = numpy_random((M, K), dtype_str='float32', rs=rs) y = numpy_random((K, N), dtype_str='float32', rs=rs) + if allow_tf32: + cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) + if cc < 80: + pytest.skip("Only test tf32 on devices with sm >= 80") + x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') + y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') x_tri = to_triton(x, device=device) y_tri = to_triton(y, device=device) # triton result @@ -703,7 +714,8 @@ def test_dot(epilogue, device='cuda'): BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, ADD_MATRIX=epilogue == 'add-matrix', ADD_ROWS=epilogue == 'add-rows', - ADD_COLS=epilogue == 'add-cols') + ADD_COLS=epilogue == 'add-cols', + ALLOW_TF32=allow_tf32) # torch result z_ref = np.matmul(x, y) if epilogue == 'add-matrix': @@ -718,6 +730,8 @@ def test_dot(epilogue, device='cuda'): ptx = pgm.asm['ptx'] assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx + if allow_tf32: + assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx def test_dot_without_load(): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index d32da45c3..4ef6408d0 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -524,18 +524,18 @@ def reshape(input, shape, _builder=None): @builtin -def dot(input, other, _builder=None): +def dot(input, other, allow_tf32=True, _builder=None): """ Returns the matrix product of two blocks. The two blocks must be two dimensionals and have compatible inner dimensions. :param input: The first block to be multiplied. - :type input: 2D block of scalar-type in {:code:`float16`, :code:`float32`} + :type input: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} :param other: The second block to be multiplied. - :type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`} + :type other: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} """ - return frontend.dot(input, other, _builder) + return frontend.dot(input, other, allow_tf32, _builder) # -----------------------