[BACKEND] Add bf16 & tf32 mma supports (on A100) (#426)

This commit is contained in:
daadaada
2022-01-12 02:20:31 +08:00
committed by GitHub
parent efdabe6073
commit 94a2e10fe5
17 changed files with 717 additions and 263 deletions

View File

@@ -109,6 +109,63 @@ protected:
}; };
class mma_layout: public distributed_layout { class mma_layout: public distributed_layout {
public:
enum TensorCoreType : uint8_t {
// floating-point tensor core instr
FP32_FP16_FP16_FP32 = 0, // default
FP32_BF16_BF16_FP32,
FP32_TF32_TF32_FP32,
// integer tensor core instr
INT32_INT1_INT1_INT32, // Not implemented
INT32_INT4_INT4_INT32, // Not implemented
INT32_INT8_INT8_INT32, // Not implemented
//
NOT_APPLICABLE,
};
// Used on nvidia GPUs with sm >= 80
inline static const std::map<TensorCoreType, std::vector<int>> mma_instr_shape_ = {
{FP32_FP16_FP16_FP32, {16, 8, 16}},
{FP32_BF16_BF16_FP32, {16, 8, 16}},
{FP32_TF32_TF32_FP32, {16, 8, 8}},
{INT32_INT1_INT1_INT32, {16, 8, 256}},
{INT32_INT4_INT4_INT32, {16, 8, 64}},
{INT32_INT8_INT8_INT32, {16, 8, 32}},
};
// shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices)
inline static const std::map<TensorCoreType, std::vector<int>> mma_mat_shape_ = {
{FP32_FP16_FP16_FP32, {8, 8, 8}},
{FP32_BF16_BF16_FP32, {8, 8, 8}},
{FP32_TF32_TF32_FP32, {8, 8, 4}},
{INT32_INT1_INT1_INT32, {8, 8, 64}},
{INT32_INT4_INT4_INT32, {8, 8, 32}},
{INT32_INT8_INT8_INT32, {8, 8, 16}},
};
inline static const std::map<TensorCoreType, std::string> mma_instr_ptx_ = {
{FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
{FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"},
{FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"},
{INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"},
{INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"},
{INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
};
// vector length per ldmatrix (16*8/elelment_size_in_bits)
inline static const std::map<TensorCoreType, int> mma_instr_vec_ = {
{FP32_FP16_FP16_FP32, 8},
{FP32_BF16_BF16_FP32, 8},
{FP32_TF32_TF32_FP32, 4},
{INT32_INT1_INT1_INT32, 128},
{INT32_INT4_INT4_INT32, 32},
{INT32_INT8_INT8_INT32, 16},
};
public: public:
mma_layout(size_t num_warps, mma_layout(size_t num_warps,
const std::vector<int>& axes, const std::vector<int>& axes,
@@ -116,7 +173,8 @@ public:
const std::vector<ir::value *> &values, const std::vector<ir::value *> &values,
analysis::align* align, target *tgt, analysis::align* align, target *tgt,
shared_layout* layout_a, shared_layout* layout_a,
shared_layout* layout_b); shared_layout* layout_b,
ir::value *dot);
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); } void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
// accessor // accessor
int fpw(size_t k) { return fpw_.at(k); } int fpw(size_t k) { return fpw_.at(k); }
@@ -124,6 +182,16 @@ public:
int spw(size_t k) { return spw_.at(k); } int spw(size_t k) { return spw_.at(k); }
int rep(size_t k) { return rep_.at(k); } int rep(size_t k) { return rep_.at(k); }
// helpers for generator.cc
std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); }
std::vector<int> get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); }
std::vector<int> get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); }
int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); }
int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); }
// setter
void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; }
private: private:
// fragment per warp // fragment per warp
std::vector<int> fpw_; std::vector<int> fpw_;
@@ -135,6 +203,8 @@ private:
std::vector<int> spt_; std::vector<int> spt_;
// repetitions // repetitions
std::vector<int> rep_; std::vector<int> rep_;
TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
}; };
struct scanline_layout: public distributed_layout { struct scanline_layout: public distributed_layout {
@@ -182,7 +252,7 @@ public:
const std::vector<unsigned>& shapes, const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_, const std::vector<ir::value *> &values_,
ir::type *ty, ir::type *ty,
analysis::align* align); analysis::align* align, target *tgt);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); } void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors // accessors
size_t get_size() { return size_; } size_t get_size() { return size_; }
@@ -197,6 +267,7 @@ public:
ir::value* hmma_dot_b() { return hmma_dot_b_; } ir::value* hmma_dot_b() { return hmma_dot_b_; }
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; } void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
int get_mma_vec() { return mma_vec_;} int get_mma_vec() { return mma_vec_;}
int get_mma_strided() { return mma_strided_; }
data_layout* get_arg_layout() { return arg_layout_; } data_layout* get_arg_layout() { return arg_layout_; }
private: private:
@@ -209,6 +280,8 @@ private:
ir::value* hmma_dot_b_; ir::value* hmma_dot_b_;
data_layout* arg_layout_; data_layout* arg_layout_;
int mma_vec_; int mma_vec_;
int mma_strided_;
target *tgt_;
}; };

View File

@@ -154,7 +154,7 @@ public:
value *create_cos(value* arg); value *create_cos(value* arg);
value *create_sin(value* arg); value *create_sin(value* arg);
value *create_log(value* arg); value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C); value *create_dot(value *A, value *B, value *C, bool allow_tf32);
value *create_trans(value *A, const std::vector<int> &perm = {}); value *create_trans(value *A, const std::vector<int> &perm = {});
value *create_sqrt(value *A); value *create_sqrt(value *A);
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis); value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);

View File

@@ -80,7 +80,7 @@ struct dispatch{
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
// linear algebra // linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder); static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder);
// indexing // indexing
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder); static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);

View File

@@ -742,26 +742,29 @@ public:
}; };
private: private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next); dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; } std::string repr_impl() const { return "dot"; }
bool is_prefetched_ = false;
DataType C_type_ = DataType::FP32;
DataType A_type_ = DataType::FP16;
DataType B_type_ = DataType::FP16;
public: public:
bool is_prefetched() const { return is_prefetched_; } bool is_prefetched() const { return is_prefetched_; }
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; } void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
bool allow_tf32() const { return allow_tf32_; }
public: public:
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr); static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(dot_inst) _TRITON_DEFINE_CLONE(dot_inst)
_TRITON_DEFINE_ACCEPT(dot_inst) _TRITON_DEFINE_ACCEPT(dot_inst)
private:
bool is_prefetched_ = false;
bool allow_tf32_ = false;
DataType C_type_ = DataType::FP32;
DataType A_type_ = DataType::FP16;
DataType B_type_ = DataType::FP16;
}; };
//class outer_inst: public builtin_inst { //class outer_inst: public builtin_inst {

View File

