[BACKEND] Add bf16 & tf32 mma supports (on A100) (#426)
This commit is contained in:
@@ -109,6 +109,63 @@ protected:
|
|||||||
};
|
};
|
||||||
|
|
||||||
class mma_layout: public distributed_layout {
|
class mma_layout: public distributed_layout {
|
||||||
|
public:
|
||||||
|
enum TensorCoreType : uint8_t {
|
||||||
|
// floating-point tensor core instr
|
||||||
|
FP32_FP16_FP16_FP32 = 0, // default
|
||||||
|
FP32_BF16_BF16_FP32,
|
||||||
|
FP32_TF32_TF32_FP32,
|
||||||
|
// integer tensor core instr
|
||||||
|
INT32_INT1_INT1_INT32, // Not implemented
|
||||||
|
INT32_INT4_INT4_INT32, // Not implemented
|
||||||
|
INT32_INT8_INT8_INT32, // Not implemented
|
||||||
|
//
|
||||||
|
NOT_APPLICABLE,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Used on nvidia GPUs with sm >= 80
|
||||||
|
inline static const std::map<TensorCoreType, std::vector<int>> mma_instr_shape_ = {
|
||||||
|
{FP32_FP16_FP16_FP32, {16, 8, 16}},
|
||||||
|
{FP32_BF16_BF16_FP32, {16, 8, 16}},
|
||||||
|
{FP32_TF32_TF32_FP32, {16, 8, 8}},
|
||||||
|
|
||||||
|
{INT32_INT1_INT1_INT32, {16, 8, 256}},
|
||||||
|
{INT32_INT4_INT4_INT32, {16, 8, 64}},
|
||||||
|
{INT32_INT8_INT8_INT32, {16, 8, 32}},
|
||||||
|
};
|
||||||
|
|
||||||
|
// shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices)
|
||||||
|
inline static const std::map<TensorCoreType, std::vector<int>> mma_mat_shape_ = {
|
||||||
|
{FP32_FP16_FP16_FP32, {8, 8, 8}},
|
||||||
|
{FP32_BF16_BF16_FP32, {8, 8, 8}},
|
||||||
|
{FP32_TF32_TF32_FP32, {8, 8, 4}},
|
||||||
|
|
||||||
|
{INT32_INT1_INT1_INT32, {8, 8, 64}},
|
||||||
|
{INT32_INT4_INT4_INT32, {8, 8, 32}},
|
||||||
|
{INT32_INT8_INT8_INT32, {8, 8, 16}},
|
||||||
|
};
|
||||||
|
|
||||||
|
inline static const std::map<TensorCoreType, std::string> mma_instr_ptx_ = {
|
||||||
|
{FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
|
||||||
|
{FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"},
|
||||||
|
{FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"},
|
||||||
|
|
||||||
|
{INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"},
|
||||||
|
{INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"},
|
||||||
|
{INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
|
||||||
|
};
|
||||||
|
|
||||||
|
// vector length per ldmatrix (16*8/elelment_size_in_bits)
|
||||||
|
inline static const std::map<TensorCoreType, int> mma_instr_vec_ = {
|
||||||
|
{FP32_FP16_FP16_FP32, 8},
|
||||||
|
{FP32_BF16_BF16_FP32, 8},
|
||||||
|
{FP32_TF32_TF32_FP32, 4},
|
||||||
|
|
||||||
|
{INT32_INT1_INT1_INT32, 128},
|
||||||
|
{INT32_INT4_INT4_INT32, 32},
|
||||||
|
{INT32_INT8_INT8_INT32, 16},
|
||||||
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
mma_layout(size_t num_warps,
|
mma_layout(size_t num_warps,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
@@ -116,7 +173,8 @@ public:
|
|||||||
const std::vector<ir::value *> &values,
|
const std::vector<ir::value *> &values,
|
||||||
analysis::align* align, target *tgt,
|
analysis::align* align, target *tgt,
|
||||||
shared_layout* layout_a,
|
shared_layout* layout_a,
|
||||||
shared_layout* layout_b);
|
shared_layout* layout_b,
|
||||||
|
ir::value *dot);
|
||||||
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
|
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
|
||||||
// accessor
|
// accessor
|
||||||
int fpw(size_t k) { return fpw_.at(k); }
|
int fpw(size_t k) { return fpw_.at(k); }
|
||||||
@@ -124,6 +182,16 @@ public:
|
|||||||
int spw(size_t k) { return spw_.at(k); }
|
int spw(size_t k) { return spw_.at(k); }
|
||||||
int rep(size_t k) { return rep_.at(k); }
|
int rep(size_t k) { return rep_.at(k); }
|
||||||
|
|
||||||
|
// helpers for generator.cc
|
||||||
|
std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); }
|
||||||
|
std::vector<int> get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); }
|
||||||
|
std::vector<int> get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); }
|
||||||
|
int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); }
|
||||||
|
int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); }
|
||||||
|
|
||||||
|
// setter
|
||||||
|
void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// fragment per warp
|
// fragment per warp
|
||||||
std::vector<int> fpw_;
|
std::vector<int> fpw_;
|
||||||
@@ -135,6 +203,8 @@ private:
|
|||||||
std::vector<int> spt_;
|
std::vector<int> spt_;
|
||||||
// repetitions
|
// repetitions
|
||||||
std::vector<int> rep_;
|
std::vector<int> rep_;
|
||||||
|
|
||||||
|
TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct scanline_layout: public distributed_layout {
|
struct scanline_layout: public distributed_layout {
|
||||||
@@ -182,7 +252,7 @@ public:
|
|||||||
const std::vector<unsigned>& shapes,
|
const std::vector<unsigned>& shapes,
|
||||||
const std::vector<ir::value *> &values_,
|
const std::vector<ir::value *> &values_,
|
||||||
ir::type *ty,
|
ir::type *ty,
|
||||||
analysis::align* align);
|
analysis::align* align, target *tgt);
|
||||||
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
|
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
|
||||||
// accessors
|
// accessors
|
||||||
size_t get_size() { return size_; }
|
size_t get_size() { return size_; }
|
||||||
@@ -197,6 +267,7 @@ public:
|
|||||||
ir::value* hmma_dot_b() { return hmma_dot_b_; }
|
ir::value* hmma_dot_b() { return hmma_dot_b_; }
|
||||||
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
|
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
|
||||||
int get_mma_vec() { return mma_vec_;}
|
int get_mma_vec() { return mma_vec_;}
|
||||||
|
int get_mma_strided() { return mma_strided_; }
|
||||||
data_layout* get_arg_layout() { return arg_layout_; }
|
data_layout* get_arg_layout() { return arg_layout_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -209,6 +280,8 @@ private:
|
|||||||
ir::value* hmma_dot_b_;
|
ir::value* hmma_dot_b_;
|
||||||
data_layout* arg_layout_;
|
data_layout* arg_layout_;
|
||||||
int mma_vec_;
|
int mma_vec_;
|
||||||
|
int mma_strided_;
|
||||||
|
target *tgt_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@@ -154,7 +154,7 @@ public:
|
|||||||
value *create_cos(value* arg);
|
value *create_cos(value* arg);
|
||||||
value *create_sin(value* arg);
|
value *create_sin(value* arg);
|
||||||
value *create_log(value* arg);
|
value *create_log(value* arg);
|
||||||
value *create_dot(value *A, value *B, value *C);
|
value *create_dot(value *A, value *B, value *C, bool allow_tf32);
|
||||||
value *create_trans(value *A, const std::vector<int> &perm = {});
|
value *create_trans(value *A, const std::vector<int> &perm = {});
|
||||||
value *create_sqrt(value *A);
|
value *create_sqrt(value *A);
|
||||||
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
|
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
|
||||||
|
@@ -80,7 +80,7 @@ struct dispatch{
|
|||||||
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||||
|
|
||||||
// linear algebra
|
// linear algebra
|
||||||
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder);
|
||||||
|
|
||||||
// indexing
|
// indexing
|
||||||
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
|
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
|
||||||
|
@@ -742,26 +742,29 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
|
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next);
|
||||||
std::string repr_impl() const { return "dot"; }
|
std::string repr_impl() const { return "dot"; }
|
||||||
|
|
||||||
bool is_prefetched_ = false;
|
|
||||||
DataType C_type_ = DataType::FP32;
|
|
||||||
DataType A_type_ = DataType::FP16;
|
|
||||||
DataType B_type_ = DataType::FP16;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
bool is_prefetched() const { return is_prefetched_; }
|
bool is_prefetched() const { return is_prefetched_; }
|
||||||
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
|
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
|
||||||
|
bool allow_tf32() const { return allow_tf32_; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
|
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||||
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||||
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||||
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||||
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||||
_TRITON_DEFINE_CLONE(dot_inst)
|
_TRITON_DEFINE_CLONE(dot_inst)
|
||||||
_TRITON_DEFINE_ACCEPT(dot_inst)
|
_TRITON_DEFINE_ACCEPT(dot_inst)
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool is_prefetched_ = false;
|
||||||
|
bool allow_tf32_ = false;
|
||||||
|
DataType C_type_ = DataType::FP32;
|
||||||
|
DataType A_type_ = DataType::FP16;
|
||||||
|
DataType B_type_ = DataType::FP16;
|
||||||
};
|
};
|
||||||
|
|
||||||
//class outer_inst: public builtin_inst {
|
//class outer_inst: public builtin_inst {
|
||||||
|
@@ -23,19 +23,65 @@ inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
|
|||||||
return std::min(std::max(x, lo), hi);
|
return std::min(std::max(x, lo), hi);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool is_hmma_c(ir::value *v){
|
inline bool is_hmma_c(ir::value *v, int sm){
|
||||||
bool result = false;
|
bool result = false;
|
||||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||||
ir::value *a = x->get_operand(0);
|
ir::value *a = x->get_operand(0);
|
||||||
ir::type *a_ty = a->get_type();
|
ir::type *a_ty = a->get_type();
|
||||||
ir::value *b = x->get_operand(1);
|
ir::value *b = x->get_operand(1);
|
||||||
ir::type *b_ty = b->get_type();
|
ir::type *b_ty = b->get_type();
|
||||||
result = a_ty->get_scalar_ty()->is_fp16_ty() &&
|
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
|
||||||
b_ty->get_scalar_ty()->is_fp16_ty();
|
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
|
||||||
|
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
|
||||||
|
x->allow_tf32() && sm >= 80);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
|
||||||
|
mma_layout::TensorCoreType mma_type;
|
||||||
|
if (auto* dot = dynamic_cast<ir::dot_inst*>(v)) {
|
||||||
|
ir::value* a = dot->get_operand(0);
|
||||||
|
ir::value* b = dot->get_operand(1);
|
||||||
|
ir::type* a_ty = a->get_type();
|
||||||
|
ir::type* b_ty = b->get_type();
|
||||||
|
ir::type* c_ty = v->get_type();
|
||||||
|
|
||||||
|
if (c_ty->get_scalar_ty()->is_fp32_ty()) {
|
||||||
|
// floating point tensor cores
|
||||||
|
if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) {
|
||||||
|
mma_type = mma_layout::FP32_FP16_FP16_FP32;
|
||||||
|
return mma_type;
|
||||||
|
}
|
||||||
|
if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) {
|
||||||
|
mma_type = mma_layout::FP32_BF16_BF16_FP32;
|
||||||
|
return mma_type;
|
||||||
|
}
|
||||||
|
if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty()
|
||||||
|
&& dot->allow_tf32()) {
|
||||||
|
mma_type = mma_layout::FP32_TF32_TF32_FP32;
|
||||||
|
return mma_type;
|
||||||
|
}
|
||||||
|
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
|
||||||
|
throw std::runtime_error("integer tensor cores are not yet supported");
|
||||||
|
// // integer tensor cores
|
||||||
|
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
|
||||||
|
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
|
||||||
|
// return mma_type;
|
||||||
|
// }
|
||||||
|
// if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) {
|
||||||
|
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
|
||||||
|
// return mma_type;
|
||||||
|
// }
|
||||||
|
// if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
|
||||||
|
// mma_type = mma_layout::INT32_INT8_INT8_INT32;
|
||||||
|
// return mma_type;
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mma_layout::NOT_APPLICABLE;
|
||||||
|
}
|
||||||
|
|
||||||
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
||||||
for(ir::user* u: v->get_users()){
|
for(ir::user* u: v->get_users()){
|
||||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||||
@@ -52,11 +98,12 @@ inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) {
|
||||||
for(ir::user* u: v->get_users()){
|
for(ir::user* u: v->get_users()){
|
||||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||||
if(i && is_hmma_c(i) && i->get_operand(n) == v)
|
if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) {
|
||||||
result = i;
|
result = i;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,7 +189,9 @@ mma_layout::mma_layout(size_t num_warps,
|
|||||||
const std::vector<unsigned>& shape,
|
const std::vector<unsigned>& shape,
|
||||||
const std::vector<ir::value *> &values,
|
const std::vector<ir::value *> &values,
|
||||||
analysis::align* align, target* tgt,
|
analysis::align* align, target* tgt,
|
||||||
shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) {
|
shared_layout *layout_a, shared_layout *layout_b,
|
||||||
|
ir::value *dot): distributed_layout(MMA, axes, shape, values, align) {
|
||||||
|
tensor_core_type_ = get_mma_type(dot);
|
||||||
/* fragments per warp */
|
/* fragments per warp */
|
||||||
// try to make things as square as possible to maximize data re-use
|
// try to make things as square as possible to maximize data re-use
|
||||||
if(tgt->as_nvidia()->sm() < 80){
|
if(tgt->as_nvidia()->sm() < 80){
|
||||||
@@ -159,9 +208,9 @@ mma_layout::mma_layout(size_t num_warps,
|
|||||||
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
|
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
fpw_ = {1, 1, 1};
|
// fpw_ = {1, 1, 1};
|
||||||
spw_ = {16, 8, 1};
|
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
|
||||||
rep_ = {2, 2, 1};
|
// rep_ = {2, 2, 1};
|
||||||
}
|
}
|
||||||
order_ = {0, 1};
|
order_ = {0, 1};
|
||||||
|
|
||||||
@@ -356,7 +405,8 @@ shared_layout::shared_layout(data_layout *arg,
|
|||||||
const std::vector<unsigned>& shape,
|
const std::vector<unsigned>& shape,
|
||||||
const std::vector<ir::value *> &values,
|
const std::vector<ir::value *> &values,
|
||||||
ir::type *ty,
|
ir::type *ty,
|
||||||
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
|
analysis::align* align, target *tgt)
|
||||||
|
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) {
|
||||||
|
|
||||||
size_ = 0;
|
size_ = 0;
|
||||||
arg_layout_ = arg;
|
arg_layout_ = arg;
|
||||||
@@ -382,12 +432,25 @@ shared_layout::shared_layout(data_layout *arg,
|
|||||||
for(ir::value* v: values){
|
for(ir::value* v: values){
|
||||||
extract_dot_use(v, dot_a, 0);
|
extract_dot_use(v, dot_a, 0);
|
||||||
extract_dot_use(v, dot_b, 1);
|
extract_dot_use(v, dot_b, 1);
|
||||||
extract_hmma_dot_use(v, hmma_dot_a, 0);
|
extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm());
|
||||||
extract_hmma_dot_use(v, hmma_dot_b, 1);
|
extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm());
|
||||||
}
|
}
|
||||||
hmma_dot_a_ = hmma_dot_a;
|
hmma_dot_a_ = hmma_dot_a;
|
||||||
hmma_dot_b_ = hmma_dot_b;
|
hmma_dot_b_ = hmma_dot_b;
|
||||||
|
|
||||||
|
// Update mma_vec
|
||||||
|
if (hmma_dot_a_) {
|
||||||
|
assert(order_.size() == 2);
|
||||||
|
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
|
||||||
|
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
|
||||||
|
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
|
||||||
|
} else if (hmma_dot_b_) {
|
||||||
|
assert(order_.size() == 2);
|
||||||
|
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
|
||||||
|
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
|
||||||
|
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
|
||||||
|
}
|
||||||
|
|
||||||
// size
|
// size
|
||||||
size_ = ty_->get_primitive_size_in_bits() / 8;
|
size_ = ty_->get_primitive_size_in_bits() / 8;
|
||||||
for(auto s: shape_)
|
for(auto s: shape_)
|
||||||
@@ -451,7 +514,8 @@ void layouts::make_graph(ir::instruction *i) {
|
|||||||
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||||
// if(layouts_.find(id) != layouts_.end())
|
// if(layouts_.find(id) != layouts_.end())
|
||||||
// return;
|
// return;
|
||||||
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
auto it_hmma_c = std::find_if(values.begin(), values.end(),
|
||||||
|
[&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); });
|
||||||
auto cmp = [](ir::value* x, ir::value *y) {
|
auto cmp = [](ir::value* x, ir::value *y) {
|
||||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||||
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
|
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
|
||||||
@@ -473,13 +537,16 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
|||||||
ir::value *b = dot->get_operand(1);
|
ir::value *b = dot->get_operand(1);
|
||||||
create(groups_.at(a), values_.at(groups_.at(a)));
|
create(groups_.at(a), values_.at(groups_.at(a)));
|
||||||
create(groups_.at(b), values_.at(groups_.at(b)));
|
create(groups_.at(b), values_.at(groups_.at(b)));
|
||||||
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b)));
|
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_,
|
||||||
|
(shared_layout*)layouts_.at(groups_.at(a)),
|
||||||
|
(shared_layout*)layouts_.at(groups_.at(b)),
|
||||||
|
dot);
|
||||||
}
|
}
|
||||||
else if(it_cts != values.end()){
|
else if(it_cts != values.end()){
|
||||||
ir::instruction *cts = (ir::instruction*)*it_cts;
|
ir::instruction *cts = (ir::instruction*)*it_cts;
|
||||||
ir::value *arg = cts->get_operand(0);
|
ir::value *arg = cts->get_operand(0);
|
||||||
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
||||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
|
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
|
||||||
@@ -516,7 +583,7 @@ void layouts::run(ir::module &mod) {
|
|||||||
scanline_layout *layout = get(arg)->to_scanline();
|
scanline_layout *layout = get(arg)->to_scanline();
|
||||||
shapes[axis] = layout->mts(axis);
|
shapes[axis] = layout->mts(axis);
|
||||||
// create layout
|
// create layout
|
||||||
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
|
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
|
||||||
tmp_[red] = id;
|
tmp_[red] = id;
|
||||||
}
|
}
|
||||||
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
|
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
|
||||||
@@ -529,12 +596,12 @@ void layouts::run(ir::module &mod) {
|
|||||||
shape[k] = std::max(in_layout->shape_per_cta(k),
|
shape[k] = std::max(in_layout->shape_per_cta(k),
|
||||||
out_layout->shape_per_cta(k));
|
out_layout->shape_per_cta(k));
|
||||||
}
|
}
|
||||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_);
|
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
|
||||||
tmp_[val] = id;
|
tmp_[val] = id;
|
||||||
}
|
}
|
||||||
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
|
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
|
||||||
id++;
|
id++;
|
||||||
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
|
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_);
|
||||||
tmp_[atom] = id;
|
tmp_[atom] = id;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@@ -19,6 +19,7 @@ void swizzle::run(ir::module &) {
|
|||||||
continue;
|
continue;
|
||||||
ir::value* mma_dot_a = layout->hmma_dot_a();
|
ir::value* mma_dot_a = layout->hmma_dot_a();
|
||||||
ir::value* mma_dot_b = layout->hmma_dot_b();
|
ir::value* mma_dot_b = layout->hmma_dot_b();
|
||||||
|
|
||||||
if(!mma_dot_a && !mma_dot_b){
|
if(!mma_dot_a && !mma_dot_b){
|
||||||
per_phase_[layout] = 1;
|
per_phase_[layout] = 1;
|
||||||
max_phase_[layout] = 1;
|
max_phase_[layout] = 1;
|
||||||
@@ -39,10 +40,10 @@ void swizzle::run(ir::module &) {
|
|||||||
else
|
else
|
||||||
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||||
}
|
}
|
||||||
else{
|
else {
|
||||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||||
max_phase_[layout] = 8 / per_phase_[layout];
|
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
|
||||||
vec_[layout] = 8;
|
vec_[layout] = layout->get_mma_vec();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -85,7 +85,6 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
|||||||
allocation.run(ir);
|
allocation.run(ir);
|
||||||
prefetch_s.run(ir);
|
prefetch_s.run(ir);
|
||||||
barriers.run(ir);
|
barriers.run(ir);
|
||||||
// ir.print(std::cout);
|
|
||||||
isel.visit(ir, *llvm);
|
isel.visit(ir, *llvm);
|
||||||
shared_static = allocation.allocated_size();
|
shared_static = allocation.allocated_size();
|
||||||
return llvm;
|
return llvm;
|
||||||
|
78
lib/codegen/selection/common.h
Normal file
78
lib/codegen/selection/common.h
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
#include <sstream>
|
||||||
|
#include <iomanip>
|
||||||
|
#include "triton/codegen/selection/generator.h"
|
||||||
|
#include "triton/codegen/target.h"
|
||||||
|
#include "triton/codegen/analysis/axes.h"
|
||||||
|
#include "triton/codegen/analysis/allocation.h"
|
||||||
|
#include "triton/codegen/analysis/align.h"
|
||||||
|
#include "triton/codegen/analysis/swizzle.h"
|
||||||
|
#include "triton/codegen/transform/coalesce.h"
|
||||||
|
#include "triton/ir/context.h"
|
||||||
|
#include "triton/ir/module.h"
|
||||||
|
#include "triton/ir/function.h"
|
||||||
|
#include "triton/ir/type.h"
|
||||||
|
#include "llvm/IR/Module.h"
|
||||||
|
#include "llvm/IR/IRBuilder.h"
|
||||||
|
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||||
|
#include "llvm/IR/BasicBlock.h"
|
||||||
|
#include "llvm/IR/Attributes.h"
|
||||||
|
#include "llvm/IR/InlineAsm.h"
|
||||||
|
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
||||||
|
|
||||||
|
namespace triton::codegen {
|
||||||
|
// types
|
||||||
|
#define void_ty builder_->getVoidTy()
|
||||||
|
#define f16_ty builder_->getHalfTy()
|
||||||
|
#define bf16_ty builder_->getBFloatTy()
|
||||||
|
#define f32_ty builder_->getFloatTy()
|
||||||
|
#define i8_ty builder_->getInt8Ty()
|
||||||
|
#define i32_ty builder_->getInt32Ty()
|
||||||
|
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||||
|
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
||||||
|
// constants
|
||||||
|
#define i32(...) builder_->getInt32(__VA_ARGS__)
|
||||||
|
// ops
|
||||||
|
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
|
||||||
|
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
|
||||||
|
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
|
||||||
|
#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__)
|
||||||
|
#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__)
|
||||||
|
#define br(...) builder_->CreateBr(__VA_ARGS__)
|
||||||
|
#define call(...) builder_->CreateCall(__VA_ARGS__)
|
||||||
|
#define cast(...) builder_->CreateCast(__VA_ARGS__)
|
||||||
|
#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__)
|
||||||
|
#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__)
|
||||||
|
#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__)
|
||||||
|
#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__)
|
||||||
|
#define fadd(...) builder_->CreateFAdd(__VA_ARGS__)
|
||||||
|
#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__)
|
||||||
|
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
|
||||||
|
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
|
||||||
|
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
|
||||||
|
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
|
||||||
|
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
|
||||||
|
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
|
||||||
|
#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__)
|
||||||
|
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
|
||||||
|
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
|
||||||
|
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
|
||||||
|
#define load(...) builder_->CreateLoad(__VA_ARGS__)
|
||||||
|
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
||||||
|
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
||||||
|
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
||||||
|
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
|
||||||
|
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
|
||||||
|
#define ret(...) builder_->CreateRet(__VA_ARGS__)
|
||||||
|
#define select(...) builder_->CreateSelect(__VA_ARGS__)
|
||||||
|
#define store(...) builder_->CreateStore(__VA_ARGS__)
|
||||||
|
#define sub(...) builder_->CreateSub(__VA_ARGS__)
|
||||||
|
#define shl(...) builder_->CreateShl(__VA_ARGS__)
|
||||||
|
#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
|
||||||
|
#define urem(...) builder_->CreateURem(__VA_ARGS__)
|
||||||
|
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
|
||||||
|
#define xor_(...) builder_->CreateXor(__VA_ARGS__)
|
||||||
|
|
||||||
|
} // namespace triton::codegen
|
@@ -81,12 +81,13 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
|||||||
// return (*builder_)->CreateGEP(ty, ptr, vals, name);
|
// return (*builder_)->CreateGEP(ty, ptr, vals, name);
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
|
||||||
// types
|
// types
|
||||||
#define void_ty builder_->getVoidTy()
|
#define void_ty builder_->getVoidTy()
|
||||||
#define f16_ty builder_->getHalfTy()
|
#define f16_ty builder_->getHalfTy()
|
||||||
|
#define bf16_ty builder_->getBFloatTy()
|
||||||
#define f32_ty builder_->getFloatTy()
|
#define f32_ty builder_->getFloatTy()
|
||||||
#define i8_ty builder_->getInt8Ty()
|
#define i8_ty builder_->getInt8Ty()
|
||||||
|
#define i16_ty builder_->getInt16Ty()
|
||||||
#define i32_ty builder_->getInt32Ty()
|
#define i32_ty builder_->getInt32Ty()
|
||||||
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||||
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
||||||
@@ -133,7 +134,6 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
|||||||
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
|
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
|
||||||
#define xor_(...) builder_->CreateXor(__VA_ARGS__)
|
#define xor_(...) builder_->CreateXor(__VA_ARGS__)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Convert Triton-IR Type to LLVM-IR Type
|
* \brief Convert Triton-IR Type to LLVM-IR Type
|
||||||
*/
|
*/
|
||||||
@@ -162,7 +162,7 @@ Type *generator::cvt(ir::type *ty) {
|
|||||||
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
||||||
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
|
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
|
||||||
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
|
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
|
||||||
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_);
|
case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_);
|
||||||
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
|
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
|
||||||
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
|
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
|
||||||
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
|
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
|
||||||
@@ -457,19 +457,25 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0
|
|||||||
}
|
}
|
||||||
|
|
||||||
Value* generator::bf16_to_fp32(Value *in0){
|
Value* generator::bf16_to_fp32(Value *in0){
|
||||||
Value *ret = UndefValue::get(vec_ty(builder_->getInt16Ty(), 2));
|
if (tgt_->as_nvidia()->sm() >= 80) {
|
||||||
ret = insert_elt(ret, in0, (uint64_t)1);
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false),
|
||||||
ret = insert_elt(ret, builder_->getInt16(0), (uint64_t)0);
|
"cvt.rn.f32.bf16 $0, $1;", "=r,h", false);
|
||||||
return bit_cast(ret, builder_->getFloatTy());
|
return call(ptx, {in0});
|
||||||
|
} else {
|
||||||
|
Value *ret = UndefValue::get(vec_ty(i16_ty, 2));
|
||||||
|
ret = insert_elt(ret, bit_cast(in0, i16_ty), (uint64_t)1);
|
||||||
|
ret = insert_elt(ret, bit_cast(builder_->getInt16(0), i16_ty), (uint64_t)0);
|
||||||
|
return bit_cast(ret, f32_ty);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Value* generator::fp32_to_bf16(Value *in0){
|
Value* generator::fp32_to_bf16(Value *in0){
|
||||||
if(tgt_->as_nvidia()->sm() >= 80){
|
if(tgt_->as_nvidia()->sm() >= 80){
|
||||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}, false),
|
InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false),
|
||||||
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
|
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
|
||||||
return call(ptx, {in0});
|
return call(ptx, {in0});
|
||||||
}
|
}
|
||||||
return extract_elt(bit_cast(in0, vec_ty(builder_->getInt16Ty(), 2)), (uint64_t)1);
|
return extract_elt(bit_cast(in0, vec_ty(i16_ty, 2)), (uint64_t)1);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -514,12 +520,16 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
|
|||||||
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
|
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
|
||||||
// FP32 -> BF16
|
// FP32 -> BF16
|
||||||
if(op_sca_ty->is_fp32_ty())
|
if(op_sca_ty->is_fp32_ty())
|
||||||
for(size_t i = 0; i < x_idxs.size(); i++)
|
// for(size_t i = 0; i < x_idxs.size(); i++)
|
||||||
vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
|
// vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
|
||||||
|
for (indices_t idx: idxs_.at(x)) {
|
||||||
|
Value *arg = vals_[x->get_operand(0)][idx];
|
||||||
|
vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
|
||||||
|
}
|
||||||
// BF16 -> FP32
|
// BF16 -> FP32
|
||||||
if(ret_sca_ty->is_fp32_ty())
|
if(ret_sca_ty->is_fp32_ty())
|
||||||
for(size_t i = 0; i < x_idxs.size(); i++)
|
for(size_t i = 0; i < x_idxs.size(); i++)
|
||||||
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
|
vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -678,6 +688,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
// ---
|
// ---
|
||||||
std::vector<Type*> ret_tys(n_words, IntegerType::get(*ctx_, width));
|
std::vector<Type*> ret_tys(n_words, IntegerType::get(*ctx_, width));
|
||||||
Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0];
|
Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0];
|
||||||
|
// ret_ty->print(llvm::outs());
|
||||||
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
|
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
|
||||||
for(Value *v: others)
|
for(Value *v: others)
|
||||||
arg_tys.push_back(v->getType());
|
arg_tys.push_back(v->getType());
|
||||||
@@ -747,15 +758,19 @@ void generator::visit_store_inst(ir::store_inst * x){
|
|||||||
}
|
}
|
||||||
auto idxs = idxs_.at(val_op);
|
auto idxs = idxs_.at(val_op);
|
||||||
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
|
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
|
||||||
|
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
|
||||||
|
ty = f16_ty;
|
||||||
for(size_t i = 0; i < idxs.size(); i += vec){
|
for(size_t i = 0; i < idxs.size(); i += vec){
|
||||||
auto idx = idxs[i];
|
auto idx = idxs[i];
|
||||||
// pointer
|
// pointer
|
||||||
Value *ptr = vals_[ptr_op][idx];
|
Value *ptr = vals_[ptr_op][idx];
|
||||||
ptr = bit_cast(ptr, vec_ty(ty, vec)->getPointerTo(1));
|
// vectorize
|
||||||
|
Type *v_ty = vec_ty(ty, vec);
|
||||||
|
ptr = bit_cast(ptr, v_ty->getPointerTo(1));
|
||||||
// value
|
// value
|
||||||
Value* val = UndefValue::get(vec_ty(ty, vec));
|
Value* val = UndefValue::get(v_ty);
|
||||||
for(size_t ii = 0; ii < vec; ii++)
|
for(size_t ii = 0; ii < vec; ii++)
|
||||||
val = insert_elt(val, vals_.at(val_op)[idxs[i + ii]], ii);
|
val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii);
|
||||||
if(mx){
|
if(mx){
|
||||||
Value *msk = vals_[mx->get_mask_operand()][idx];
|
Value *msk = vals_[mx->get_mask_operand()][idx];
|
||||||
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
|
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
|
||||||
@@ -1317,6 +1332,229 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
|||||||
vals_[C][idxs_[C][i]] = acc[i];
|
vals_[C][idxs_[C][i]] = acc[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class mma16816_smem_loader {
|
||||||
|
public:
|
||||||
|
mma16816_smem_loader(int wpt, std::vector<int> order, int k_order,
|
||||||
|
std::vector<unsigned> tile_shape,
|
||||||
|
std::vector<int> instr_shape, std::vector<int> mat_shape,
|
||||||
|
int per_phase, int max_phase, int dtsize, Builder *builder,
|
||||||
|
adder add, multiplier mul, geper gep)
|
||||||
|
: wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape),
|
||||||
|
instr_shape_(instr_shape), mat_shape_(mat_shape),
|
||||||
|
per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder),
|
||||||
|
add(add), mul(mul), gep(gep) {
|
||||||
|
// compute compile-time constant variables & types
|
||||||
|
c_mat_shape_ = mat_shape[order[0]];
|
||||||
|
s_mat_shape_ = mat_shape[order[1]];
|
||||||
|
|
||||||
|
c_stride_ = tile_shape[order[1]];
|
||||||
|
s_stride_ = tile_shape[order[0]];
|
||||||
|
|
||||||
|
// rule: k must be the fast-changing axis
|
||||||
|
need_trans_ = k_order_ != order_[0];
|
||||||
|
can_use_ldmatrix_ = dtsize == 2 || (!need_trans_);
|
||||||
|
|
||||||
|
// std::cout << can_use_ldmatrix_ << std::endl;
|
||||||
|
// std::cout << need_trans_ << std::endl;
|
||||||
|
|
||||||
|
// we need more pointers at the fast-changing axis,
|
||||||
|
if (can_use_ldmatrix_)
|
||||||
|
num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]];
|
||||||
|
else // warning: this only works for tf32 & need transpose
|
||||||
|
num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]];
|
||||||
|
num_ptr_ = std::max<int>(num_ptr_, 2);
|
||||||
|
|
||||||
|
|
||||||
|
// load_v4 stride (in num of mats)
|
||||||
|
int load_stride_in_mat[2];
|
||||||
|
load_stride_in_mat[k_order] = 2; // instr_shape[k_order] / mat_shape[k_order], always 2
|
||||||
|
load_stride_in_mat[k_order^1] = wpt * (instr_shape[k_order^1] / mat_shape[k_order^1]);
|
||||||
|
p_load_stride_in_mat_ = load_stride_in_mat[order[0]];
|
||||||
|
// stride in mat, used by load_v4
|
||||||
|
s_mat_stride_ = load_stride_in_mat[order[1]] / (instr_shape[order[1]]/mat_shape[order[1]]);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Value*> compute_offs(Value *warp_off, Value *lane) {
|
||||||
|
// TODO: this needs to be moved to constructor (and extracted to arr_order)
|
||||||
|
mat_arr_stride_ = (k_order_ == 1) ? 1 : wpt_;
|
||||||
|
warp_off_stride_ = instr_shape_[k_order_^1] / mat_shape_[k_order_^1];
|
||||||
|
// start matrix logic offset (rename it as base_mat_off?)
|
||||||
|
Value *mat_off[2] = {nullptr, nullptr};
|
||||||
|
|
||||||
|
if (can_use_ldmatrix_) {
|
||||||
|
// c: lane idx inside a group (a group is a collection of 8 contiguous threads)
|
||||||
|
// s: group idx (0,1,2,3) inside a warp
|
||||||
|
Value *c = urem(lane, i32(8));
|
||||||
|
Value *s = udiv(lane, i32(8));
|
||||||
|
// We can decompose s => s_0, s_1...
|
||||||
|
Value *s0 = urem(s, i32(2));
|
||||||
|
Value *s1 = udiv(s, i32(2));
|
||||||
|
|
||||||
|
// We use different orders for a & b for better performance.
|
||||||
|
Value *k_mat_arr = (k_order_ == 1) ? s1 : s0;
|
||||||
|
Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1;
|
||||||
|
mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)),
|
||||||
|
mul(nk_mat_arr, i32(mat_arr_stride_)));
|
||||||
|
mat_off[k_order_] = k_mat_arr;
|
||||||
|
// physical offset (before swizzling)
|
||||||
|
Value *c_mat_off = mat_off[order_[0]];
|
||||||
|
Value *s_mat_off = mat_off[order_[1]];
|
||||||
|
// offset inside a matrix
|
||||||
|
Value *s_off_in_mat = c;
|
||||||
|
|
||||||
|
std::vector<Value*> offs(num_ptr_);
|
||||||
|
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
|
||||||
|
// pre-compute strided offset
|
||||||
|
Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
|
||||||
|
for (int i=0; i < num_ptr_; ++i) {
|
||||||
|
Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_));
|
||||||
|
c_mat_off_i = xor_(c_mat_off_i, phase); // smem swizzle
|
||||||
|
offs[i] = add(mul(c_mat_off_i, i32(c_mat_shape_)), mul(s_off, i32(s_stride_)));
|
||||||
|
}
|
||||||
|
return offs;
|
||||||
|
} else if (dtsize_ == 4 && need_trans_) {
|
||||||
|
// load tf32 matrices with lds32
|
||||||
|
Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]]
|
||||||
|
Value *s_off_in_mat = urem(lane, i32(4)); //
|
||||||
|
|
||||||
|
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
|
||||||
|
std::vector<Value*> offs(num_ptr_);
|
||||||
|
for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time
|
||||||
|
int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2;
|
||||||
|
int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2;
|
||||||
|
if (k_mat_arr_int > 0) // we don't need pointers for k
|
||||||
|
continue;
|
||||||
|
Value *k_mat_arr = i32(k_mat_arr_int);
|
||||||
|
Value *nk_mat_arr = i32(nk_mat_arr_int);
|
||||||
|
// physical offset (before swizzling)
|
||||||
|
Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)),
|
||||||
|
mul(nk_mat_arr, i32(mat_arr_stride_)));
|
||||||
|
Value *s_mat_off = k_mat_arr; // always 0?
|
||||||
|
Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
|
||||||
|
// FIXME: (k_order_ == 1?) is really dirty hack
|
||||||
|
for (int i = 0; i < num_ptr_/2; ++i) {
|
||||||
|
Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_*(k_order_ == 1?1:2)));
|
||||||
|
c_mat_off_i = xor_(c_mat_off_i, phase);
|
||||||
|
Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_)));
|
||||||
|
// TODO: move this out of the loop
|
||||||
|
c_off = urem(c_off, i32(tile_shape_[order_[0]]));
|
||||||
|
s_off = urem(s_off, i32(tile_shape_[order_[1]]));
|
||||||
|
offs[2*i + nk_mat_arr_int] = add(c_off, mul(s_off, i32(s_stride_)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return offs;
|
||||||
|
// throw std::runtime_error("not implemented");
|
||||||
|
} else
|
||||||
|
throw std::runtime_error("invalid smem load config");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<Value*, Value*, Value*, Value*>
|
||||||
|
load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn,
|
||||||
|
Value *pre_ptr, Value *next_ptr, std::vector<Value*> &off, std::vector<Value*> &ptrs,
|
||||||
|
FunctionType *ldmatrix_ty, Type *smem_ptr_ty,
|
||||||
|
std::map<ir::value*, std::vector<Value*>> &prefetch_latch_to_bb_) {
|
||||||
|
assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned");
|
||||||
|
int mat_idx[2] = {mat0, mat1};
|
||||||
|
int k = mat_idx[k_order_];
|
||||||
|
|
||||||
|
int ptr_idx = -1;
|
||||||
|
if (can_use_ldmatrix_)
|
||||||
|
ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]);
|
||||||
|
else // tf32 & trans
|
||||||
|
ptr_idx = mat_idx[order_[0]];
|
||||||
|
|
||||||
|
auto get_ptr = [&](int idx) -> Value* {
|
||||||
|
Value *ptr = nullptr;
|
||||||
|
if (k == 0 && is_prefetch) {
|
||||||
|
if (inc == 0)
|
||||||
|
ptr = bit_cast(gep(pre_ptr, off.at(idx)), smem_ptr_ty);
|
||||||
|
else
|
||||||
|
ptr = bit_cast(gep(next_ptr, off.at(idx)), smem_ptr_ty);
|
||||||
|
} else
|
||||||
|
ptr = ptrs.at(idx);
|
||||||
|
return ptr;
|
||||||
|
};
|
||||||
|
Value *ptr = get_ptr(ptr_idx);
|
||||||
|
|
||||||
|
Value *res_v4 = nullptr;
|
||||||
|
if (can_use_ldmatrix_) {
|
||||||
|
std::string trans = need_trans_ ? ".trans" : "";
|
||||||
|
// the offset (in byte) on the strided axis is a constant
|
||||||
|
int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_;
|
||||||
|
InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty,
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 "
|
||||||
|
"{$0, $1, $2, $3}, "
|
||||||
|
"[$4 + " + std::to_string(s_offset) + "];",
|
||||||
|
"=r,=r,=r,=r,r", true);
|
||||||
|
assert(ptr);
|
||||||
|
res_v4 = call(ldmatrix_ty, ld_fn, {ptr});
|
||||||
|
if (k == 0 && inc == 1 && is_prefetch)
|
||||||
|
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4);
|
||||||
|
return {extract_val(res_v4, std::vector<unsigned>{0}),
|
||||||
|
extract_val(res_v4, std::vector<unsigned>{1}),
|
||||||
|
extract_val(res_v4, std::vector<unsigned>{2}),
|
||||||
|
extract_val(res_v4, std::vector<unsigned>{3})};
|
||||||
|
} else {
|
||||||
|
// assert(false && "should not be here");
|
||||||
|
assert(dtsize_ == 4 && need_trans_);
|
||||||
|
Value *ptr2 = get_ptr(ptr_idx+1);
|
||||||
|
assert(s_mat_stride_ == 1);
|
||||||
|
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
|
||||||
|
int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_;
|
||||||
|
Value *elem0, *elem1, *elem2, *elem3;
|
||||||
|
if (k_order_ == 1) {
|
||||||
|
elem0 = load(gep(ptr, i32(s_offset_elem)));
|
||||||
|
elem1 = load(gep(ptr2, i32(s_offset_elem)));
|
||||||
|
elem2 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
} else { // for b (k first)
|
||||||
|
elem0 = load(gep(ptr, i32(s_offset_elem)));
|
||||||
|
elem2 = load(gep(ptr2, i32(s_offset_elem)));
|
||||||
|
elem1 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
}
|
||||||
|
if (k == 0 && inc == 1 && is_prefetch) {
|
||||||
|
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem0);
|
||||||
|
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem1);
|
||||||
|
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem2);
|
||||||
|
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3);
|
||||||
|
}
|
||||||
|
return {elem0, elem1, elem2, elem3};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_num_ptr() const { return num_ptr_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int wpt_;
|
||||||
|
std::vector<int> order_;
|
||||||
|
int k_order_;
|
||||||
|
std::vector<unsigned> tile_shape_;
|
||||||
|
std::vector<int> instr_shape_;
|
||||||
|
std::vector<int> mat_shape_;
|
||||||
|
int per_phase_, max_phase_;
|
||||||
|
int dtsize_;
|
||||||
|
|
||||||
|
// generated
|
||||||
|
int c_mat_shape_, s_mat_shape_;
|
||||||
|
int c_stride_, s_stride_;
|
||||||
|
// p_: on the pointer axis
|
||||||
|
int p_load_stride_in_mat_;
|
||||||
|
int s_mat_stride_;
|
||||||
|
// stride when moving to next not-k mat
|
||||||
|
int warp_off_stride_;
|
||||||
|
int mat_arr_stride_; // matrix arrangement (inside a load) stride
|
||||||
|
bool need_trans_, can_use_ldmatrix_;
|
||||||
|
int num_ptr_;
|
||||||
|
|
||||||
|
Builder *builder_;
|
||||||
|
adder add;
|
||||||
|
multiplier mul;
|
||||||
|
geper gep;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Code Generation for `mma.16816` (A100)
|
* \brief Code Generation for `mma.16816` (A100)
|
||||||
*/
|
*/
|
||||||
@@ -1338,35 +1576,65 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
|||||||
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
|
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
|
||||||
bool is_a_row = ord_a[0] == 1;
|
bool is_a_row = ord_a[0] == 1;
|
||||||
bool is_b_row = ord_b[0] == 1;
|
bool is_b_row = ord_b[0] == 1;
|
||||||
std::string a_trans = is_a_row ? "" : ".trans";
|
|
||||||
std::string b_trans = is_b_row ? ".trans" : "";
|
std::vector<int> mma_instr_shape = layout->get_mma_instr_shape();
|
||||||
int stride_a_m = is_a_row ? shape_a[1] : 1;
|
const int mma_instr_m = mma_instr_shape[0];
|
||||||
int stride_a_k = is_a_row ? 1 : shape_a[0];
|
const int mma_instr_n = mma_instr_shape[1];
|
||||||
int stride_b_n = is_b_row ? 1 : shape_b[0];
|
const int mma_instr_k = mma_instr_shape[2];
|
||||||
int stride_b_k = is_b_row ? shape_b[1] : 1;
|
|
||||||
int stride_a0 = is_a_row ? stride_a_k : stride_a_m;
|
std::vector<int> mat_shape = layout->get_mma_mat_shape();
|
||||||
int stride_a1 = is_a_row ? stride_a_m : stride_a_k;
|
const int mat_shape_m = mat_shape[0];
|
||||||
int stride_b0 = is_b_row ? stride_b_n : stride_b_k;
|
const int mat_shape_n = mat_shape[1];
|
||||||
int stride_b1 = is_b_row ? stride_b_k : stride_b_n;
|
const int mat_shape_k = mat_shape[2];
|
||||||
int lda = is_a_row ? stride_a_m : stride_a_k;
|
|
||||||
int ldb = is_b_row ? stride_b_k : stride_b_n;
|
const int per_phase_a = swizzle_->get_per_phase(layout_a);
|
||||||
int per_phase_a = swizzle_->get_per_phase(layout_a);
|
const int max_phase_a = swizzle_->get_max_phase(layout_a);
|
||||||
int max_phase_a = swizzle_->get_max_phase(layout_a);
|
const int per_phase_b = swizzle_->get_per_phase(layout_b);
|
||||||
int per_phase_b = swizzle_->get_per_phase(layout_b);
|
const int max_phase_b = swizzle_->get_max_phase(layout_b);
|
||||||
int max_phase_b = swizzle_->get_max_phase(layout_b);
|
|
||||||
int num_ptr_a = 8;
|
const int num_rep_m = shapes[0] / layout->shape_per_cta(0);
|
||||||
int num_ptr_b = 8;
|
const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
|
||||||
int vec_a = 8;
|
const int num_rep_k = std::max<int>(NK/mma_instr_k, 1);
|
||||||
int vec_b = 8;
|
|
||||||
|
|
||||||
Type *fp32_ty = f32_ty;
|
Type *fp32_ty = f32_ty;
|
||||||
Type *fp16x2_ty = vec_ty(f16_ty, 2);
|
Type *fp16x2_ty = vec_ty(f16_ty, 2);
|
||||||
|
Type *bf16x2_ty = vec_ty(bf16_ty, 2);
|
||||||
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
||||||
|
Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty});
|
||||||
Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||||
FunctionType *ld_x4_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{ptr_ty(f16_ty, 3)}, false);
|
|
||||||
|
FunctionType *ldmatrix_ty = nullptr;
|
||||||
|
FunctionType *mma_ty = nullptr;
|
||||||
|
Type *phi_ty = nullptr;
|
||||||
|
Type *smem_ptr_ty = nullptr;
|
||||||
|
|
||||||
|
ir::type *A_ir_ty = A->get_type()->get_scalar_ty();
|
||||||
|
ir::type *B_ir_ty = B->get_type()->get_scalar_ty();
|
||||||
|
if (A_ir_ty->is_fp16_ty() && B_ir_ty->is_fp16_ty()) {
|
||||||
|
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||||
|
smem_ptr_ty = ptr_ty(f16_ty, 3);
|
||||||
|
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||||
|
phi_ty = fp16x2_ty;
|
||||||
|
} else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) {
|
||||||
|
// FIXME: We should use bf16 here.
|
||||||
|
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||||
|
smem_ptr_ty = ptr_ty(f16_ty, 3);
|
||||||
|
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||||
|
phi_ty = fp16x2_ty;
|
||||||
|
// mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||||
|
// smem_ptr_ty = ptr_ty(bf16_ty, 3);
|
||||||
|
// ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||||
|
// phi_ty = bf16x2_ty;
|
||||||
|
} else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) {
|
||||||
|
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||||
|
smem_ptr_ty = ptr_ty(fp32_ty, 3);
|
||||||
|
ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||||
|
phi_ty = fp32_ty;
|
||||||
|
} else
|
||||||
|
throw std::runtime_error("mma16816 data type not supported");
|
||||||
|
|
||||||
// left-hand-side values
|
// left-hand-side values
|
||||||
std::map<std::pair<unsigned, unsigned>, std::pair<Value*, Value*>> ha;
|
std::map<std::pair<unsigned, unsigned>, Value*> ha;
|
||||||
std::map<std::pair<unsigned, unsigned>, Value*> hb;
|
std::map<std::pair<unsigned, unsigned>, Value*> hb;
|
||||||
|
|
||||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||||
@@ -1377,79 +1645,66 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
|||||||
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
|
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
|
||||||
Value *lane = urem(thread, i32(32));
|
Value *lane = urem(thread, i32(32));
|
||||||
Value *warp = udiv(thread, i32(32));
|
Value *warp = udiv(thread, i32(32));
|
||||||
Value *warp12 = udiv(warp, i32(layout->wpt(0)));
|
Value *warp_mn = udiv(warp, i32(layout->wpt(0)));
|
||||||
Value *warp0 = urem(warp, i32(layout->wpt(0)));
|
Value *warp_m = urem(warp, i32(layout->wpt(0)));
|
||||||
Value *warp1 = urem(warp12, i32(layout->wpt(1)));
|
Value *warp_n = urem(warp_mn, i32(layout->wpt(1)));
|
||||||
std::vector<Value *>& fc = fcs.begin()->second;
|
std::vector<Value *>& fc = fcs.begin()->second;
|
||||||
|
|
||||||
Value *tidr8 = urem(lane, i32(8));
|
size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||||
Value *phase_a = urem(udiv(tidr8, i32(per_phase_a)), i32(max_phase_a));
|
size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||||
Value* off_a0 = mul(tidr8, i32(lda));
|
|
||||||
Value *off_am = mul(add(urem(udiv(lane, i32(8)), i32(2)), mul(warp0, i32(2))), i32(8));
|
|
||||||
Value *off_ak = mul(udiv(lane, i32(16)), i32(8));
|
|
||||||
off_am = urem(off_am, i32(shape_a[0]));
|
|
||||||
off_ak = urem(off_ak, i32(shape_a[1]));
|
|
||||||
off_a0 = add(off_a0, is_a_row ? off_ak : off_am);
|
|
||||||
Value* off_a1 = is_a_row ? off_am : off_ak;
|
|
||||||
std::vector<Value*> off_a(num_ptr_a);
|
|
||||||
for(int i = 0; i < num_ptr_a; i++){
|
|
||||||
Value* off_a0i = add(off_a0, i32(i*16*(is_a_row?1:layout->wpt(0))));
|
|
||||||
off_a0i = exact_udiv(off_a0i, i32(vec_a));
|
|
||||||
off_a0i = xor_(off_a0i, phase_a);
|
|
||||||
off_a0i = mul(off_a0i, i32(vec_a));
|
|
||||||
off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Value *phase_b = urem(udiv(tidr8, i32(per_phase_b)), i32(max_phase_b));
|
// | -> k (row-major), since we have ldmatrix.trans, we only need to change stride
|
||||||
Value* off_b0 = mul(tidr8, i32(ldb));
|
// v (s0_0(0), s1_0(2), | *num_rep_k
|
||||||
Value *off_bn = mul(add(mul(udiv(lane, i32(16)), i32(layout->wpt(1))), mul(warp1, i32(1))), i32(8));
|
// m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2)
|
||||||
Value *off_bk = mul(urem(udiv(lane, i32(8)), i32(2)), i32(8));
|
// -----------
|
||||||
off_bn = urem(off_bn, i32(shape_b[1]));
|
// *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0))
|
||||||
off_bk = urem(off_bk, i32(shape_b[0]));
|
mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
|
||||||
off_b0 = add(off_b0, is_b_row ? off_bn : off_bk);
|
{mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
|
||||||
Value* off_b1 = is_b_row ? off_bk : off_bn;
|
per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
|
||||||
std::vector<Value*> off_b(num_ptr_b);
|
std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
|
||||||
for(int i = 0; i < num_ptr_b; i++){
|
int num_ptr_a = a_loader.get_num_ptr();
|
||||||
Value* off_b0i = add(off_b0, i32(i*(is_b_row?8*layout->wpt(1):16)));
|
|
||||||
off_b0i = exact_udiv(off_b0i, i32(vec_b));
|
// | -> n (col-major)
|
||||||
off_b0i = xor_(off_b0i, phase_b);
|
// v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n
|
||||||
off_b0i = mul(off_b0i, i32(vec_b));
|
// k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1))
|
||||||
off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1)));
|
// -----------
|
||||||
}
|
// *num_rep_k (stride in num of matrices(mat_stride_bk): 2)
|
||||||
|
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b,
|
||||||
|
{mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n},
|
||||||
|
per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep);
|
||||||
|
std::vector<Value*> off_b = b_loader.compute_offs(warp_n, lane);
|
||||||
|
int num_ptr_b = b_loader.get_num_ptr();
|
||||||
|
|
||||||
builder_->SetInsertPoint(CurrBB);
|
builder_->SetInsertPoint(CurrBB);
|
||||||
// A pointer
|
// A pointer
|
||||||
std::vector<Value*> ptrs_a(num_ptr_a);
|
std::vector<Value*> ptrs_a(num_ptr_a);
|
||||||
for(int i = 0; i < num_ptr_a; i++)
|
for(int i = 0; i < num_ptr_a; i++)
|
||||||
ptrs_a[i] = gep(shmems_[A], {off_a[i]});
|
ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
|
||||||
// B pointer
|
// B pointer
|
||||||
std::vector<Value*> ptrs_b(num_ptr_b);
|
std::vector<Value*> ptrs_b(num_ptr_b);
|
||||||
for(int i = 0; i < num_ptr_b; i++)
|
for(int i = 0; i < num_ptr_b; i++)
|
||||||
ptrs_b[i] = gep(shmems_[B], {off_b[i]});
|
ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty);
|
||||||
|
|
||||||
FunctionType *mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() +
|
||||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
" {$0, $1, $2, $3},"
|
||||||
"{$0, $1, $2, $3}, "
|
" {$4, $5, $6, $7},"
|
||||||
"{$4, $5, $6, $7}, "
|
" {$8, $9},"
|
||||||
"{$8, $9}, "
|
" {$10, $11, $12, $13};",
|
||||||
"{$10, $11, $12, $13};",
|
|
||||||
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
|
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
|
||||||
|
|
||||||
unsigned num_rep_0 = shapes[0] / layout->shape_per_cta(0);
|
// create mma & unpack result, m, n, k are offsets in mat
|
||||||
unsigned num_rep_1 = shapes[1] / layout->shape_per_cta(1);
|
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
|
||||||
|
unsigned cols_per_thread = num_rep_m * 2;
|
||||||
// create mma & unpack result
|
|
||||||
auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
|
|
||||||
unsigned cols_per_thread = num_rep_0 * 2;
|
|
||||||
std::vector<size_t> idx = {
|
std::vector<size_t> idx = {
|
||||||
(m*2 + 0) + (n*2 + 0)*cols_per_thread,
|
(m + 0) + (n*2 + 0)*cols_per_thread,
|
||||||
(m*2 + 0) + (n*2 + 1)*cols_per_thread,
|
(m + 0) + (n*2 + 1)*cols_per_thread,
|
||||||
(m*2 + 1) + (n*2 + 0)*cols_per_thread,
|
(m + 1) + (n*2 + 0)*cols_per_thread,
|
||||||
(m*2 + 1) + (n*2 + 1)*cols_per_thread
|
(m + 1) + (n*2 + 1)*cols_per_thread
|
||||||
};
|
};
|
||||||
Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second,
|
Value *nc = call(mma_ty, mma_fn,
|
||||||
hb[{n, K}], hb[{n, K+8}],
|
{ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}],
|
||||||
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
|
hb[{n, k}], hb[{n, k+1}],
|
||||||
|
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
|
||||||
fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
|
fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
|
||||||
fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
|
fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
|
||||||
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
|
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
|
||||||
@@ -1459,131 +1714,83 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
|||||||
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
||||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||||
|
|
||||||
auto register_lds =
|
|
||||||
[&](decltype(ha)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
|
|
||||||
if (K <= 8 && is_prefetch) {
|
|
||||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
|
||||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
|
|
||||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
|
|
||||||
} else
|
|
||||||
vals[{m, K}] = {val0, val1};
|
|
||||||
};
|
|
||||||
|
|
||||||
auto register_lds2 =
|
auto register_lds2 =
|
||||||
[&](decltype(hb)& vals, int m, int K, int inc, Value* val, bool is_prefetch) {
|
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int n, int k, int inc, Value* val, bool is_prefetch) {
|
||||||
if (K <= 8 && is_prefetch) {
|
if (k < 2 && is_prefetch) {
|
||||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}], val, inc_block));
|
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{n, k}], val, inc_block));
|
||||||
} else
|
} else
|
||||||
vals[{m, K}] = val;
|
vals[{n, k}] = val;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto load_a = [&](int m, int K, int inc, bool is_prefetch) {
|
auto load_a = [&](int m, int k, int inc, bool is_prefetch) {
|
||||||
int offidx = (is_a_row ? K/16 : m) % num_ptr_a;
|
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
|
||||||
Value* ptra;
|
shared_next_ptr_[layout_a], off_a, ptrs_a,
|
||||||
if(K == 0 && is_prefetch){
|
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
||||||
if(inc == 0)
|
register_lds2(ha, m, k, inc, ha0, is_prefetch);
|
||||||
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
|
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
|
||||||
else
|
register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
|
||||||
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
|
register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
|
||||||
}
|
|
||||||
else
|
|
||||||
ptra = ptrs_a[offidx];
|
|
||||||
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
|
|
||||||
int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K;
|
|
||||||
InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 "
|
|
||||||
"{$0, $1, $2, $3}, [$4 + " +
|
|
||||||
std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];",
|
|
||||||
"=r,=r,=r,=r,r", true);
|
|
||||||
Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra});
|
|
||||||
if(K == 0 && inc == 1 && is_prefetch)
|
|
||||||
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(haa);
|
|
||||||
Value *ha0 = extract_val(haa, std::vector<unsigned>{0});
|
|
||||||
Value *ha1 = extract_val(haa, std::vector<unsigned>{1});
|
|
||||||
Value *ha2 = extract_val(haa, std::vector<unsigned>{2});
|
|
||||||
Value *ha3 = extract_val(haa, std::vector<unsigned>{3});
|
|
||||||
register_lds(ha, m, K, inc, ha0, ha1, is_prefetch);
|
|
||||||
register_lds(ha, m, K + 8, inc, ha2, ha3, is_prefetch);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
auto load_b = [&](int n, int K, int inc, bool is_prefetch) {
|
auto load_b = [&](int n, int k, int inc, bool is_prefetch) {
|
||||||
int offidx = (is_b_row ? n : K/16) % num_ptr_b;
|
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
|
||||||
Value* ptrb;
|
shared_next_ptr_[layout_b], off_b, ptrs_b,
|
||||||
if(K == 0 && is_prefetch){
|
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
||||||
if(inc == 0)
|
register_lds2(hb, n, k, inc, hb0, is_prefetch);
|
||||||
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
|
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
|
||||||
else
|
register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
|
||||||
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
|
register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
|
||||||
}
|
|
||||||
else
|
|
||||||
ptrb = ptrs_b[offidx];
|
|
||||||
int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
|
|
||||||
int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8);
|
|
||||||
InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 "
|
|
||||||
"{$0, $1, $2, $3}, [$4 + " +
|
|
||||||
std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];",
|
|
||||||
"=r,=r,=r,=r,r", true);
|
|
||||||
Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb});
|
|
||||||
if(K == 0 && inc == 1 && is_prefetch)
|
|
||||||
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hbb);
|
|
||||||
Value *hb0 = extract_val(hbb, std::vector<unsigned>{0});
|
|
||||||
Value *hb1 = extract_val(hbb, std::vector<unsigned>{1});
|
|
||||||
Value *hb2 = extract_val(hbb, std::vector<unsigned>{2});
|
|
||||||
Value *hb3 = extract_val(hbb, std::vector<unsigned>{3});
|
|
||||||
register_lds2(hb, n, K, inc, hb0, is_prefetch);
|
|
||||||
register_lds2(hb, n+1, K, inc, hb2, is_prefetch);
|
|
||||||
register_lds2(hb, n, K+8, inc, hb1, is_prefetch);
|
|
||||||
register_lds2(hb, n+1, K+8, inc, hb3, is_prefetch);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (C->is_prefetched()) {
|
if (C->is_prefetched()) {
|
||||||
// create phis
|
// create phis
|
||||||
builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
|
builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
|
||||||
for(unsigned m = 0; m < num_rep_0; m++){
|
for(unsigned m = 0; m < num_rep_m; m++){
|
||||||
ha[{m, 0}].first = phi(fp16x2_ty, 2);
|
ha[{2*m, 0}] = phi(phi_ty, 2);
|
||||||
ha[{m, 0}].second = phi(fp16x2_ty, 2);
|
ha[{2*m+1, 0}] = phi(phi_ty, 2);
|
||||||
ha[{m, 8}].first = phi(fp16x2_ty, 2);
|
ha[{2*m, 1}] = phi(phi_ty, 2);
|
||||||
ha[{m, 8}].second = phi(fp16x2_ty, 2);
|
ha[{2*m+1, 1}] = phi(phi_ty, 2);
|
||||||
}
|
}
|
||||||
for(unsigned n = 0; n < num_rep_1; n+=2){
|
for(unsigned n = 0; n < num_rep_n; n+=2){
|
||||||
hb[{n, 0}] = phi(fp16x2_ty, 2);
|
hb[{n, 0}] = phi(phi_ty, 2);
|
||||||
hb[{n+1, 0}] = phi(fp16x2_ty, 2);
|
hb[{n+1, 0}] = phi(phi_ty, 2);
|
||||||
hb[{n, 8}] = phi(fp16x2_ty, 2);
|
hb[{n, 1}] = phi(phi_ty, 2);
|
||||||
hb[{n+1, 8}] = phi(fp16x2_ty, 2);
|
hb[{n+1, 1}] = phi(phi_ty, 2);
|
||||||
}
|
}
|
||||||
// insert prefetched lds at the end of loop header
|
// insert prefetched lds at the end of loop header
|
||||||
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
|
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
|
||||||
for(unsigned m = 0; m < num_rep_0; m++)
|
for(unsigned m = 0; m < num_rep_m; m++)
|
||||||
load_a(m, 0, 0, true);
|
load_a(2*m, 0, 0, true);
|
||||||
for(unsigned n = 0; n < num_rep_1; n+=2)
|
for(unsigned n = 0; n < num_rep_n; n+=2)
|
||||||
load_b(n, 0, 0, true);
|
load_b(n, 0, 0, true);
|
||||||
// update accumulators
|
// update accumulators
|
||||||
builder_->SetInsertPoint(CurrBB);
|
builder_->SetInsertPoint(CurrBB);
|
||||||
for(unsigned K = 0; K < NK; K += 16){
|
for(unsigned k = 0; k < num_rep_k; ++k){ // stride of instr in mat is 2
|
||||||
int NEXTK = (K + 16) % NK;
|
int next_k = (k + 1) % num_rep_k;
|
||||||
// prefetch A
|
// prefetch A
|
||||||
for(unsigned m = 0; m < num_rep_0; m++)
|
for(unsigned m = 0; m < num_rep_m; m++)
|
||||||
load_a(m, NEXTK, 1, true);
|
load_a(2*m, 2*next_k, 1, true);
|
||||||
// prefetch B
|
// prefetch B
|
||||||
for(unsigned n = 0; n < num_rep_1; n+=2)
|
for(unsigned n = 0; n < num_rep_n; n+=2)
|
||||||
load_b(n, NEXTK, 1, true);
|
load_b(n, 2*next_k, 1, true);
|
||||||
// tensor core ops
|
// tensor core ops
|
||||||
for(unsigned m = 0; m < num_rep_0; m++)
|
for(unsigned m = 0; m < num_rep_m; m++)
|
||||||
for(unsigned n = 0; n < num_rep_1; n++){
|
for(unsigned n = 0; n < num_rep_n; n++){
|
||||||
call_mma(m, n, K);
|
call_mma(2*m, n, 2*k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
for(unsigned K = 0; K < NK; K += 16)
|
for (unsigned k = 0; k < num_rep_k; k++) {
|
||||||
for(unsigned m = 0; m < num_rep_0; m++)
|
for (unsigned m = 0; m < num_rep_m; m++)
|
||||||
for(unsigned n = 0; n < num_rep_1; n++){
|
load_a(2*m, 2*k, 0, /*is_prefetch*/false);
|
||||||
if(ha.find({m, K}) == ha.end())
|
for (unsigned n = 0; n < num_rep_n; n+=2)
|
||||||
load_a(m, K, 0, false);
|
load_b(n, 2*k, 0, /*is_prefetch*/false);
|
||||||
if(hb.find({n, K})==hb.end())
|
for (unsigned m = 0; m < num_rep_m; m++)
|
||||||
load_b(n, K, 0, false);
|
for (unsigned n = 0; n < num_rep_n; n++)
|
||||||
call_mma(m, n, K);
|
call_mma(2*m, n, 2*k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// write back
|
// write back
|
||||||
unsigned i = 0;
|
unsigned i = 0;
|
||||||
@@ -1714,7 +1921,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
|||||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
|
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
|
||||||
return visit_mma884(dot, A, B, D, NK);
|
return visit_mma884(dot, A, B, D, NK);
|
||||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
|
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
|
||||||
return visit_mma16816(dot, A, B, D, NK);
|
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
|
||||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1752,13 +1959,13 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
|||||||
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
|
||||||
if(ty->getPrimitiveSizeInBits() <= 32)
|
if(ty->getPrimitiveSizeInBits() <= 32)
|
||||||
return call(shfl, {acc, i32(i)});
|
return call(shfl, {acc, i32(i)});
|
||||||
acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2));
|
acc = bit_cast(acc, vec_ty(f32_ty, 2));
|
||||||
Value* acc0 = builder_->CreateExtractElement(acc, i32(0));
|
Value* acc0 = builder_->CreateExtractElement(acc, i32(0));
|
||||||
Value* acc1 = builder_->CreateExtractElement(acc, i32(1));
|
Value* acc1 = builder_->CreateExtractElement(acc, i32(1));
|
||||||
Value* ret = UndefValue::get(vec_ty(f32_ty, 2));
|
Value* ret = UndefValue::get(vec_ty(f32_ty, 2));
|
||||||
ret = insert_elt(ret, shfl_sync(acc0, i), i32(0));
|
ret = insert_elt(ret, shfl_sync(acc0, i), i32(0));
|
||||||
ret = insert_elt(ret, shfl_sync(acc1, i), i32(1));
|
ret = insert_elt(ret, shfl_sync(acc1, i), i32(1));
|
||||||
return builder_->CreateBitCast(ret, ty);
|
return bit_cast(ret, ty);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -1936,6 +2143,10 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
|||||||
ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes();
|
ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes();
|
||||||
// pointer to temporary shared memory
|
// pointer to temporary shared memory
|
||||||
Type *ty = cvt(out->get_type()->get_scalar_ty());
|
Type *ty = cvt(out->get_type()->get_scalar_ty());
|
||||||
|
|
||||||
|
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
|
||||||
|
ty = f16_ty;
|
||||||
|
|
||||||
// Orders
|
// Orders
|
||||||
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
|
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
|
||||||
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
|
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
|
||||||
@@ -1976,7 +2187,7 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
|||||||
// stash value to shared mem
|
// stash value to shared mem
|
||||||
indices_t idxs = {in_ax[0][i*max_ii + ii],
|
indices_t idxs = {in_ax[0][i*max_ii + ii],
|
||||||
in_ax[1][j*max_jj + jj]};
|
in_ax[1][j*max_jj + jj]};
|
||||||
store(vals_[in][idxs], ptr);
|
store(bit_cast(vals_[in][idxs], ty), ptr);
|
||||||
}
|
}
|
||||||
add_barrier();
|
add_barrier();
|
||||||
max_ii = out_ax[0].size()/n_reps[0];
|
max_ii = out_ax[0].size()/n_reps[0];
|
||||||
|
@@ -80,7 +80,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
|||||||
ir::value *a = dot->get_operand(0);
|
ir::value *a = dot->get_operand(0);
|
||||||
ir::value *b = dot->get_operand(1);
|
ir::value *b = dot->get_operand(1);
|
||||||
builder.set_insert_point(add);
|
builder.set_insert_point(add);
|
||||||
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name()));
|
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name()));
|
||||||
add->replace_all_uses_with(new_dot);
|
add->replace_all_uses_with(new_dot);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@@ -29,8 +29,13 @@ void prefetch::run(ir::module &mod) {
|
|||||||
std::vector<ir::dot_inst*> to_prefetch;
|
std::vector<ir::dot_inst*> to_prefetch;
|
||||||
ir::for_each_instruction(mod, [&](ir::instruction *i) {
|
ir::for_each_instruction(mod, [&](ir::instruction *i) {
|
||||||
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
|
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
|
||||||
// Now only do prefetching when dot is fp16
|
// Now only do prefetching when dot is using tensor cores
|
||||||
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::FP16TyID)
|
if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
|
||||||
|
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
|
||||||
|
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
|
||||||
|
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||||
|
)
|
||||||
|
)
|
||||||
return;
|
return;
|
||||||
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
||||||
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
||||||
|
@@ -121,6 +121,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
|||||||
init_llvm();
|
init_llvm();
|
||||||
// verify and store llvm
|
// verify and store llvm
|
||||||
llvm::legacy::PassManager pm;
|
llvm::legacy::PassManager pm;
|
||||||
|
// module->print(llvm::outs(), nullptr);
|
||||||
pm.add(llvm::createVerifierPass());
|
pm.add(llvm::createVerifierPass());
|
||||||
pm.run(*module);
|
pm.run(*module);
|
||||||
// create machine
|
// create machine
|
||||||
|
@@ -367,8 +367,8 @@ value *builder::create_log(value *arg){
|
|||||||
return insert(log_inst::create(arg));
|
return insert(log_inst::create(arg));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_dot(value *A, value *B, value *C) {
|
value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) {
|
||||||
return insert(dot_inst::create_nn(A, B, C));
|
return insert(dot_inst::create_nn(A, B, C, allow_tf32));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_trans(value *A, const std::vector<int>& perm) {
|
value *builder::create_trans(value *A, const std::vector<int>& perm) {
|
||||||
|
@@ -718,12 +718,13 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask
|
|||||||
// Linear Algebra
|
// Linear Algebra
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
|
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) {
|
||||||
ir::value *_0 = builder->get_float32(0);
|
ir::value *_0 = builder->get_float32(0);
|
||||||
unsigned M = lhs->get_type()->get_block_shapes()[0];
|
unsigned M = lhs->get_type()->get_block_shapes()[0];
|
||||||
unsigned N = rhs->get_type()->get_block_shapes()[1];
|
unsigned N = rhs->get_type()->get_block_shapes()[1];
|
||||||
_0 = builder->create_splat(_0, {M, N});
|
_0 = builder->create_splat(_0, {M, N});
|
||||||
return builder->create_dot(lhs, rhs, _0);
|
bool _allow_tf32 = allow_tf32->get_value() != 0;
|
||||||
|
return builder->create_dot(lhs, rhs, _0, _allow_tf32);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -577,40 +577,41 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
|
|||||||
// matmul_inst classes
|
// matmul_inst classes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT,
|
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
|
||||||
const std::string &name, instruction *next)
|
const std::string &name, instruction *next)
|
||||||
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
|
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
|
||||||
set_operand(0, A);
|
set_operand(0, A);
|
||||||
set_operand(1, B);
|
set_operand(1, B);
|
||||||
set_operand(2, C);
|
set_operand(2, C);
|
||||||
|
allow_tf32_ = allow_tf32;
|
||||||
}
|
}
|
||||||
|
|
||||||
instruction *dot_inst::create(value *A, value *B, value *C,
|
instruction *dot_inst::create(value *A, value *B, value *C,
|
||||||
bool AT, bool BT,
|
bool AT, bool BT, bool allow_tf32,
|
||||||
const std::string &name, instruction *next) {
|
const std::string &name, instruction *next) {
|
||||||
TransT OPA = AT ? Trans : NoTrans;
|
TransT OPA = AT ? Trans : NoTrans;
|
||||||
TransT OPB = BT ? Trans : NoTrans;
|
TransT OPB = BT ? Trans : NoTrans;
|
||||||
return new dot_inst(A, B, C, OPA, OPB, name, next);
|
return new dot_inst(A, B, C, OPA, OPB, allow_tf32, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
instruction *dot_inst::create_nn(value *A, value *B, value *C,
|
instruction *dot_inst::create_nn(value *A, value *B, value *C, bool allow_tf32,
|
||||||
const std::string &name, instruction *next) {
|
const std::string &name, instruction *next) {
|
||||||
return new dot_inst(A, B, C, NoTrans, NoTrans, name, next);
|
return new dot_inst(A, B, C, NoTrans, NoTrans, allow_tf32, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
instruction *dot_inst::create_nt(value *A, value *B, value *C,
|
instruction *dot_inst::create_nt(value *A, value *B, value *C, bool allow_tf32,
|
||||||
const std::string &name, instruction *next) {
|
const std::string &name, instruction *next) {
|
||||||
return new dot_inst(A, B, C, NoTrans, Trans, name, next);
|
return new dot_inst(A, B, C, NoTrans, Trans, allow_tf32, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
instruction *dot_inst::create_tn(value *A, value *B, value *C,
|
instruction *dot_inst::create_tn(value *A, value *B, value *C, bool allow_tf32,
|
||||||
const std::string &name, instruction *next) {
|
const std::string &name, instruction *next) {
|
||||||
return new dot_inst(A, B, C, Trans, NoTrans, name, next);
|
return new dot_inst(A, B, C, Trans, NoTrans, allow_tf32, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
instruction *dot_inst::create_tt(value *A, value *B, value *C,
|
instruction *dot_inst::create_tt(value *A, value *B, value *C, bool allow_tf32,
|
||||||
const std::string &name, instruction *next) {
|
const std::string &name, instruction *next) {
|
||||||
return new dot_inst(A, B, C, Trans, Trans, name, next);
|
return new dot_inst(A, B, C, Trans, Trans, allow_tf32, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -10,6 +10,7 @@ import torch
|
|||||||
from numpy.random import RandomState
|
from numpy.random import RandomState
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
|
import triton._C.libtriton.triton as _triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from triton.code_gen import TensorWrapper, reinterpret
|
from triton.code_gen import TensorWrapper, reinterpret
|
||||||
|
|
||||||
@@ -660,22 +661,26 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
|||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
|
@pytest.mark.parametrize("epilogue, allow_tf32",
|
||||||
def test_dot(epilogue, device='cuda'):
|
[(epilogue, allow_tf32)
|
||||||
|
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
|
||||||
|
for allow_tf32 in [True, False]])
|
||||||
|
def test_dot(epilogue, allow_tf32, device='cuda'):
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, stride_xm, stride_xk,
|
def kernel(X, stride_xm, stride_xk,
|
||||||
Y, stride_yk, stride_yn,
|
Y, stride_yk, stride_yn,
|
||||||
Z, stride_zm, stride_zn,
|
Z, stride_zm, stride_zn,
|
||||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr):
|
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||||
|
ALLOW_TF32: tl.constexpr):
|
||||||
off_m = tl.arange(0, BLOCK_M)
|
off_m = tl.arange(0, BLOCK_M)
|
||||||
off_n = tl.arange(0, BLOCK_N)
|
off_n = tl.arange(0, BLOCK_N)
|
||||||
off_k = tl.arange(0, BLOCK_K)
|
off_k = tl.arange(0, BLOCK_K)
|
||||||
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
|
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
|
||||||
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
||||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||||
z = tl.dot(tl.load(Xs), tl.load(Ys))
|
z = tl.dot(tl.load(Xs), tl.load(Ys), allow_tf32=ALLOW_TF32)
|
||||||
if ADD_MATRIX:
|
if ADD_MATRIX:
|
||||||
z += tl.load(Zs)
|
z += tl.load(Zs)
|
||||||
if ADD_ROWS:
|
if ADD_ROWS:
|
||||||
@@ -690,6 +695,12 @@ def test_dot(epilogue, device='cuda'):
|
|||||||
rs = RandomState(17)
|
rs = RandomState(17)
|
||||||
x = numpy_random((M, K), dtype_str='float32', rs=rs)
|
x = numpy_random((M, K), dtype_str='float32', rs=rs)
|
||||||
y = numpy_random((K, N), dtype_str='float32', rs=rs)
|
y = numpy_random((K, N), dtype_str='float32', rs=rs)
|
||||||
|
if allow_tf32:
|
||||||
|
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||||
|
if cc < 80:
|
||||||
|
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||||
|
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||||
|
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||||
x_tri = to_triton(x, device=device)
|
x_tri = to_triton(x, device=device)
|
||||||
y_tri = to_triton(y, device=device)
|
y_tri = to_triton(y, device=device)
|
||||||
# triton result
|
# triton result
|
||||||
@@ -703,7 +714,8 @@ def test_dot(epilogue, device='cuda'):
|
|||||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||||
ADD_MATRIX=epilogue == 'add-matrix',
|
ADD_MATRIX=epilogue == 'add-matrix',
|
||||||
ADD_ROWS=epilogue == 'add-rows',
|
ADD_ROWS=epilogue == 'add-rows',
|
||||||
ADD_COLS=epilogue == 'add-cols')
|
ADD_COLS=epilogue == 'add-cols',
|
||||||
|
ALLOW_TF32=allow_tf32)
|
||||||
# torch result
|
# torch result
|
||||||
z_ref = np.matmul(x, y)
|
z_ref = np.matmul(x, y)
|
||||||
if epilogue == 'add-matrix':
|
if epilogue == 'add-matrix':
|
||||||
@@ -718,6 +730,8 @@ def test_dot(epilogue, device='cuda'):
|
|||||||
ptx = pgm.asm['ptx']
|
ptx = pgm.asm['ptx']
|
||||||
assert 'ld.global.v4' in ptx
|
assert 'ld.global.v4' in ptx
|
||||||
assert 'st.global.v4' in ptx
|
assert 'st.global.v4' in ptx
|
||||||
|
if allow_tf32:
|
||||||
|
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||||
|
|
||||||
|
|
||||||
def test_dot_without_load():
|
def test_dot_without_load():
|
||||||
|
@@ -524,18 +524,18 @@ def reshape(input, shape, _builder=None):
|
|||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def dot(input, other, _builder=None):
|
def dot(input, other, allow_tf32=True, _builder=None):
|
||||||
"""
|
"""
|
||||||
Returns the matrix product of two blocks.
|
Returns the matrix product of two blocks.
|
||||||
|
|
||||||
The two blocks must be two dimensionals and have compatible inner dimensions.
|
The two blocks must be two dimensionals and have compatible inner dimensions.
|
||||||
|
|
||||||
:param input: The first block to be multiplied.
|
:param input: The first block to be multiplied.
|
||||||
:type input: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
|
:type input: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
||||||
:param other: The second block to be multiplied.
|
:param other: The second block to be multiplied.
|
||||||
:type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
|
:type other: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
||||||
"""
|
"""
|
||||||
return frontend.dot(input, other, _builder)
|
return frontend.dot(input, other, allow_tf32, _builder)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
Reference in New Issue
Block a user