@@ -23,19 +23,65 @@ inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
return std::min(std::max(x, lo), hi); return std::min(std::max(x, lo), hi);
} }
inline bool is_hmma_c(ir::value *v){ inline bool is_hmma_c(ir::value *v, int sm){
bool result = false; bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){ if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0); ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type(); ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1); ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type(); ir::type *b_ty = b->get_type();
result = a_ty->get_scalar_ty()->is_fp16_ty() && result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
b_ty->get_scalar_ty()->is_fp16_ty(); (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
x->allow_tf32() && sm >= 80);
} }
return result; return result;
} }
static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
mma_layout::TensorCoreType mma_type;
if (auto* dot = dynamic_cast<ir::dot_inst*>(v)) {
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
ir::type* a_ty = a->get_type();
ir::type* b_ty = b->get_type();
ir::type* c_ty = v->get_type();
if (c_ty->get_scalar_ty()->is_fp32_ty()) {
// floating point tensor cores
if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) {
mma_type = mma_layout::FP32_FP16_FP16_FP32;
return mma_type;
}
if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) {
mma_type = mma_layout::FP32_BF16_BF16_FP32;
return mma_type;
}
if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty()
&& dot->allow_tf32()) {
mma_type = mma_layout::FP32_TF32_TF32_FP32;
return mma_type;
}
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
throw std::runtime_error("integer tensor cores are not yet supported");
// // integer tensor cores
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
// return mma_type;
// }
// if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) {
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
// return mma_type;
// }
// if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
// mma_type = mma_layout::INT32_INT8_INT8_INT32;
// return mma_type;
// }
}
}
return mma_layout::NOT_APPLICABLE;
}
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) { inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
for(ir::user* u: v->get_users()){ for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u); auto i = dynamic_cast<ir::io_inst*>(u);
@@ -52,11 +98,12 @@ inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
} }
} }
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) { inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) {
for(ir::user* u: v->get_users()){ for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u); auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i) && i->get_operand(n) == v) if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) {
result = i; result = i;
}
} }
} }
@@ -142,7 +189,9 @@ mma_layout::mma_layout(size_t num_warps,
const std::vector<unsigned>& shape, const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values, const std::vector<ir::value *> &values,
analysis::align* align, target* tgt, analysis::align* align, target* tgt,
shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) { shared_layout *layout_a, shared_layout *layout_b,
ir::value *dot): distributed_layout(MMA, axes, shape, values, align) {
tensor_core_type_ = get_mma_type(dot);
/* fragments per warp */ /* fragments per warp */
// try to make things as square as possible to maximize data re-use // try to make things as square as possible to maximize data re-use
if(tgt->as_nvidia()->sm() < 80){ if(tgt->as_nvidia()->sm() < 80){
@@ -159,9 +208,9 @@ mma_layout::mma_layout(size_t num_warps,
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
} }
else{ else{
fpw_ = {1, 1, 1}; // fpw_ = {1, 1, 1};
spw_ = {16, 8, 1}; spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
rep_ = {2, 2, 1}; // rep_ = {2, 2, 1};
} }
order_ = {0, 1}; order_ = {0, 1};
@@ -356,7 +405,8 @@ shared_layout::shared_layout(data_layout *arg,
const std::vector<unsigned>& shape, const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values, const std::vector<ir::value *> &values,
ir::type *ty, ir::type *ty,
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) { analysis::align* align, target *tgt)
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) {
size_ = 0; size_ = 0;
arg_layout_ = arg; arg_layout_ = arg;
@@ -382,12 +432,25 @@ shared_layout::shared_layout(data_layout *arg,
for(ir::value* v: values){ for(ir::value* v: values){
extract_dot_use(v, dot_a, 0); extract_dot_use(v, dot_a, 0);
extract_dot_use(v, dot_b, 1); extract_dot_use(v, dot_b, 1);
extract_hmma_dot_use(v, hmma_dot_a, 0); extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm());
extract_hmma_dot_use(v, hmma_dot_b, 1); extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm());
} }
hmma_dot_a_ = hmma_dot_a; hmma_dot_a_ = hmma_dot_a;
hmma_dot_b_ = hmma_dot_b; hmma_dot_b_ = hmma_dot_b;
// Update mma_vec
if (hmma_dot_a_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
} else if (hmma_dot_b_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
}
// size // size
size_ = ty_->get_primitive_size_in_bits() / 8; size_ = ty_->get_primitive_size_in_bits() / 8;
for(auto s: shape_) for(auto s: shape_)
@@ -451,7 +514,8 @@ void layouts::make_graph(ir::instruction *i) {
void layouts::create(size_t id, const std::vector<ir::value*>& values) { void layouts::create(size_t id, const std::vector<ir::value*>& values) {
// if(layouts_.find(id) != layouts_.end()) // if(layouts_.find(id) != layouts_.end())
// return; // return;
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c); auto it_hmma_c = std::find_if(values.begin(), values.end(),
[&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); });
auto cmp = [](ir::value* x, ir::value *y) { auto cmp = [](ir::value* x, ir::value *y) {
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()}; std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()}; std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
@@ -473,13 +537,16 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
ir::value *b = dot->get_operand(1); ir::value *b = dot->get_operand(1);
create(groups_.at(a), values_.at(groups_.at(a))); create(groups_.at(a), values_.at(groups_.at(a)));
create(groups_.at(b), values_.at(groups_.at(b))); create(groups_.at(b), values_.at(groups_.at(b)));
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b))); layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_,
(shared_layout*)layouts_.at(groups_.at(a)),
(shared_layout*)layouts_.at(groups_.at(b)),
dot);
} }
else if(it_cts != values.end()){ else if(it_cts != values.end()){
ir::instruction *cts = (ir::instruction*)*it_cts; ir::instruction *cts = (ir::instruction*)*it_cts;
ir::value *arg = cts->get_operand(0); ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg))); create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_);
} }
else{ else{
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_); layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
@@ -516,7 +583,7 @@ void layouts::run(ir::module &mod) {
scanline_layout *layout = get(arg)->to_scanline(); scanline_layout *layout = get(arg)->to_scanline();
shapes[axis] = layout->mts(axis); shapes[axis] = layout->mts(axis);
// create layout // create layout
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_); layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[red] = id; tmp_[red] = id;
} }
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){ if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
@@ -529,12 +596,12 @@ void layouts::run(ir::module &mod) {
shape[k] = std::max(in_layout->shape_per_cta(k), shape[k] = std::max(in_layout->shape_per_cta(k),
out_layout->shape_per_cta(k)); out_layout->shape_per_cta(k));
} }
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_); layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[val] = id; tmp_[val] = id;
} }
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){ if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
id++; id++;
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_); layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[atom] = id; tmp_[atom] = id;
} }
}); });

View File

@@ -19,6 +19,7 @@ void swizzle::run(ir::module &) {
continue; continue;
ir::value* mma_dot_a = layout->hmma_dot_a(); ir::value* mma_dot_a = layout->hmma_dot_a();
ir::value* mma_dot_b = layout->hmma_dot_b(); ir::value* mma_dot_b = layout->hmma_dot_b();
if(!mma_dot_a && !mma_dot_b){ if(!mma_dot_a && !mma_dot_b){
per_phase_[layout] = 1; per_phase_[layout] = 1;
max_phase_[layout] = 1; max_phase_[layout] = 1;
@@ -39,10 +40,10 @@ void swizzle::run(ir::module &) {
else else
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
} }
else{ else {
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = 8 / per_phase_[layout]; max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
vec_[layout] = 8; vec_[layout] = layout->get_mma_vec();
} }
} }
} }

View File

@@ -85,7 +85,6 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
allocation.run(ir); allocation.run(ir);
prefetch_s.run(ir); prefetch_s.run(ir);
barriers.run(ir); barriers.run(ir);
// ir.print(std::cout);
isel.visit(ir, *llvm); isel.visit(ir, *llvm);
shared_static = allocation.allocated_size(); shared_static = allocation.allocated_size();
return llvm; return llvm;

View File

@@ -0,0 +1,78 @@
#pragma once
#include <numeric>
#include <sstream>
#include <iomanip>
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/target.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/ir/context.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/type.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
namespace triton::codegen {
// types
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
#define bf16_ty builder_->getBFloatTy()
#define f32_ty builder_->getFloatTy()
#define i8_ty builder_->getInt8Ty()
#define i32_ty builder_->getInt32Ty()
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
// constants
#define i32(...) builder_->getInt32(__VA_ARGS__)
// ops
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__)
#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__)
#define br(...) builder_->CreateBr(__VA_ARGS__)
#define call(...) builder_->CreateCall(__VA_ARGS__)
#define cast(...) builder_->CreateCast(__VA_ARGS__)
#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__)
#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__)
#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__)
#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__)
#define fadd(...) builder_->CreateFAdd(__VA_ARGS__)
#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__)
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__)
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
#define load(...) builder_->CreateLoad(__VA_ARGS__)
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
#define ret(...) builder_->CreateRet(__VA_ARGS__)
#define select(...) builder_->CreateSelect(__VA_ARGS__)
#define store(...) builder_->CreateStore(__VA_ARGS__)
#define sub(...) builder_->CreateSub(__VA_ARGS__)
#define shl(...) builder_->CreateShl(__VA_ARGS__)
#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
#define urem(...) builder_->CreateURem(__VA_ARGS__)
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
#define xor_(...) builder_->CreateXor(__VA_ARGS__)
} // namespace triton::codegen

View File

@@ -81,12 +81,13 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
// return (*builder_)->CreateGEP(ty, ptr, vals, name); // return (*builder_)->CreateGEP(ty, ptr, vals, name);
//} //}
// types // types
#define void_ty builder_->getVoidTy() #define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy() #define f16_ty builder_->getHalfTy()
#define bf16_ty builder_->getBFloatTy()
#define f32_ty builder_->getFloatTy() #define f32_ty builder_->getFloatTy()
#define i8_ty builder_->getInt8Ty() #define i8_ty builder_->getInt8Ty()
#define i16_ty builder_->getInt16Ty()
#define i32_ty builder_->getInt32Ty() #define i32_ty builder_->getInt32Ty()
#define vec_ty(type, num_el) VectorType::get(type, num_el, false) #define vec_ty(type, num_el) VectorType::get(type, num_el, false)
#define ptr_ty(...) PointerType::get(__VA_ARGS__) #define ptr_ty(...) PointerType::get(__VA_ARGS__)
@@ -133,7 +134,6 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__) #define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
#define xor_(...) builder_->CreateXor(__VA_ARGS__) #define xor_(...) builder_->CreateXor(__VA_ARGS__)
/** /**
* \brief Convert Triton-IR Type to LLVM-IR Type * \brief Convert Triton-IR Type to LLVM-IR Type
*/ */
@@ -162,7 +162,7 @@ Type *generator::cvt(ir::type *ty) {
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_); case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_);
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
@@ -457,19 +457,25 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0
} }
Value* generator::bf16_to_fp32(Value *in0){ Value* generator::bf16_to_fp32(Value *in0){
Value *ret = UndefValue::get(vec_ty(builder_->getInt16Ty(), 2)); if (tgt_->as_nvidia()->sm() >= 80) {
ret = insert_elt(ret, in0, (uint64_t)1); InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false),
ret = insert_elt(ret, builder_->getInt16(0), (uint64_t)0); "cvt.rn.f32.bf16 $0, $1;", "=r,h", false);
return bit_cast(ret, builder_->getFloatTy()); return call(ptx, {in0});
} else {
Value *ret = UndefValue::get(vec_ty(i16_ty, 2));
ret = insert_elt(ret, bit_cast(in0, i16_ty), (uint64_t)1);
ret = insert_elt(ret, bit_cast(builder_->getInt16(0), i16_ty), (uint64_t)0);
return bit_cast(ret, f32_ty);
}
} }
Value* generator::fp32_to_bf16(Value *in0){ Value* generator::fp32_to_bf16(Value *in0){
if(tgt_->as_nvidia()->sm() >= 80){ if(tgt_->as_nvidia()->sm() >= 80){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}, false), InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false),
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false); "cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
return call(ptx, {in0}); return call(ptx, {in0});
} }
return extract_elt(bit_cast(in0, vec_ty(builder_->getInt16Ty(), 2)), (uint64_t)1); return extract_elt(bit_cast(in0, vec_ty(i16_ty, 2)), (uint64_t)1);
} }
/** /**
@@ -514,12 +520,16 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){ if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
// FP32 -> BF16 // FP32 -> BF16
if(op_sca_ty->is_fp32_ty()) if(op_sca_ty->is_fp32_ty())
for(size_t i = 0; i < x_idxs.size(); i++) // for(size_t i = 0; i < x_idxs.size(); i++)
vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]); // vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
for (indices_t idx: idxs_.at(x)) {
Value *arg = vals_[x->get_operand(0)][idx];
vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
}
// BF16 -> FP32 // BF16 -> FP32
if(ret_sca_ty->is_fp32_ty()) if(ret_sca_ty->is_fp32_ty())
for(size_t i = 0; i < x_idxs.size(); i++) for(size_t i = 0; i < x_idxs.size(); i++)
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]); vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
return; return;
} }
@@ -678,6 +688,7 @@ void generator::visit_load_inst(ir::load_inst* x){
// --- // ---
std::vector<Type*> ret_tys(n_words, IntegerType::get(*ctx_, width)); std::vector<Type*> ret_tys(n_words, IntegerType::get(*ctx_, width));
Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0]; Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0];
// ret_ty->print(llvm::outs());
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()}; std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
for(Value *v: others) for(Value *v: others)
arg_tys.push_back(v->getType()); arg_tys.push_back(v->getType());
@@ -747,15 +758,19 @@ void generator::visit_store_inst(ir::store_inst * x){
} }
auto idxs = idxs_.at(val_op); auto idxs = idxs_.at(val_op);
Type *ty = cvt(val_op->get_type()->get_scalar_ty()); Type *ty = cvt(val_op->get_type()->get_scalar_ty());
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
ty = f16_ty;
for(size_t i = 0; i < idxs.size(); i += vec){ for(size_t i = 0; i < idxs.size(); i += vec){
auto idx = idxs[i]; auto idx = idxs[i];
// pointer // pointer
Value *ptr = vals_[ptr_op][idx]; Value *ptr = vals_[ptr_op][idx];
ptr = bit_cast(ptr, vec_ty(ty, vec)->getPointerTo(1)); // vectorize
Type *v_ty = vec_ty(ty, vec);
ptr = bit_cast(ptr, v_ty->getPointerTo(1));
// value // value
Value* val = UndefValue::get(vec_ty(ty, vec)); Value* val = UndefValue::get(v_ty);
for(size_t ii = 0; ii < vec; ii++) for(size_t ii = 0; ii < vec; ii++)
val = insert_elt(val, vals_.at(val_op)[idxs[i + ii]], ii); val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii);
if(mx){ if(mx){
Value *msk = vals_[mx->get_mask_operand()][idx]; Value *msk = vals_[mx->get_mask_operand()][idx];
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {}); Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
@@ -1317,6 +1332,229 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
vals_[C][idxs_[C][i]] = acc[i]; vals_[C][idxs_[C][i]] = acc[i];
} }
namespace {
class mma16816_smem_loader {
public:
mma16816_smem_loader(int wpt, std::vector<int> order, int k_order,
std::vector<unsigned> tile_shape,
std::vector<int> instr_shape, std::vector<int> mat_shape,
int per_phase, int max_phase, int dtsize, Builder *builder,
adder add, multiplier mul, geper gep)
: wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape),
instr_shape_(instr_shape), mat_shape_(mat_shape),
per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder),
add(add), mul(mul), gep(gep) {
// compute compile-time constant variables & types
c_mat_shape_ = mat_shape[order[0]];
s_mat_shape_ = mat_shape[order[1]];
c_stride_ = tile_shape[order[1]];
s_stride_ = tile_shape[order[0]];
// rule: k must be the fast-changing axis
need_trans_ = k_order_ != order_[0];
can_use_ldmatrix_ = dtsize == 2 || (!need_trans_);
// std::cout << can_use_ldmatrix_ << std::endl;
// std::cout << need_trans_ << std::endl;
// we need more pointers at the fast-changing axis,
if (can_use_ldmatrix_)
num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]];
else // warning: this only works for tf32 & need transpose
num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]];
num_ptr_ = std::max<int>(num_ptr_, 2);
// load_v4 stride (in num of mats)
int load_stride_in_mat[2];
load_stride_in_mat[k_order] = 2; // instr_shape[k_order] / mat_shape[k_order], always 2
load_stride_in_mat[k_order^1] = wpt * (instr_shape[k_order^1] / mat_shape[k_order^1]);
p_load_stride_in_mat_ = load_stride_in_mat[order[0]];
// stride in mat, used by load_v4
s_mat_stride_ = load_stride_in_mat[order[1]] / (instr_shape[order[1]]/mat_shape[order[1]]);
}
std::vector<Value*> compute_offs(Value *warp_off, Value *lane) {
// TODO: this needs to be moved to constructor (and extracted to arr_order)
mat_arr_stride_ = (k_order_ == 1) ? 1 : wpt_;
warp_off_stride_ = instr_shape_[k_order_^1] / mat_shape_[k_order_^1];
// start matrix logic offset (rename it as base_mat_off?)
Value *mat_off[2] = {nullptr, nullptr};
if (can_use_ldmatrix_) {
// c: lane idx inside a group (a group is a collection of 8 contiguous threads)
// s: group idx (0,1,2,3) inside a warp
Value *c = urem(lane, i32(8));
Value *s = udiv(lane, i32(8));
// We can decompose s => s_0, s_1...
Value *s0 = urem(s, i32(2));
Value *s1 = udiv(s, i32(2));
// We use different orders for a & b for better performance.
Value *k_mat_arr = (k_order_ == 1) ? s1 : s0;
Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1;
mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)),
mul(nk_mat_arr, i32(mat_arr_stride_)));
mat_off[k_order_] = k_mat_arr;
// physical offset (before swizzling)
Value *c_mat_off = mat_off[order_[0]];
Value *s_mat_off = mat_off[order_[1]];
// offset inside a matrix
Value *s_off_in_mat = c;
std::vector<Value*> offs(num_ptr_);
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
// pre-compute strided offset
Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
for (int i=0; i < num_ptr_; ++i) {
Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_));
c_mat_off_i = xor_(c_mat_off_i, phase); // smem swizzle
offs[i] = add(mul(c_mat_off_i, i32(c_mat_shape_)), mul(s_off, i32(s_stride_)));
}
return offs;
} else if (dtsize_ == 4 && need_trans_) {
// load tf32 matrices with lds32
Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]]
Value *s_off_in_mat = urem(lane, i32(4)); //
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
std::vector<Value*> offs(num_ptr_);
for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time
int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2;
int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2;
if (k_mat_arr_int > 0) // we don't need pointers for k
continue;
Value *k_mat_arr = i32(k_mat_arr_int);
Value *nk_mat_arr = i32(nk_mat_arr_int);
// physical offset (before swizzling)
Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)),
mul(nk_mat_arr, i32(mat_arr_stride_)));
Value *s_mat_off = k_mat_arr; // always 0?
Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
// FIXME: (k_order_ == 1?) is really dirty hack
for (int i = 0; i < num_ptr_/2; ++i) {
Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_*(k_order_ == 1?1:2)));
c_mat_off_i = xor_(c_mat_off_i, phase);
Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_)));
// TODO: move this out of the loop
c_off = urem(c_off, i32(tile_shape_[order_[0]]));
s_off = urem(s_off, i32(tile_shape_[order_[1]]));
offs[2*i + nk_mat_arr_int] = add(c_off, mul(s_off, i32(s_stride_)));
}
}
return offs;
// throw std::runtime_error("not implemented");
} else
throw std::runtime_error("invalid smem load config");
}
std::tuple<Value*, Value*, Value*, Value*>
load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn,
Value *pre_ptr, Value *next_ptr, std::vector<Value*> &off, std::vector<Value*> &ptrs,
FunctionType *ldmatrix_ty, Type *smem_ptr_ty,
std::map<ir::value*, std::vector<Value*>> &prefetch_latch_to_bb_) {
assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned");
int mat_idx[2] = {mat0, mat1};
int k = mat_idx[k_order_];
int ptr_idx = -1;
if (can_use_ldmatrix_)
ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]);
else // tf32 & trans
ptr_idx = mat_idx[order_[0]];
auto get_ptr = [&](int idx) -> Value* {
Value *ptr = nullptr;
if (k == 0 && is_prefetch) {
if (inc == 0)
ptr = bit_cast(gep(pre_ptr, off.at(idx)), smem_ptr_ty);
else
ptr = bit_cast(gep(next_ptr, off.at(idx)), smem_ptr_ty);
} else
ptr = ptrs.at(idx);
return ptr;
};
Value *ptr = get_ptr(ptr_idx);
Value *res_v4 = nullptr;
if (can_use_ldmatrix_) {
std::string trans = need_trans_ ? ".trans" : "";
// the offset (in byte) on the strided axis is a constant
int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_;
InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty,
"ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 "
"{$0, $1, $2, $3}, "
"[$4 + " + std::to_string(s_offset) + "];",
"=r,=r,=r,=r,r", true);
assert(ptr);
res_v4 = call(ldmatrix_ty, ld_fn, {ptr});
if (k == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4);
return {extract_val(res_v4, std::vector<unsigned>{0}),
extract_val(res_v4, std::vector<unsigned>{1}),
extract_val(res_v4, std::vector<unsigned>{2}),
extract_val(res_v4, std::vector<unsigned>{3})};
} else {
// assert(false && "should not be here");
assert(dtsize_ == 4 && need_trans_);
Value *ptr2 = get_ptr(ptr_idx+1);
assert(s_mat_stride_ == 1);
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_;
Value *elem0, *elem1, *elem2, *elem3;
if (k_order_ == 1) {
elem0 = load(gep(ptr, i32(s_offset_elem)));
elem1 = load(gep(ptr2, i32(s_offset_elem)));
elem2 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
} else { // for b (k first)
elem0 = load(gep(ptr, i32(s_offset_elem)));
elem2 = load(gep(ptr2, i32(s_offset_elem)));
elem1 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
}
if (k == 0 && inc == 1 && is_prefetch) {
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem0);
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem1);
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem2);
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3);
}
return {elem0, elem1, elem2, elem3};
}
}
int get_num_ptr() const { return num_ptr_; }
private:
int wpt_;
std::vector<int> order_;
int k_order_;
std::vector<unsigned> tile_shape_;
std::vector<int> instr_shape_;
std::vector<int> mat_shape_;
int per_phase_, max_phase_;
int dtsize_;
// generated
int c_mat_shape_, s_mat_shape_;
int c_stride_, s_stride_;
// p_: on the pointer axis
int p_load_stride_in_mat_;
int s_mat_stride_;
// stride when moving to next not-k mat
int warp_off_stride_;
int mat_arr_stride_; // matrix arrangement (inside a load) stride
bool need_trans_, can_use_ldmatrix_;
int num_ptr_;
Builder *builder_;
adder add;
multiplier mul;
geper gep;
};
}
/** /**
* \brief Code Generation for `mma.16816` (A100) * \brief Code Generation for `mma.16816` (A100)
*/ */
@@ -1338,35 +1576,65 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1)); analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
bool is_a_row = ord_a[0] == 1; bool is_a_row = ord_a[0] == 1;
bool is_b_row = ord_b[0] == 1; bool is_b_row = ord_b[0] == 1;
std::string a_trans = is_a_row ? "" : ".trans";
std::string b_trans = is_b_row ? ".trans" : ""; std::vector<int> mma_instr_shape = layout->get_mma_instr_shape();
int stride_a_m = is_a_row ? shape_a[1] : 1; const int mma_instr_m = mma_instr_shape[0];
int stride_a_k = is_a_row ? 1 : shape_a[0]; const int mma_instr_n = mma_instr_shape[1];
int stride_b_n = is_b_row ? 1 : shape_b[0]; const int mma_instr_k = mma_instr_shape[2];
int stride_b_k = is_b_row ? shape_b[1] : 1;
int stride_a0 = is_a_row ? stride_a_k : stride_a_m; std::vector<int> mat_shape = layout->get_mma_mat_shape();
int stride_a1 = is_a_row ? stride_a_m : stride_a_k; const int mat_shape_m = mat_shape[0];
int stride_b0 = is_b_row ? stride_b_n : stride_b_k; const int mat_shape_n = mat_shape[1];
int stride_b1 = is_b_row ? stride_b_k : stride_b_n; const int mat_shape_k = mat_shape[2];
int lda = is_a_row ? stride_a_m : stride_a_k;
int ldb = is_b_row ? stride_b_k : stride_b_n; const int per_phase_a = swizzle_->get_per_phase(layout_a);
int per_phase_a = swizzle_->get_per_phase(layout_a); const int max_phase_a = swizzle_->get_max_phase(layout_a);
int max_phase_a = swizzle_->get_max_phase(layout_a); const int per_phase_b = swizzle_->get_per_phase(layout_b);
int per_phase_b = swizzle_->get_per_phase(layout_b); const int max_phase_b = swizzle_->get_max_phase(layout_b);
int max_phase_b = swizzle_->get_max_phase(layout_b);
int num_ptr_a = 8; const int num_rep_m = shapes[0] / layout->shape_per_cta(0);
int num_ptr_b = 8; const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
int vec_a = 8; const int num_rep_k = std::max<int>(NK/mma_instr_k, 1);
int vec_b = 8;
Type *fp32_ty = f32_ty; Type *fp32_ty = f32_ty;
Type *fp16x2_ty = vec_ty(f16_ty, 2); Type *fp16x2_ty = vec_ty(f16_ty, 2);
Type *bf16x2_ty = vec_ty(bf16_ty, 2);
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty});
Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty}); Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty});
FunctionType *ld_x4_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{ptr_ty(f16_ty, 3)}, false);
FunctionType *ldmatrix_ty = nullptr;
FunctionType *mma_ty = nullptr;
Type *phi_ty = nullptr;
Type *smem_ptr_ty = nullptr;
ir::type *A_ir_ty = A->get_type()->get_scalar_ty();
ir::type *B_ir_ty = B->get_type()->get_scalar_ty();
if (A_ir_ty->is_fp16_ty() && B_ir_ty->is_fp16_ty()) {
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
smem_ptr_ty = ptr_ty(f16_ty, 3);
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = fp16x2_ty;
} else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) {
// FIXME: We should use bf16 here.
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
smem_ptr_ty = ptr_ty(f16_ty, 3);
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = fp16x2_ty;
// mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
// smem_ptr_ty = ptr_ty(bf16_ty, 3);
// ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
// phi_ty = bf16x2_ty;
} else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) {
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
smem_ptr_ty = ptr_ty(fp32_ty, 3);
ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = fp32_ty;
} else
throw std::runtime_error("mma16816 data type not supported");
// left-hand-side values // left-hand-side values
std::map<std::pair<unsigned, unsigned>, std::pair<Value*, Value*>> ha; std::map<std::pair<unsigned, unsigned>, Value*> ha;
std::map<std::pair<unsigned, unsigned>, Value*> hb; std::map<std::pair<unsigned, unsigned>, Value*> hb;
BasicBlock* CurrBB = builder_->GetInsertBlock(); BasicBlock* CurrBB = builder_->GetInsertBlock();
@@ -1377,79 +1645,66 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
Value* thread = tgt_->get_local_id(mod_, *builder_, 0); Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value *lane = urem(thread, i32(32)); Value *lane = urem(thread, i32(32));
Value *warp = udiv(thread, i32(32)); Value *warp = udiv(thread, i32(32));
Value *warp12 = udiv(warp, i32(layout->wpt(0))); Value *warp_mn = udiv(warp, i32(layout->wpt(0)));
Value *warp0 = urem(warp, i32(layout->wpt(0))); Value *warp_m = urem(warp, i32(layout->wpt(0)));
Value *warp1 = urem(warp12, i32(layout->wpt(1))); Value *warp_n = urem(warp_mn, i32(layout->wpt(1)));
std::vector<Value *>& fc = fcs.begin()->second; std::vector<Value *>& fc = fcs.begin()->second;
Value *tidr8 = urem(lane, i32(8)); size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
Value *phase_a = urem(udiv(tidr8, i32(per_phase_a)), i32(max_phase_a)); size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
Value* off_a0 = mul(tidr8, i32(lda));
Value *off_am = mul(add(urem(udiv(lane, i32(8)), i32(2)), mul(warp0, i32(2))), i32(8));
Value *off_ak = mul(udiv(lane, i32(16)), i32(8));
off_am = urem(off_am, i32(shape_a[0]));
off_ak = urem(off_ak, i32(shape_a[1]));
off_a0 = add(off_a0, is_a_row ? off_ak : off_am);
Value* off_a1 = is_a_row ? off_am : off_ak;
std::vector<Value*> off_a(num_ptr_a);
for(int i = 0; i < num_ptr_a; i++){
Value* off_a0i = add(off_a0, i32(i*16*(is_a_row?1:layout->wpt(0))));
off_a0i = exact_udiv(off_a0i, i32(vec_a));
off_a0i = xor_(off_a0i, phase_a);
off_a0i = mul(off_a0i, i32(vec_a));
off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1)));
}
Value *phase_b = urem(udiv(tidr8, i32(per_phase_b)), i32(max_phase_b)); // | -> k (row-major), since we have ldmatrix.trans, we only need to change stride
Value* off_b0 = mul(tidr8, i32(ldb)); // v (s0_0(0), s1_0(2), | *num_rep_k
Value *off_bn = mul(add(mul(udiv(lane, i32(16)), i32(layout->wpt(1))), mul(warp1, i32(1))), i32(8)); // m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2)
Value *off_bk = mul(urem(udiv(lane, i32(8)), i32(2)), i32(8)); // -----------
off_bn = urem(off_bn, i32(shape_b[1])); // *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0))
off_bk = urem(off_bk, i32(shape_b[0])); mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
off_b0 = add(off_b0, is_b_row ? off_bn : off_bk); {mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
Value* off_b1 = is_b_row ? off_bk : off_bn; per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
std::vector<Value*> off_b(num_ptr_b); std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
for(int i = 0; i < num_ptr_b; i++){ int num_ptr_a = a_loader.get_num_ptr();
Value* off_b0i = add(off_b0, i32(i*(is_b_row?8*layout->wpt(1):16)));
off_b0i = exact_udiv(off_b0i, i32(vec_b)); // | -> n (col-major)
off_b0i = xor_(off_b0i, phase_b); // v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n
off_b0i = mul(off_b0i, i32(vec_b)); // k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1))
off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1))); // -----------
} // *num_rep_k (stride in num of matrices(mat_stride_bk): 2)
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b,
{mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n},
per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep);
std::vector<Value*> off_b = b_loader.compute_offs(warp_n, lane);
int num_ptr_b = b_loader.get_num_ptr();
builder_->SetInsertPoint(CurrBB); builder_->SetInsertPoint(CurrBB);
// A pointer // A pointer
std::vector<Value*> ptrs_a(num_ptr_a); std::vector<Value*> ptrs_a(num_ptr_a);
for(int i = 0; i < num_ptr_a; i++) for(int i = 0; i < num_ptr_a; i++)
ptrs_a[i] = gep(shmems_[A], {off_a[i]}); ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
// B pointer // B pointer
std::vector<Value*> ptrs_b(num_ptr_b); std::vector<Value*> ptrs_b(num_ptr_b);
for(int i = 0; i < num_ptr_b; i++) for(int i = 0; i < num_ptr_b; i++)
ptrs_b[i] = gep(shmems_[B], {off_b[i]}); ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty);
FunctionType *mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() +
InlineAsm *mma_fn = InlineAsm::get(mma_ty, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " " {$0, $1, $2, $3},"
"{$0, $1, $2, $3}, " " {$4, $5, $6, $7},"
"{$4, $5, $6, $7}, " " {$8, $9},"
"{$8, $9}, " " {$10, $11, $12, $13};",
"{$10, $11, $12, $13};",
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true); "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
unsigned num_rep_0 = shapes[0] / layout->shape_per_cta(0); // create mma & unpack result, m, n, k are offsets in mat
unsigned num_rep_1 = shapes[1] / layout->shape_per_cta(1); auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
unsigned cols_per_thread = num_rep_m * 2;
// create mma & unpack result
auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
unsigned cols_per_thread = num_rep_0 * 2;
std::vector<size_t> idx = { std::vector<size_t> idx = {
(m*2 + 0) + (n*2 + 0)*cols_per_thread, (m + 0) + (n*2 + 0)*cols_per_thread,
(m*2 + 0) + (n*2 + 1)*cols_per_thread, (m + 0) + (n*2 + 1)*cols_per_thread,
(m*2 + 1) + (n*2 + 0)*cols_per_thread, (m + 1) + (n*2 + 0)*cols_per_thread,
(m*2 + 1) + (n*2 + 1)*cols_per_thread (m + 1) + (n*2 + 1)*cols_per_thread
}; };
Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second, Value *nc = call(mma_ty, mma_fn,
hb[{n, K}], hb[{n, K+8}], {ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}],
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]}); hb[{n, k}], hb[{n, k+1}],
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0}); fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1}); fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2}); fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
@@ -1459,131 +1714,83 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A); ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B); ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
auto register_lds =
[&](decltype(ha)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
if (K <= 8 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
} else
vals[{m, K}] = {val0, val1};
};
auto register_lds2 = auto register_lds2 =
[&](decltype(hb)& vals, int m, int K, int inc, Value* val, bool is_prefetch) { [&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int n, int k, int inc, Value* val, bool is_prefetch) {
if (K <= 8 && is_prefetch) { if (k < 2 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc); ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}], val, inc_block)); lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{n, k}], val, inc_block));
} else } else
vals[{m, K}] = val; vals[{n, k}] = val;
}; };
auto load_a = [&](int m, int K, int inc, bool is_prefetch) { auto load_a = [&](int m, int k, int inc, bool is_prefetch) {
int offidx = (is_a_row ? K/16 : m) % num_ptr_a; auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
Value* ptra; shared_next_ptr_[layout_a], off_a, ptrs_a,
if(K == 0 && is_prefetch){ ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
if(inc == 0) register_lds2(ha, m, k, inc, ha0, is_prefetch);
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]); register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
else register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]); register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
}
else
ptra = ptrs_a[offidx];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K;
InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 "
"{$0, $1, $2, $3}, [$4 + " +
std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];",
"=r,=r,=r,=r,r", true);
Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra});
if(K == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(haa);
Value *ha0 = extract_val(haa, std::vector<unsigned>{0});
Value *ha1 = extract_val(haa, std::vector<unsigned>{1});
Value *ha2 = extract_val(haa, std::vector<unsigned>{2});
Value *ha3 = extract_val(haa, std::vector<unsigned>{3});
register_lds(ha, m, K, inc, ha0, ha1, is_prefetch);
register_lds(ha, m, K + 8, inc, ha2, ha3, is_prefetch);
}; };
auto load_b = [&](int n, int K, int inc, bool is_prefetch) { auto load_b = [&](int n, int k, int inc, bool is_prefetch) {
int offidx = (is_b_row ? n : K/16) % num_ptr_b; auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
Value* ptrb; shared_next_ptr_[layout_b], off_b, ptrs_b,
if(K == 0 && is_prefetch){ ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
if(inc == 0) register_lds2(hb, n, k, inc, hb0, is_prefetch);
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]); register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
else register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]); register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
}
else
ptrb = ptrs_b[offidx];
int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8);
InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 "
"{$0, $1, $2, $3}, [$4 + " +
std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];",
"=r,=r,=r,=r,r", true);
Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb});
if(K == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hbb);
Value *hb0 = extract_val(hbb, std::vector<unsigned>{0});
Value *hb1 = extract_val(hbb, std::vector<unsigned>{1});
Value *hb2 = extract_val(hbb, std::vector<unsigned>{2});
Value *hb3 = extract_val(hbb, std::vector<unsigned>{3});
register_lds2(hb, n, K, inc, hb0, is_prefetch);
register_lds2(hb, n+1, K, inc, hb2, is_prefetch);
register_lds2(hb, n, K+8, inc, hb1, is_prefetch);
register_lds2(hb, n+1, K+8, inc, hb3, is_prefetch);
}; };
if (C->is_prefetched()) { if (C->is_prefetched()) {
// create phis // create phis
builder_->SetInsertPoint(CurrBB->getFirstNonPHI()); builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
for(unsigned m = 0; m < num_rep_0; m++){ for(unsigned m = 0; m < num_rep_m; m++){
ha[{m, 0}].first = phi(fp16x2_ty, 2); ha[{2*m, 0}] = phi(phi_ty, 2);
ha[{m, 0}].second = phi(fp16x2_ty, 2); ha[{2*m+1, 0}] = phi(phi_ty, 2);
ha[{m, 8}].first = phi(fp16x2_ty, 2); ha[{2*m, 1}] = phi(phi_ty, 2);
ha[{m, 8}].second = phi(fp16x2_ty, 2); ha[{2*m+1, 1}] = phi(phi_ty, 2);
} }
for(unsigned n = 0; n < num_rep_1; n+=2){ for(unsigned n = 0; n < num_rep_n; n+=2){
hb[{n, 0}] = phi(fp16x2_ty, 2); hb[{n, 0}] = phi(phi_ty, 2);
hb[{n+1, 0}] = phi(fp16x2_ty, 2); hb[{n+1, 0}] = phi(phi_ty, 2);
hb[{n, 8}] = phi(fp16x2_ty, 2); hb[{n, 1}] = phi(phi_ty, 2);
hb[{n+1, 8}] = phi(fp16x2_ty, 2); hb[{n+1, 1}] = phi(phi_ty, 2);
} }
// insert prefetched lds at the end of loop header // insert prefetched lds at the end of loop header
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator()); builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
for(unsigned m = 0; m < num_rep_0; m++) for(unsigned m = 0; m < num_rep_m; m++)
load_a(m, 0, 0, true); load_a(2*m, 0, 0, true);
for(unsigned n = 0; n < num_rep_1; n+=2) for(unsigned n = 0; n < num_rep_n; n+=2)
load_b(n, 0, 0, true); load_b(n, 0, 0, true);
// update accumulators // update accumulators
builder_->SetInsertPoint(CurrBB); builder_->SetInsertPoint(CurrBB);
for(unsigned K = 0; K < NK; K += 16){ for(unsigned k = 0; k < num_rep_k; ++k){ // stride of instr in mat is 2
int NEXTK = (K + 16) % NK; int next_k = (k + 1) % num_rep_k;
// prefetch A // prefetch A
for(unsigned m = 0; m < num_rep_0; m++) for(unsigned m = 0; m < num_rep_m; m++)
load_a(m, NEXTK, 1, true); load_a(2*m, 2*next_k, 1, true);
// prefetch B // prefetch B
for(unsigned n = 0; n < num_rep_1; n+=2) for(unsigned n = 0; n < num_rep_n; n+=2)
load_b(n, NEXTK, 1, true); load_b(n, 2*next_k, 1, true);
// tensor core ops // tensor core ops
for(unsigned m = 0; m < num_rep_0; m++) for(unsigned m = 0; m < num_rep_m; m++)
for(unsigned n = 0; n < num_rep_1; n++){ for(unsigned n = 0; n < num_rep_n; n++){
call_mma(m, n, K); call_mma(2*m, n, 2*k);
} }
} }
} }
else{ else{
for(unsigned K = 0; K < NK; K += 16) for (unsigned k = 0; k < num_rep_k; k++) {
for(unsigned m = 0; m < num_rep_0; m++) for (unsigned m = 0; m < num_rep_m; m++)
for(unsigned n = 0; n < num_rep_1; n++){ load_a(2*m, 2*k, 0, /*is_prefetch*/false);
if(ha.find({m, K}) == ha.end()) for (unsigned n = 0; n < num_rep_n; n+=2)
load_a(m, K, 0, false); load_b(n, 2*k, 0, /*is_prefetch*/false);
if(hb.find({n, K})==hb.end()) for (unsigned m = 0; m < num_rep_m; m++)
load_b(n, K, 0, false); for (unsigned n = 0; n < num_rep_n; n++)
call_mma(m, n, K); call_mma(2*m, n, 2*k);
} }
} }
// write back // write back
unsigned i = 0; unsigned i = 0;
@@ -1714,7 +1921,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80) if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
return visit_mma884(dot, A, B, D, NK); return visit_mma884(dot, A, B, D, NK);
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80) if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
return visit_mma16816(dot, A, B, D, NK); return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
} }
@@ -1752,13 +1959,13 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false); InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
if(ty->getPrimitiveSizeInBits() <= 32) if(ty->getPrimitiveSizeInBits() <= 32)
return call(shfl, {acc, i32(i)}); return call(shfl, {acc, i32(i)});
acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2)); acc = bit_cast(acc, vec_ty(f32_ty, 2));
Value* acc0 = builder_->CreateExtractElement(acc, i32(0)); Value* acc0 = builder_->CreateExtractElement(acc, i32(0));
Value* acc1 = builder_->CreateExtractElement(acc, i32(1)); Value* acc1 = builder_->CreateExtractElement(acc, i32(1));
Value* ret = UndefValue::get(vec_ty(f32_ty, 2)); Value* ret = UndefValue::get(vec_ty(f32_ty, 2));
ret = insert_elt(ret, shfl_sync(acc0, i), i32(0)); ret = insert_elt(ret, shfl_sync(acc0, i), i32(0));
ret = insert_elt(ret, shfl_sync(acc1, i), i32(1)); ret = insert_elt(ret, shfl_sync(acc1, i), i32(1));
return builder_->CreateBitCast(ret, ty); return bit_cast(ret, ty);
} }
/** /**
@@ -1936,6 +2143,10 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes(); ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes();
// pointer to temporary shared memory // pointer to temporary shared memory
Type *ty = cvt(out->get_type()->get_scalar_ty()); Type *ty = cvt(out->get_type()->get_scalar_ty());
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
ty = f16_ty;
// Orders // Orders
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in)); analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out)); analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
@@ -1976,7 +2187,7 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
// stash value to shared mem // stash value to shared mem
indices_t idxs = {in_ax[0][i*max_ii + ii], indices_t idxs = {in_ax[0][i*max_ii + ii],
in_ax[1][j*max_jj + jj]}; in_ax[1][j*max_jj + jj]};
store(vals_[in][idxs], ptr); store(bit_cast(vals_[in][idxs], ty), ptr);
} }
add_barrier(); add_barrier();
max_ii = out_ax[0].size()/n_reps[0]; max_ii = out_ax[0].size()/n_reps[0];

View File

@@ -80,7 +80,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
ir::value *a = dot->get_operand(0); ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1); ir::value *b = dot->get_operand(1);
builder.set_insert_point(add); builder.set_insert_point(add);
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name())); ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name()));
add->replace_all_uses_with(new_dot); add->replace_all_uses_with(new_dot);
return true; return true;
} }

View File

@@ -29,8 +29,13 @@ void prefetch::run(ir::module &mod) {
std::vector<ir::dot_inst*> to_prefetch; std::vector<ir::dot_inst*> to_prefetch;
ir::for_each_instruction(mod, [&](ir::instruction *i) { ir::for_each_instruction(mod, [&](ir::instruction *i) {
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) { if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
// Now only do prefetching when dot is fp16 // Now only do prefetching when dot is using tensor cores
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::FP16TyID) if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
)
)
return; return;
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0)); auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1)); auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));

View File

@@ -121,6 +121,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
init_llvm(); init_llvm();
// verify and store llvm // verify and store llvm
llvm::legacy::PassManager pm; llvm::legacy::PassManager pm;
// module->print(llvm::outs(), nullptr);
pm.add(llvm::createVerifierPass()); pm.add(llvm::createVerifierPass());
pm.run(*module); pm.run(*module);
// create machine // create machine

View File

@@ -367,8 +367,8 @@ value *builder::create_log(value *arg){
return insert(log_inst::create(arg)); return insert(log_inst::create(arg));
} }
value *builder::create_dot(value *A, value *B, value *C) { value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) {
return insert(dot_inst::create_nn(A, B, C)); return insert(dot_inst::create_nn(A, B, C, allow_tf32));
} }
value *builder::create_trans(value *A, const std::vector<int>& perm) { value *builder::create_trans(value *A, const std::vector<int>& perm) {

View File

@@ -718,12 +718,13 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask
// Linear Algebra // Linear Algebra
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) { ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) {
ir::value *_0 = builder->get_float32(0); ir::value *_0 = builder->get_float32(0);
unsigned M = lhs->get_type()->get_block_shapes()[0]; unsigned M = lhs->get_type()->get_block_shapes()[0];
unsigned N = rhs->get_type()->get_block_shapes()[1]; unsigned N = rhs->get_type()->get_block_shapes()[1];
_0 = builder->create_splat(_0, {M, N}); _0 = builder->create_splat(_0, {M, N});
return builder->create_dot(lhs, rhs, _0); bool _allow_tf32 = allow_tf32->get_value() != 0;
return builder->create_dot(lhs, rhs, _0, _allow_tf32);
} }

View File

@@ -577,40 +577,41 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
// matmul_inst classes // matmul_inst classes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
const std::string &name, instruction *next) const std::string &name, instruction *next)
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) { : builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
set_operand(0, A); set_operand(0, A);
set_operand(1, B); set_operand(1, B);
set_operand(2, C); set_operand(2, C);
allow_tf32_ = allow_tf32;
} }
instruction *dot_inst::create(value *A, value *B, value *C, instruction *dot_inst::create(value *A, value *B, value *C,
bool AT, bool BT, bool AT, bool BT, bool allow_tf32,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
TransT OPA = AT ? Trans : NoTrans; TransT OPA = AT ? Trans : NoTrans;
TransT OPB = BT ? Trans : NoTrans; TransT OPB = BT ? Trans : NoTrans;
return new dot_inst(A, B, C, OPA, OPB, name, next); return new dot_inst(A, B, C, OPA, OPB, allow_tf32, name, next);
} }
instruction *dot_inst::create_nn(value *A, value *B, value *C, instruction *dot_inst::create_nn(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new dot_inst(A, B, C, NoTrans, NoTrans, name, next); return new dot_inst(A, B, C, NoTrans, NoTrans, allow_tf32, name, next);
} }
instruction *dot_inst::create_nt(value *A, value *B, value *C, instruction *dot_inst::create_nt(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new dot_inst(A, B, C, NoTrans, Trans, name, next); return new dot_inst(A, B, C, NoTrans, Trans, allow_tf32, name, next);
} }
instruction *dot_inst::create_tn(value *A, value *B, value *C, instruction *dot_inst::create_tn(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new dot_inst(A, B, C, Trans, NoTrans, name, next); return new dot_inst(A, B, C, Trans, NoTrans, allow_tf32, name, next);
} }
instruction *dot_inst::create_tt(value *A, value *B, value *C, instruction *dot_inst::create_tt(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new dot_inst(A, B, C, Trans, Trans, name, next); return new dot_inst(A, B, C, Trans, Trans, allow_tf32, name, next);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@@ -10,6 +10,7 @@ import torch
from numpy.random import RandomState from numpy.random import RandomState
import triton import triton
import triton._C.libtriton.triton as _triton
import triton.language as tl import triton.language as tl
from triton.code_gen import TensorWrapper, reinterpret from triton.code_gen import TensorWrapper, reinterpret
@@ -660,22 +661,26 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# --------------- # ---------------
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']) @pytest.mark.parametrize("epilogue, allow_tf32",
def test_dot(epilogue, device='cuda'): [(epilogue, allow_tf32)
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
for allow_tf32 in [True, False]])
def test_dot(epilogue, allow_tf32, device='cuda'):
# triton kernel # triton kernel
@triton.jit @triton.jit
def kernel(X, stride_xm, stride_xk, def kernel(X, stride_xm, stride_xk,
Y, stride_yk, stride_yn, Y, stride_yk, stride_yn,
Z, stride_zm, stride_zn, Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr): ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
ALLOW_TF32: tl.constexpr):
off_m = tl.arange(0, BLOCK_M) off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N) off_n = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, BLOCK_K) off_k = tl.arange(0, BLOCK_K)
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
z = tl.dot(tl.load(Xs), tl.load(Ys)) z = tl.dot(tl.load(Xs), tl.load(Ys), allow_tf32=ALLOW_TF32)
if ADD_MATRIX: if ADD_MATRIX:
z += tl.load(Zs) z += tl.load(Zs)
if ADD_ROWS: if ADD_ROWS:
@@ -690,6 +695,12 @@ def test_dot(epilogue, device='cuda'):
rs = RandomState(17) rs = RandomState(17)
x = numpy_random((M, K), dtype_str='float32', rs=rs) x = numpy_random((M, K), dtype_str='float32', rs=rs)
y = numpy_random((K, N), dtype_str='float32', rs=rs) y = numpy_random((K, N), dtype_str='float32', rs=rs)
if allow_tf32:
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80:
pytest.skip("Only test tf32 on devices with sm >= 80")
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
x_tri = to_triton(x, device=device) x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device) y_tri = to_triton(y, device=device)
# triton result # triton result
@@ -703,7 +714,8 @@ def test_dot(epilogue, device='cuda'):
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix', ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows', ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols') ADD_COLS=epilogue == 'add-cols',
ALLOW_TF32=allow_tf32)
# torch result # torch result
z_ref = np.matmul(x, y) z_ref = np.matmul(x, y)
if epilogue == 'add-matrix': if epilogue == 'add-matrix':
@@ -718,6 +730,8 @@ def test_dot(epilogue, device='cuda'):
ptx = pgm.asm['ptx'] ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx assert 'st.global.v4' in ptx
if allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
def test_dot_without_load(): def test_dot_without_load():

View File

@@ -524,18 +524,18 @@ def reshape(input, shape, _builder=None):
@builtin @builtin
def dot(input, other, _builder=None): def dot(input, other, allow_tf32=True, _builder=None):
""" """
Returns the matrix product of two blocks. Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions. The two blocks must be two dimensionals and have compatible inner dimensions.
:param input: The first block to be multiplied. :param input: The first block to be multiplied.
:type input: 2D block of scalar-type in {:code:`float16`, :code:`float32`} :type input: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
:param other: The second block to be multiplied. :param other: The second block to be multiplied.
:type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`} :type other: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
""" """
return frontend.dot(input, other, _builder) return frontend.dot(input, other, allow_tf32, _builder)
# ----------------------- # -----------------------