Merge branch 'master' into rcom52_fixes
This commit is contained in:
@@ -115,6 +115,18 @@ std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_ins
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_dequantize(ir::dequantize_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
result.push_back(op_cst[d]);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
@@ -129,6 +141,36 @@ std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_cmp(ir::cmp_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
|
||||
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
|
||||
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
cst_info ax = {1, 0};
|
||||
// Examples:
|
||||
// 16 17 18 ... 32 < 24 24 24 ... 24 => equal in groups of 8
|
||||
// 16 17 18 ... 32 < 20 20 20 ... 20 => equal in groups of 4
|
||||
// 16 17 18 ... 32 < 16 16 16 ... 16 => equal in groups of 16
|
||||
//
|
||||
// if LHS is a range of N continuous (or equal) elements that starts at M,
|
||||
// and RHS is a set of N constants that start at K
|
||||
// then the result in constant in groups of gcd(M, K)
|
||||
if(rhs[d].num_cst % lhs_max_contiguous[d] == 0 ||
|
||||
rhs[d].num_cst % lhs[d].num_cst == 0)
|
||||
ax.num_cst = gcd(lhs_multiple_of[d], rhs_multiple_of[d]);
|
||||
result.push_back(ax);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
@@ -136,12 +178,14 @@ std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operat
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
auto max_contiguous = populate_max_contiguous(lhs_op);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
|
||||
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
|
||||
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
cst_info ax;
|
||||
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
|
||||
// todo might not be entirely true
|
||||
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
|
||||
unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value);
|
||||
ax = {num_constants, 0};
|
||||
}
|
||||
else
|
||||
@@ -180,10 +224,14 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
return populate_is_constant_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_is_constant_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
|
||||
return populate_is_constant_dequantize(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_is_constant_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_is_constant_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::cmp_inst*>(v))
|
||||
return populate_is_constant_cmp(x);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_is_constant_gep(x);
|
||||
return populate_is_constant_default(v);
|
||||
@@ -245,6 +293,23 @@ std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_dequantize(ir::dequantize_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
|
||||
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
for(size_t d = 0; d < shapes.size(); d++) {
|
||||
unsigned factor = 1;
|
||||
if (d == shapes.size() - 1) {
|
||||
factor = ret_last_dim / op_last_dim;
|
||||
}
|
||||
result.push_back(factor * op_mc[d]);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
@@ -285,8 +350,8 @@ std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator*
|
||||
}
|
||||
if(x->is_int_add_sub()){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]);
|
||||
rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]);
|
||||
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
|
||||
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
|
||||
value = std::max(lvalue, rvalue);
|
||||
}
|
||||
result.push_back(value);
|
||||
@@ -332,9 +397,9 @@ std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||
return max_contiguous_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
|
||||
if(max_contiguous > 0)
|
||||
return add_to_cache(x, {max_contiguous}, max_contiguous_);
|
||||
std::vector<unsigned> max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
|
||||
if(!max_contiguous.empty())
|
||||
return add_to_cache(x, max_contiguous, max_contiguous_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||
return populate_max_contiguous_cast(x);
|
||||
@@ -342,6 +407,8 @@ std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||
return populate_max_contiguous_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_max_contiguous_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
|
||||
return populate_max_contiguous_dequantize(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_max_contiguous_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
@@ -386,6 +453,23 @@ std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_dequantize(ir::dequantize_inst* x){
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
|
||||
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
|
||||
auto op_multiple = populate_starting_multiple(op);
|
||||
for(size_t d = 0; d < shapes.size(); d++) {
|
||||
unsigned factor = 1;
|
||||
if (d == shapes.size() - 1) {
|
||||
factor = ret_last_dim / op_last_dim;
|
||||
}
|
||||
result.push_back(factor * op_multiple[d]);
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
|
||||
auto result = populate_starting_multiple(x->get_operand(0));
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
@@ -401,7 +485,7 @@ std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operato
|
||||
if(x->is_int_add_sub())
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
if(x->is_int_div())
|
||||
result[d] = 1;
|
||||
result[d] = (lhs[d] == (1 << 31)) ? 1 << 31 : 1;
|
||||
if(x->is_int_rem() && rhs[d] > 1){
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
}
|
||||
@@ -471,28 +555,42 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
||||
return add_to_cache(v, {1}, starting_multiple_);
|
||||
}
|
||||
|
||||
unsigned get_max_multiple(int val){
|
||||
if(val == 0) return 1 << 31;
|
||||
if(val % 128 == 0) return 128;
|
||||
if(val % 64 == 0) return 64;
|
||||
if(val % 32 == 0) return 32;
|
||||
if(val % 16 == 0) return 16;
|
||||
if(val % 8 == 0) return 8;
|
||||
if(val % 4 == 0) return 4;
|
||||
if(val % 2 == 0) return 2;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||
return starting_multiple_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
if(multiple_of > 0)
|
||||
return add_to_cache(x, {multiple_of}, starting_multiple_);
|
||||
std::vector<unsigned> multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
if(!multiple_of.empty())
|
||||
return add_to_cache(x, multiple_of, starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||
return populate_starting_multiple_cast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_starting_multiple_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
|
||||
return add_to_cache(x, {get_max_multiple(x->get_value())}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
|
||||
return add_to_cache(x, {get_max_multiple(x->get_first()->get_value())}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_starting_multiple_gep(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_starting_multiple_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_starting_multiple_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
|
||||
return populate_starting_multiple_dequantize(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_starting_multiple_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
@@ -511,12 +609,15 @@ std::vector<unsigned> align::contiguous(ir::value* v) const {
|
||||
return max_contiguous_.at(v);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::get_cst_info(ir::value* v) const {
|
||||
return is_constant_.at(v);
|
||||
}
|
||||
|
||||
|
||||
void align::populate(ir::value *v) {
|
||||
populate_is_constant(v);
|
||||
populate_starting_multiple(v);
|
||||
populate_max_contiguous(v);
|
||||
|
||||
}
|
||||
|
||||
void align::run(ir::module &mod) {
|
||||
|
@@ -92,8 +92,10 @@ void allocation::run(ir::module &mod) {
|
||||
}
|
||||
// Save maximum size of induced memory space
|
||||
allocated_size_ = 0;
|
||||
for(shared_layout* x: V)
|
||||
for(shared_layout* x: V){
|
||||
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
|
||||
// std::cout << "start: " << starts[x] << " | end: " << starts[x] + x->get_size() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -56,6 +56,17 @@ void axes::update_graph_trans(ir::instruction *i) {
|
||||
graph_.add_edge({i, perm[d]}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_dequantize(ir::instruction *i) {
|
||||
auto *dequantize = static_cast<ir::dequantize_inst*>(i);
|
||||
auto shapes = dequantize->get_type()->get_block_shapes();
|
||||
ir::value *op = dequantize->get_operand(0);
|
||||
|
||||
// add edge except the last axis
|
||||
for(unsigned d = 0; d < shapes.size() - 1; d ++){
|
||||
graph_.add_edge({i, d}, {op, d});
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_broadcast(ir::instruction *i) {
|
||||
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
|
||||
auto shapes = broadcast->get_type()->get_block_shapes();
|
||||
@@ -79,7 +90,7 @@ void axes::update_graph_dot(ir::instruction *i) {
|
||||
graph_.add_edge({dot, d}, {D, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_elementwise(ir::instruction *i,
|
||||
void axes::update_graph_elementwise(ir::instruction *i,
|
||||
bool is_masked_load_async) {
|
||||
if(i->get_num_operands() == 0)
|
||||
return;
|
||||
@@ -119,6 +130,7 @@ void axes::update_graph(ir::instruction *i) {
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);
|
||||
case ir::INST_CAT: return update_graph_elementwise(i, true);
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_DEQUANTIZE: return update_graph_dequantize(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
||||
|
@@ -23,19 +23,67 @@ inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
|
||||
return std::min(std::max(x, lo), hi);
|
||||
}
|
||||
|
||||
inline bool is_hmma_c(ir::value *v){
|
||||
inline bool is_hmma_c(ir::value *v, int sm){
|
||||
bool result = false;
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *a = x->get_operand(0);
|
||||
ir::type *a_ty = a->get_type();
|
||||
ir::value *b = x->get_operand(1);
|
||||
ir::type *b_ty = b->get_type();
|
||||
result = a_ty->get_scalar_ty()->is_fp16_ty() &&
|
||||
b_ty->get_scalar_ty()->is_fp16_ty();
|
||||
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
|
||||
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
|
||||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
|
||||
x->allow_tf32() && sm >= 80) ||
|
||||
(a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&
|
||||
sm >= 80);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
|
||||
mma_layout::TensorCoreType mma_type;
|
||||
if (auto* dot = dynamic_cast<ir::dot_inst*>(v)) {
|
||||
ir::value* a = dot->get_operand(0);
|
||||
ir::value* b = dot->get_operand(1);
|
||||
ir::type* a_ty = a->get_type();
|
||||
ir::type* b_ty = b->get_type();
|
||||
ir::type* c_ty = v->get_type();
|
||||
|
||||
if (c_ty->get_scalar_ty()->is_fp32_ty()) {
|
||||
// floating point tensor cores
|
||||
if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) {
|
||||
mma_type = mma_layout::FP32_FP16_FP16_FP32;
|
||||
return mma_type;
|
||||
}
|
||||
if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) {
|
||||
mma_type = mma_layout::FP32_BF16_BF16_FP32;
|
||||
return mma_type;
|
||||
}
|
||||
if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty()
|
||||
&& dot->allow_tf32()) {
|
||||
mma_type = mma_layout::FP32_TF32_TF32_FP32;
|
||||
return mma_type;
|
||||
}
|
||||
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
|
||||
// throw std::runtime_error("integer tensor cores are not yet supported");
|
||||
// // integer tensor cores
|
||||
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
|
||||
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
|
||||
// return mma_type;
|
||||
// }
|
||||
// if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) {
|
||||
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
|
||||
// return mma_type;
|
||||
// }
|
||||
if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
|
||||
mma_type = mma_layout::INT32_INT8_INT8_INT32;
|
||||
return mma_type;
|
||||
}
|
||||
}
|
||||
}
|
||||
return mma_layout::NOT_APPLICABLE;
|
||||
}
|
||||
|
||||
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||
@@ -52,11 +100,12 @@ inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||
}
|
||||
}
|
||||
|
||||
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||
if(i && is_hmma_c(i) && i->get_operand(n) == v)
|
||||
if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) {
|
||||
result = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,7 +191,9 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target* tgt,
|
||||
shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) {
|
||||
shared_layout *layout_a, shared_layout *layout_b,
|
||||
ir::value *dot): distributed_layout(MMA, axes, shape, values, align) {
|
||||
tensor_core_type_ = get_mma_type(dot);
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
if(tgt->as_nvidia() && tgt->as_nvidia()->sm() < 80){
|
||||
@@ -157,25 +208,67 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
|
||||
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
|
||||
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
|
||||
contig_per_thread_ = {1, 1};
|
||||
order_ = {0, 1};
|
||||
}
|
||||
else{
|
||||
fpw_ = {1, 1, 1};
|
||||
spw_ = {16, 8, 1};
|
||||
rep_ = {2, 2, 1};
|
||||
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
|
||||
contig_per_thread_ = {1, 2};
|
||||
order_ = {1, 0};
|
||||
}
|
||||
order_ = {0, 1};
|
||||
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
wpt_ = {1, 1, 1};
|
||||
std::vector<int> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt_;
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
||||
}while(wpt_nm1 != wpt_);
|
||||
// try to make warp-level tiles as square as possible to maximize data re-use
|
||||
if (tgt->as_nvidia()->sm() < 80) {
|
||||
std::vector<int> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt_;
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
||||
}while(wpt_nm1 != wpt_);
|
||||
} else {
|
||||
bool changed = false;
|
||||
// try to have a warp own entire rows of the output
|
||||
// this makes it easier to fuse multiple mmas by fusing
|
||||
// registers
|
||||
bool one_warp_per_row = false;
|
||||
for(ir::value* v: values)
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto* dot = dynamic_cast<ir::dot_inst*>(u);
|
||||
auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(u);
|
||||
if((dot && dot->get_operand(2)!=v) || !layout_a->to_shared() || cts)
|
||||
one_warp_per_row = shape[0] / spw_[0] >= num_warps;
|
||||
}
|
||||
// std::cout << one_warp_per_row << std::endl;
|
||||
|
||||
if(one_warp_per_row){
|
||||
wpt_[1] = 1;
|
||||
wpt_[0] = num_warps;
|
||||
}
|
||||
else{
|
||||
do {
|
||||
changed = false;
|
||||
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
|
||||
break;
|
||||
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
|
||||
if (wpt_[0] < shape_[0] / spw_[0]) {
|
||||
wpt_[0] *= 2;
|
||||
changed = true;
|
||||
}
|
||||
} else {
|
||||
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
|
||||
wpt_[1] *= 2;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
} while(changed);
|
||||
}
|
||||
}
|
||||
|
||||
// std::cout << wpt_[0] << " " << wpt_[1] << std::endl;
|
||||
|
||||
/* shape per block */
|
||||
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
|
||||
@@ -198,8 +291,6 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
bool is_dot = std::any_of(values.begin(), values.end(),
|
||||
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
|
||||
|
||||
|
||||
|
||||
std::vector<ir::value*> ptrs;
|
||||
for(ir::value *v: values)
|
||||
for(ir::user *usr: v->get_users())
|
||||
@@ -215,7 +306,6 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
|
||||
}
|
||||
|
||||
|
||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
size /= shape_[i];
|
||||
@@ -277,12 +367,16 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
|
||||
res.reset(new double_buffer_info_t{value_1, value_0, phi});
|
||||
}
|
||||
|
||||
static bool is_smem(ir::value* v) {
|
||||
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||
dynamic_cast<ir::masked_load_async_inst*>(v))
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
static bool is_smem_in(ir::value* v, const ir::basic_block* bb) {
|
||||
if (ir::instruction *instr = dynamic_cast<ir::instruction*>(v)) {
|
||||
if (instr->get_parent() != bb)
|
||||
return false;
|
||||
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||
dynamic_cast<ir::masked_load_async_inst*>(v)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// param:
|
||||
@@ -297,14 +391,14 @@ static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::
|
||||
ir::basic_block *cbb0 = cphi->get_incoming_block(0);
|
||||
ir::basic_block *cbb1 = cphi->get_incoming_block(1);
|
||||
|
||||
if (is_smem(c0)) {
|
||||
if (is_smem_in(c0, cbb0)) {
|
||||
assert(cbb0 == bb0);
|
||||
values_0.push_back(c0);
|
||||
if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) {
|
||||
next = phi1;
|
||||
continue;
|
||||
} else {
|
||||
if (is_smem(c1)) {
|
||||
if (is_smem_in(c1, cbb1)) {
|
||||
value_1 = c1;
|
||||
assert(cbb1 == bb1);
|
||||
return true;
|
||||
@@ -359,7 +453,8 @@ shared_layout::shared_layout(data_layout *arg,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
ir::type *ty,
|
||||
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
|
||||
analysis::align* align, target *tgt, bool is_tmp)
|
||||
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt), is_tmp_(is_tmp){
|
||||
|
||||
size_ = 0;
|
||||
arg_layout_ = arg;
|
||||
@@ -385,12 +480,35 @@ shared_layout::shared_layout(data_layout *arg,
|
||||
for(ir::value* v: values){
|
||||
extract_dot_use(v, dot_a, 0);
|
||||
extract_dot_use(v, dot_b, 1);
|
||||
extract_hmma_dot_use(v, hmma_dot_a, 0);
|
||||
extract_hmma_dot_use(v, hmma_dot_b, 1);
|
||||
extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm());
|
||||
extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm());
|
||||
}
|
||||
hmma_dot_a_ = hmma_dot_a;
|
||||
hmma_dot_b_ = hmma_dot_b;
|
||||
|
||||
// Update mma_vec
|
||||
if (hmma_dot_a_) {
|
||||
assert(order_.size() == 2);
|
||||
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
|
||||
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
|
||||
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
|
||||
|
||||
// for now, disable swizzle when using lds.8
|
||||
if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32)
|
||||
if (order_[0] == 0) // need transpose
|
||||
allow_swizzle_ = false;
|
||||
} else if (hmma_dot_b_) {
|
||||
assert(order_.size() == 2);
|
||||
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
|
||||
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
|
||||
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
|
||||
|
||||
// for now, disable swizzle when using lds.8
|
||||
if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32)
|
||||
if (order_[0] == 1) // need transpose
|
||||
allow_swizzle_ = false;
|
||||
}
|
||||
|
||||
// size
|
||||
size_ = ty_->get_primitive_size_in_bits() / 8;
|
||||
for(auto s: shape_)
|
||||
@@ -454,7 +572,8 @@ void layouts::make_graph(ir::instruction *i) {
|
||||
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
// if(layouts_.find(id) != layouts_.end())
|
||||
// return;
|
||||
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
||||
auto it_hmma_c = std::find_if(values.begin(), values.end(),
|
||||
[&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); });
|
||||
auto cmp = [](ir::value* x, ir::value *y) {
|
||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
|
||||
@@ -476,19 +595,61 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
ir::value *b = dot->get_operand(1);
|
||||
create(groups_.at(a), values_.at(groups_.at(a)));
|
||||
create(groups_.at(b), values_.at(groups_.at(b)));
|
||||
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b)));
|
||||
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_,
|
||||
(shared_layout*)layouts_.at(groups_.at(a)),
|
||||
(shared_layout*)layouts_.at(groups_.at(b)),
|
||||
dot);
|
||||
}
|
||||
else if(it_cts != values.end()){
|
||||
ir::instruction *cts = (ir::instruction*)*it_cts;
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
}
|
||||
else{
|
||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
|
||||
}
|
||||
}
|
||||
|
||||
// layout checkers
|
||||
bool layouts::is_scanline(ir::instruction *i) {
|
||||
return this->get(i->get_operand(0))->to_scanline() != nullptr;
|
||||
}
|
||||
|
||||
bool layouts::is_coalesced_scanline(ir::instruction *i) {
|
||||
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
|
||||
auto *scanline = this->get(i->get_operand(0))->to_scanline();
|
||||
return scanline && scanline->get_order()[0] == red->get_axis();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool layouts::is_mma(ir::instruction *i) {
|
||||
return this->get(i->get_operand(0))->to_mma() != nullptr;
|
||||
}
|
||||
|
||||
bool layouts::is_a100_mma(ir::instruction *i) {
|
||||
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
|
||||
return is_mma(red) && (tgt_->as_nvidia()->sm() >= 80) &&
|
||||
(red->get_axis() == 1);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void layouts::create_tmp_layout(size_t id, data_layout *arg,
|
||||
const std::vector<int> &axes,
|
||||
const std::vector<unsigned> &shape,
|
||||
ir::instruction *i, bool is_index) {
|
||||
ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
|
||||
: i->get_type()->get_scalar_ty();
|
||||
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_, true);
|
||||
if (is_index) {
|
||||
tmp_index_[i] = id;
|
||||
} else {
|
||||
tmp_[i] = id;
|
||||
}
|
||||
}
|
||||
|
||||
void layouts::run(ir::module &mod) {
|
||||
// make graph
|
||||
graph_.clear();
|
||||
@@ -510,35 +671,47 @@ void layouts::run(ir::module &mod) {
|
||||
// create temporaries
|
||||
size_t id = values_.size();
|
||||
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
||||
// std::cout << "layout: " << std::endl;
|
||||
// i->print(std::cout);
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
|
||||
id++;
|
||||
ir::value *arg = red->get_operand(0);
|
||||
unsigned axis = red->get_axis();
|
||||
distributed_layout *layout =
|
||||
dynamic_cast<analysis::distributed_layout *>(get(arg));
|
||||
// shape
|
||||
auto shapes = arg->get_type()->get_block_shapes();
|
||||
scanline_layout *layout = get(arg)->to_scanline();
|
||||
shapes[axis] = layout->mts(axis);
|
||||
unsigned axis = red->get_axis();
|
||||
shapes[axis] =
|
||||
layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
|
||||
// create layout
|
||||
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[red] = id;
|
||||
id++;
|
||||
create_tmp_layout(id, layout, axes_->get(arg), shapes, red);
|
||||
|
||||
if (red->with_index()) {
|
||||
id++;
|
||||
create_tmp_layout(id, layout, axes_->get(arg), shapes, red, true);
|
||||
}
|
||||
}
|
||||
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
|
||||
distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
|
||||
distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
|
||||
id++;
|
||||
size_t dim = val->get_type()->get_tile_rank();
|
||||
ir::type::block_shapes_t shape(dim);
|
||||
for(size_t k = 0; k < dim; k++){
|
||||
shape[k] = std::max(in_layout->shape_per_cta(k),
|
||||
out_layout->shape_per_cta(k));
|
||||
}
|
||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[val] = id;
|
||||
auto in_ord = in_layout->get_order();
|
||||
auto out_ord = out_layout->get_order();
|
||||
int in_vec = in_layout->contig_per_thread(in_ord[0]);
|
||||
int out_vec = out_layout->contig_per_thread(out_ord[0]);
|
||||
int pad = std::max(in_vec, out_vec);
|
||||
shape[out_ord[0]] += pad;
|
||||
id++;
|
||||
create_tmp_layout(id, out_layout, axes_->get(val), shape, val);
|
||||
}
|
||||
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
|
||||
id++;
|
||||
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[atom] = id;
|
||||
create_tmp_layout(id, nullptr, {}, {1}, atom);
|
||||
}
|
||||
});
|
||||
|
||||
|
@@ -14,43 +14,108 @@ namespace analysis{
|
||||
void liveness::run(ir::module &mod) {
|
||||
intervals_.clear();
|
||||
|
||||
// Assigns index to each instruction
|
||||
std::map<ir::value*, slot_index> indices;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
slot_index index = 0;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()){
|
||||
index += 1;
|
||||
indices.insert({instr, index});
|
||||
std::map<ir::value*, std::set<shared_layout*>> layouts_map;
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(!layout || layout->is_tmp())
|
||||
continue;
|
||||
for(ir::value* v:layout->get_values()){
|
||||
layouts_map[v].insert(layout);
|
||||
}
|
||||
}
|
||||
|
||||
// create live intervals
|
||||
|
||||
|
||||
std::map<ir::user*, std::set<shared_layout*>> live_in;
|
||||
while(true){
|
||||
bool changed = false;
|
||||
ir::instruction* last_inst = nullptr;
|
||||
ir::for_each_instruction_backward(mod, [&](ir::instruction* i){
|
||||
// gen
|
||||
std::set<shared_layout*> gen;
|
||||
for(ir::value* v: i->ops())
|
||||
for(shared_layout* layout: layouts_map[v])
|
||||
gen.insert(layout);
|
||||
// kill
|
||||
std::set<shared_layout*> kill;
|
||||
for(shared_layout* layout: layouts_map[i])
|
||||
kill.insert(layout);
|
||||
// temporaries are handled separately
|
||||
if(layouts_->has_tmp(i)){
|
||||
gen.insert(layouts_->get(layouts_->tmp(i))->to_shared());
|
||||
kill.insert(layouts_->get(layouts_->tmp(i))->to_shared());
|
||||
}
|
||||
if(layouts_->has_tmp_index(i)){
|
||||
gen.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
|
||||
kill.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
|
||||
}
|
||||
// live-out
|
||||
std::set<shared_layout*> live_out;
|
||||
std::vector<ir::instruction*> succs = {last_inst};
|
||||
if(i == i->get_parent()->get_inst_list().back())
|
||||
for(ir::basic_block* succ: i->get_parent()->get_successors())
|
||||
succs.push_back(succ->get_inst_list().front());
|
||||
for(ir::instruction* succ: succs)
|
||||
for(shared_layout* layout: live_in[succ])
|
||||
if(!layout->is_tmp())
|
||||
live_out.insert(layout);
|
||||
|
||||
// new sets
|
||||
std::set<shared_layout*> live_out_minus_kill;
|
||||
std::set_difference(live_out.begin(), live_out.end(), kill.begin(), kill.end(),
|
||||
std::inserter(live_out_minus_kill, live_out_minus_kill.end()));
|
||||
std::set<shared_layout*> new_live_in;
|
||||
std::set_union(gen.begin(), gen.end(), live_out_minus_kill.begin(), live_out_minus_kill.end(),
|
||||
std::inserter(new_live_in, new_live_in.end()));
|
||||
|
||||
changed = changed || (new_live_in != live_in[i]);
|
||||
live_in[i] = new_live_in;
|
||||
last_inst = i;
|
||||
});
|
||||
if(!changed)
|
||||
break;
|
||||
}
|
||||
|
||||
// ir::for_each_instruction(mod, [&](ir::instruction* i){
|
||||
// i->print(std::cout);
|
||||
// std::cout << " live_in: " << live_in[i].size() << std::endl;
|
||||
// });
|
||||
|
||||
|
||||
|
||||
// Assigns index to each instruction
|
||||
std::map<ir::value*, slot_index> indices;
|
||||
slot_index index = 0;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction* instr){
|
||||
index += 1;
|
||||
indices.insert({instr, index});
|
||||
});
|
||||
|
||||
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(layout)
|
||||
intervals_[layout] = segment{INT32_MAX, 0};
|
||||
}
|
||||
|
||||
for(auto& x: live_in)
|
||||
for(shared_layout* layout: x.second)
|
||||
intervals_[layout].start = std::min<int>(intervals_[layout].start, indices[x.first]);
|
||||
|
||||
for(auto& x: live_in)
|
||||
for(shared_layout* layout: x.second){
|
||||
intervals_[layout].end = std::max<int>(intervals_[layout].end, indices[x.first] + 1);
|
||||
}
|
||||
|
||||
|
||||
for(auto &x: layouts_->get_all()) {
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(!layout)
|
||||
continue;
|
||||
// users
|
||||
std::set<ir::user*> users;
|
||||
for(ir::value *v: layout->get_values()){
|
||||
for(ir::user *u: v->get_users())
|
||||
users.insert(u);
|
||||
}
|
||||
// compute intervals
|
||||
unsigned start = INT32_MAX;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(indices.find(v) != indices.end())
|
||||
start = std::min(start, indices.at(v));
|
||||
unsigned end = 0;
|
||||
for(ir::user *u: users)
|
||||
if(indices.find(u) != indices.end())
|
||||
end = std::max(end, indices.at(u));
|
||||
if(end == 0)
|
||||
end = start + 1;
|
||||
intervals_[layout] = segment{start, end};
|
||||
// std::cout << intervals_[layout].start << " " << intervals_[layout].end << std::endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
@@ -19,6 +19,7 @@ void swizzle::run(ir::module &) {
|
||||
continue;
|
||||
ir::value* mma_dot_a = layout->hmma_dot_a();
|
||||
ir::value* mma_dot_b = layout->hmma_dot_b();
|
||||
|
||||
if(!mma_dot_a && !mma_dot_b){
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
@@ -27,22 +28,31 @@ void swizzle::run(ir::module &) {
|
||||
}
|
||||
auto ord = layout->get_order();
|
||||
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
|
||||
if(!in_layout)
|
||||
continue;
|
||||
int per_phase = 1;
|
||||
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(in_layout)
|
||||
per_phase = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
else
|
||||
per_phase = 1;
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
|
||||
int inner = mma_dot_a ? 0 : 1;
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
per_phase_[layout] = per_phase;
|
||||
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
|
||||
if(mma_dot_a)
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||
else
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||
}
|
||||
else{
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = 8 / per_phase_[layout];
|
||||
vec_[layout] = 8;
|
||||
else {
|
||||
if (!layout->allow_swizzle()) {
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
} else {
|
||||
per_phase_[layout] = per_phase;
|
||||
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
|
||||
vec_[layout] = layout->get_mma_vec();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
63
lib/codegen/extern_lib.cc
Normal file
63
lib/codegen/extern_lib.cc
Normal file
@@ -0,0 +1,63 @@
|
||||
#include "triton/codegen/extern_lib.h"
|
||||
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Metadata.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
#include "llvm/Linker/Linker.h"
|
||||
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
||||
#include "triton/codegen/pass.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace codegen {
|
||||
|
||||
std::unique_ptr<llvm::Module> ExternLib::load(llvm::LLVMContext& ctx) {
|
||||
llvm::SMDiagnostic err;
|
||||
auto mod = llvm::parseIRFile(this->path_, err, ctx);
|
||||
if (!mod) {
|
||||
throw std::runtime_error("Failed to load extern lib " + this->name_ +
|
||||
" at " + this->path_);
|
||||
}
|
||||
return mod;
|
||||
}
|
||||
|
||||
void ExternLib::link(std::unique_ptr<llvm::Module>& llvm,
|
||||
std::unique_ptr<llvm::Module>& mod) {
|
||||
// Set triple and data layout to match the target module
|
||||
mod->setTargetTriple(llvm->getTargetTriple());
|
||||
mod->setDataLayout(llvm->getDataLayout());
|
||||
if (llvm::Linker::linkModules(*llvm, std::move(mod))) {
|
||||
throw std::runtime_error("Failed to link extern lib " + this->name_ +
|
||||
" at " + this->path_);
|
||||
}
|
||||
}
|
||||
|
||||
void LibDevice::opt(llvm::LLVMContext& ctx, std::unique_ptr<llvm::Module>& llvm) {
|
||||
// Add nvvm reflect flags to llvm module
|
||||
// https://llvm.org/docs/LangRef.html#module-flags-metadata
|
||||
// i32 4: Override the other module.
|
||||
// i32 1: Emit an error
|
||||
// If both modules specify Override, but the values differ, an error
|
||||
// will be emitted.
|
||||
llvm::Type* I32 = llvm::Type::getInt32Ty(ctx);
|
||||
llvm::Metadata* md_four =
|
||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
|
||||
llvm::Metadata* md_name = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
|
||||
llvm::Metadata* md_one =
|
||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
|
||||
llvm::MDNode* reflect = llvm::MDNode::get(ctx, {md_four, md_name, md_one});
|
||||
llvm->addModuleFlag(reflect);
|
||||
}
|
||||
|
||||
std::unique_ptr<ExternLib> create_extern_lib(const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
if (lib_name == "libdevice") {
|
||||
return std::make_unique<LibDevice>(lib_name, lib_path);
|
||||
} else {
|
||||
throw std::runtime_error("Unknown external library: " + lib_name);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
@@ -1,4 +1,14 @@
|
||||
#include "triton/codegen/pass.h"
|
||||
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Linker/Linker.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Transforms/IPO.h"
|
||||
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
@@ -9,6 +19,7 @@
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/codegen/transform/disassociate.h"
|
||||
#include "triton/codegen/transform/inline.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
@@ -16,44 +27,90 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen {
|
||||
|
||||
static void link_extern_libs(const ExternLibMap& user_extern_lib_map,
|
||||
const ExternLibMap& target_extern_lib_map,
|
||||
ir::module& ir, llvm::LLVMContext& ctx,
|
||||
std::unique_ptr<llvm::Module>& llvm) {
|
||||
for (const auto& iter : target_extern_lib_map) {
|
||||
auto &lib_name = iter.first;
|
||||
if (user_extern_lib_map.count(lib_name) != 0 &&
|
||||
user_extern_lib_map.at(lib_name)->path() != "") {
|
||||
// If the user specified a path for this library, use it.
|
||||
user_extern_lib_map.at(lib_name)->install(ctx, llvm);
|
||||
} else {
|
||||
// Otherwise, use the default path.
|
||||
iter.second->install(ctx, llvm);
|
||||
}
|
||||
}
|
||||
|
||||
std::set<llvm::StringRef> function_names;
|
||||
for (auto& func : ir.get_function_list()) {
|
||||
function_names.insert(func->get_name());
|
||||
}
|
||||
llvm::legacy::PassManager pass;
|
||||
pass.add(llvm::createInternalizePass([&](const llvm::GlobalValue& v) -> bool {
|
||||
if (function_names.count(v.getName()) != 0) {
|
||||
// Preserve global functions
|
||||
return true;
|
||||
}
|
||||
// Internalize all device functions
|
||||
return false;
|
||||
}));
|
||||
|
||||
llvm::legacy::PassManager pm;
|
||||
pm.add(llvm::createVerifierPass());
|
||||
pm.run(*llvm);
|
||||
|
||||
llvm::PassManagerBuilder builder;
|
||||
builder.OptLevel = 3;
|
||||
builder.SizeLevel = 0;
|
||||
builder.populateModulePassManager(pass);
|
||||
|
||||
pass.run(*llvm);
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
|
||||
int cc, int num_warps, int num_stages, int& shared_static) {
|
||||
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
|
||||
ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target,
|
||||
int num_warps, int num_stages, int& shared_static,
|
||||
const ExternLibMap& extern_lib_map) {
|
||||
// generate llvm code
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||
// optimizations
|
||||
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
|
||||
bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
|
||||
// create passes
|
||||
codegen::analysis::align align;
|
||||
codegen::transform::inliner inliner;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::transform::cts cts(cts_use_async);
|
||||
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
|
||||
codegen::transform::pipeline pipeline(has_sm80, num_stages);
|
||||
codegen::transform::disassociate disassociate;
|
||||
codegen::analysis::layouts layouts(&axes, &align, num_warps, target);
|
||||
codegen::transform::cts cts(&layouts, has_sm80);
|
||||
codegen::analysis::liveness liveness(&layouts);
|
||||
codegen::analysis::swizzle swizzle(&layouts, target);
|
||||
codegen::analysis::allocation allocation(&liveness);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole(target, &layouts);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts, has_sm80);
|
||||
codegen::transform::prefetch prefetch_s(target);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation,
|
||||
&prefetch_s, target);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle,
|
||||
target, num_warps);
|
||||
// run passes
|
||||
inliner.run(ir);
|
||||
dce.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
pipeline.run(ir);
|
||||
dce.run(ir);
|
||||
// ir.print(std::cout);
|
||||
disassociate.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
@@ -61,8 +118,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
layouts.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
if (target->is_gpu())
|
||||
cts.run(ir);
|
||||
if (target->is_gpu()) cts.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
@@ -70,8 +126,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
dce.run(ir);
|
||||
if (target->is_gpu())
|
||||
cts.run(ir);
|
||||
if (target->is_gpu()) cts.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
axes.run(ir);
|
||||
@@ -82,14 +137,27 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
axes.run(ir);
|
||||
layouts.run(ir);
|
||||
swizzle.run(ir);
|
||||
// std::cout << "---" << std::endl;
|
||||
// ir.print(std::cout);
|
||||
// std::cout << "---" << std::endl;
|
||||
// ir.print(std::cout);
|
||||
liveness.run(ir);
|
||||
allocation.run(ir);
|
||||
prefetch_s.run(ir);
|
||||
barriers.run(ir);
|
||||
// exit(1);
|
||||
// ir.print(std::cout);
|
||||
isel.visit(ir, *llvm);
|
||||
shared_static = allocation.allocated_size();
|
||||
|
||||
if (isel.get_extern_lib_map().size() > 0) {
|
||||
// If there's any extern lib calls,
|
||||
// we need to link them in.
|
||||
link_extern_libs(extern_lib_map, isel.get_extern_lib_map(), ir, ctx, llvm);
|
||||
}
|
||||
|
||||
return llvm;
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -12,46 +12,11 @@ namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
|
||||
: align_(align), layout_(layouts) { }
|
||||
|
||||
|
||||
// simplify layout conversions using the following simple rules:
|
||||
// - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
|
||||
// - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y))
|
||||
//ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){
|
||||
// ir::value* _op = inst->get_operand(0);
|
||||
// ir::instruction* op = dynamic_cast<ir::instruction*>(_op);
|
||||
// analysis::mma_layout* mma_in = layout_->get(op) ->to_mma();
|
||||
// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma();
|
||||
// std::cout << 1 << std::endl;
|
||||
// // i must be layout conversion instruction
|
||||
// if(!mma_in && !mma_out)
|
||||
// return inst;
|
||||
// // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
|
||||
// bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT;
|
||||
// if((mma_in || mma_out) && is_op_cvt &&
|
||||
// (layout_->get(inst) == layout_->get(op->get_operand(0))))
|
||||
// return op->get_operand(0);
|
||||
// // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y))
|
||||
// if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR)
|
||||
// return inst;
|
||||
// std::cout << 1 << std::endl;
|
||||
// for(size_t i = 0; i < op->get_num_operands(); i++){
|
||||
// ir::value* arg_i = op->get_operand(i);
|
||||
// builder.set_insert_point(op);
|
||||
// // create new layout transform
|
||||
// ir::instruction* new_arg_i = inst->clone();
|
||||
// builder.insert(new_arg_i);
|
||||
// // set the right args
|
||||
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
|
||||
// op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder));
|
||||
// }
|
||||
// std::cout << 2 << std::endl;
|
||||
// return op;
|
||||
//}
|
||||
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80)
|
||||
: align_(align), layout_(layouts), has_sm80_(has_sm80) { }
|
||||
|
||||
void coalesce::run(ir::module &mod) {
|
||||
std::set<analysis::data_layout*> invalidated;
|
||||
ir::builder& builder = mod.get_builder();
|
||||
// add layout conversion instructions
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
@@ -61,23 +26,43 @@ void coalesce::run(ir::module &mod) {
|
||||
if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
|
||||
if(ir::value* op = i->get_operand(1))
|
||||
if(op->get_type()->is_block_ty())
|
||||
if(layout_->get(op)->to_mma()){
|
||||
if(op->get_type()->get_tile_ranks1() == 2)
|
||||
if(invalidated.find(layout_->get(op)) == invalidated.end())
|
||||
if(layout_->get(op)->to_mma())
|
||||
if(dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
|
||||
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
|
||||
builder.set_insert_point(i);
|
||||
builder.insert(new_op);
|
||||
i->replace_uses_of_with(op, new_op);
|
||||
}
|
||||
// coalesce before copy_to_shared
|
||||
// only necessary for sm < 80 as Ampere+ can handle reduction
|
||||
// on MMA layout
|
||||
if(!has_sm80_)
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::reduce_inst*>(i))
|
||||
if(ir::value* op = i->get_operand(0))
|
||||
if(op->get_type()->is_block_ty())
|
||||
if(op->get_type()->get_tile_ranks1() == 2)
|
||||
if(invalidated.find(layout_->get(op)) == invalidated.end())
|
||||
if(layout_->get(op)->to_mma()){
|
||||
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
|
||||
builder.set_insert_point(i);
|
||||
builder.insert(new_op);
|
||||
op->replace_all_uses_with(new_op);
|
||||
new_op->replace_uses_of_with(new_op, op);
|
||||
invalidated.insert(layout_->get(op));
|
||||
}
|
||||
// uncoalesce after load
|
||||
if(auto x = dynamic_cast<ir::load_inst*>(i))
|
||||
if(x->get_type()->is_block_ty())
|
||||
if(x->get_type()->get_tile_rank()==2)
|
||||
if(layout_->get(x)->to_mma()){
|
||||
if(x->get_type()->get_tile_ranks1()==2)
|
||||
if(layout_->get(x)->to_mma())
|
||||
if(!has_sm80_ || dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
|
||||
builder.set_insert_point_after(x);
|
||||
ir::instruction* new_x = ir::cvt_layout_inst::create(x);
|
||||
builder.insert(new_x);
|
||||
x->replace_all_uses_with(new_x);
|
||||
new_x->replace_uses_of_with(new_x, x);
|
||||
// new_x->replace_uses_of_with(new_x, new_x);
|
||||
}
|
||||
}
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
@@ -90,9 +75,11 @@ void coalesce::run(ir::module &mod) {
|
||||
auto out_contig = align_->contiguous(ptr);
|
||||
auto val_inst = dynamic_cast<ir::instruction*>(val);
|
||||
if(!val_inst)
|
||||
break;
|
||||
continue;
|
||||
if(dynamic_cast<ir::cvt_layout_inst*>(val))
|
||||
break;
|
||||
continue;
|
||||
if(!val->get_type()->is_block_ty() || val->get_type()->get_tile_ranks1()==1)
|
||||
continue;
|
||||
std::vector<unsigned> in_contig;
|
||||
std::vector<ir::instruction*> queue = {val_inst};
|
||||
std::set<ir::instruction*> seen;
|
||||
@@ -101,6 +88,8 @@ void coalesce::run(ir::module &mod) {
|
||||
ir::instruction* curr = queue.back();
|
||||
seen.insert(curr);
|
||||
queue.pop_back();
|
||||
if(auto dot_inst = dynamic_cast<ir::dot_inst*>(curr))
|
||||
break;
|
||||
if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){
|
||||
in_contig = align_->contiguous(io_inst->get_pointer_operand());
|
||||
break;
|
||||
|
@@ -1,8 +1,10 @@
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
@@ -10,9 +12,9 @@ namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
inline bool is_shmem_op(ir::instruction* i, int op) {
|
||||
bool cts::is_shmem_op(ir::instruction* i, int op) {
|
||||
if(i->get_id() == ir::INST_DOT)
|
||||
return op==0 || op==1;
|
||||
return op == 0 || op == 1;
|
||||
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
|
||||
return op==0;
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
@@ -20,7 +22,7 @@ inline bool is_shmem_op(ir::instruction* i, int op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool is_shmem_res(ir::value* v){
|
||||
bool cts::is_shmem_res(ir::value* v){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return false;
|
||||
@@ -35,7 +37,7 @@ inline bool is_shmem_res(ir::value* v){
|
||||
|
||||
|
||||
// run pass on module
|
||||
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*, ir::value*>& copies) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||
// not an instruction
|
||||
if(!i) {
|
||||
@@ -51,7 +53,7 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
|
||||
// phi node
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
|
||||
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
|
||||
add_copy(phi, phi->get_incoming_value(i), builder, to_shared, copies);
|
||||
return;
|
||||
}
|
||||
// already in shared memory
|
||||
@@ -65,33 +67,52 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
|
||||
}
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
copies.insert({x, copy});
|
||||
parent->replace_uses_of_with(x, copies.at(x));
|
||||
}
|
||||
|
||||
void cts::run(ir::module &mod) {
|
||||
// Add shared copies
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function* fn: mod.get_function_list()){
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
size_t num_op = i->get_num_operands();
|
||||
// copy to shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(is_shmem_op(i, k)){
|
||||
add_copy(i, i->get_operand(k), builder, true);
|
||||
}
|
||||
// copy from shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(!dynamic_cast<ir::phi_node*>(i) &&
|
||||
!is_shmem_op(i,k) &&
|
||||
is_shmem_res(i->get_operand(k))){
|
||||
add_copy(i, i->get_operand(k), builder, false);
|
||||
}
|
||||
// Precompute where copies should be added
|
||||
std::set<ir::value*> shmem_ops;
|
||||
std::set<ir::value*> shmem_res;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction* i) {
|
||||
if(i->get_id() == ir::INST_DOT){
|
||||
ir::dot_inst* dot = dynamic_cast<ir::dot_inst*>(i);
|
||||
ir::value* lhs = i->get_operand(0);
|
||||
ir::type* ty = lhs->get_type()->get_scalar_ty();
|
||||
analysis::mma_layout* mma_lhs = layouts_->get(lhs)->to_mma();
|
||||
// TODO: V100
|
||||
bool is_lhs_shmem = !(mma_lhs && has_sm80_ && ty->get_primitive_size_in_bits() == 16 && !dot->is_trans_a());
|
||||
if(is_lhs_shmem)
|
||||
shmem_ops.insert(lhs);
|
||||
shmem_ops.insert(i->get_operand(1));
|
||||
}
|
||||
}
|
||||
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
|
||||
shmem_ops.insert(i->get_operand(0));
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
shmem_ops.insert(i->get_operand(0));
|
||||
if(i->get_id() == ir::INST_TRANS ||
|
||||
i->get_id() == ir::INST_COPY_TO_SHARED ||
|
||||
i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
|
||||
shmem_res.insert(i);
|
||||
});
|
||||
|
||||
// Add shared copies
|
||||
std::map<ir::value*, ir::value*> copies;
|
||||
ir::builder &builder = mod.get_builder();
|
||||
ir::for_each_instruction(mod, [&](ir::instruction* i) {
|
||||
size_t num_op = i->get_num_operands();
|
||||
for(size_t k = 0; k < num_op; k++){
|
||||
ir::value* op = i->get_operand(k);
|
||||
// copy to shared operands
|
||||
bool is_shmem_op = shmem_ops.find(op) != shmem_ops.end();
|
||||
if(is_shmem_op)
|
||||
add_copy(i, op, builder, true, copies);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -3,6 +3,7 @@
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
@@ -28,6 +29,8 @@ void dce::run(ir::module &mod) {
|
||||
case ir::INST_ATOMIC_CAS:
|
||||
case ir::INST_ATOMIC_RMW:
|
||||
case ir::INST_ATOMIC_EXCH:
|
||||
case ir::INST_CALL:
|
||||
case ir::INST_LAUNCH:
|
||||
case ir::INST_BARRIER: {
|
||||
work_list.push_back(i);
|
||||
marked.insert(i);
|
||||
@@ -65,6 +68,7 @@ void dce::run(ir::module &mod) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// delete
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
|
147
lib/codegen/transform/inline.cc
Normal file
147
lib/codegen/transform/inline.cc
Normal file
@@ -0,0 +1,147 @@
|
||||
#include <iostream>
|
||||
#include "triton/codegen/transform/inline.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
bool fncmp::operator()(ir::function* x, ir::function* y) const {
|
||||
auto fn_list = x->get_parent()->get_function_list();
|
||||
return std::find(fn_list.begin(), fn_list.end(), x) < std::find(fn_list.begin(), fn_list.end(), y);
|
||||
};
|
||||
|
||||
void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder,
|
||||
std::list<ir::call_inst*>& callsites){
|
||||
ir::basic_block* parent_block = callsite->get_parent();
|
||||
ir::function* parent_fn = parent_block->get_parent();
|
||||
// the parent block is split into block A and block B:
|
||||
// - block A (`new_blocks[0]`) is the entry block of the inlined function
|
||||
// - block B (`exit`) resumes execution of the parent function
|
||||
ir::basic_block* entry = parent_block->split_before(callsite, fn->get_name());
|
||||
ir::basic_block* exit = entry->get_successors()[0];
|
||||
std::vector<ir::basic_block*> new_blocks = {entry};
|
||||
for(size_t i = 1; i < fn->blocks().size(); i++){
|
||||
ir::basic_block* block = fn->blocks()[i];
|
||||
ir::context& ctx = block->get_context();
|
||||
const std::string& name = block->get_parent()->get_name() + "_" + block->get_name();
|
||||
new_blocks.push_back(ir::basic_block::create(ctx, name, parent_fn));
|
||||
}
|
||||
// a phi node holds the return values of the inlined function
|
||||
if(exit->get_inst_list().empty())
|
||||
builder.set_insert_point(exit);
|
||||
else
|
||||
builder.set_insert_point(exit->get_first_non_phi());
|
||||
ir::phi_node* exit_val = builder.create_phi(fn->get_fn_type()->get_return_ty(), 0);
|
||||
callsite->replace_all_uses_with(exit_val);
|
||||
callsite->erase_from_parent();
|
||||
// get arguments `fn` is called with
|
||||
std::vector<ir::value*> tgt_args(callsite->op_begin(), callsite->op_end());
|
||||
std::vector<ir::argument*> src_args(fn->args().begin(), fn->args().end());
|
||||
// Actually generate the instructions:
|
||||
// - Remove the branch created by basic_block::split_before
|
||||
// - Clone all instructions
|
||||
// - Replace `ret` with incoming nodes to `exit_val` and branches to `exit`
|
||||
ir::instruction* terminator = new_blocks[0]->get_inst_list().back();
|
||||
// new_blocks[0]->get_inst_list().back()->erase_from_parent();
|
||||
terminator->erase_from_parent();
|
||||
std::map<ir::instruction*, ir::instruction*> inst_map;
|
||||
std::map<ir::argument*, ir::value*> arg_map;
|
||||
for(size_t k = 0; k < fn->args().size(); k++)
|
||||
arg_map[fn->args()[k]] = callsite->ops()[k];
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// clone instructions
|
||||
for(size_t i = 0; i < new_blocks.size(); i++){
|
||||
ir::basic_block* old_block = fn->blocks()[i];
|
||||
ir::basic_block* new_block = new_blocks[i];
|
||||
builder.set_insert_point(new_block);
|
||||
for(ir::instruction* old_inst: old_block->get_inst_list()){
|
||||
ir::instruction* new_inst = old_inst->clone();
|
||||
inst_map[old_inst] = new_inst;
|
||||
builder.insert(new_inst);
|
||||
}
|
||||
}
|
||||
// update basic blocks
|
||||
for(size_t i = 0; i < new_blocks.size(); i++) {
|
||||
for (ir::instruction* new_inst: new_blocks[i]->get_inst_list()) {
|
||||
// replace basic use cases
|
||||
for(size_t k = 0; k < new_blocks.size(); k++)
|
||||
new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]);
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(new_inst)) {
|
||||
// additionally replace basic blocks of phi-nodes since
|
||||
// replace_uses_of_with() does not replace them.
|
||||
for(unsigned in = 0; in < phi->get_num_incoming(); in++)
|
||||
for(size_t k = 0; k < new_blocks.size(); k++)
|
||||
if (phi->get_incoming_block(in) == fn->blocks()[k])
|
||||
phi->set_incoming_block(in, new_blocks[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// replace operands of instructions after constructing inst_map
|
||||
for (auto& it: inst_map) {
|
||||
ir::instruction* new_inst = it.second;
|
||||
for(size_t k = 0; k < new_inst->get_num_operands(); k++) {
|
||||
ir::value* op = new_inst->get_operand(k);
|
||||
if(auto arg_op = dynamic_cast<ir::argument*>(op))
|
||||
new_inst->set_operand(k, arg_map.at(arg_op));
|
||||
if(auto inst_op = dynamic_cast<ir::instruction*>(op))
|
||||
if(inst_map.find(inst_op) != inst_map.end())
|
||||
new_inst->set_operand(k, inst_map.at(inst_op));
|
||||
}
|
||||
// handles a ret instruciton.
|
||||
// instead of returning we need to branch to after the function call
|
||||
if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)) {
|
||||
if(ir::value* ret_val = ret->get_return_value())
|
||||
exit_val->add_incoming(ret_val, new_inst->get_parent());
|
||||
// replace ret with branch
|
||||
ir::instruction* new_br_inst = ir::branch_inst::create(exit);
|
||||
builder.set_insert_point(new_inst->get_parent());
|
||||
builder.insert(new_br_inst);
|
||||
new_inst->erase_from_parent();
|
||||
}
|
||||
}
|
||||
if(exit_val->get_num_incoming() == 1)
|
||||
exit_val->replace_all_uses_with(exit_val->get_incoming_value(0));
|
||||
// done -- make sure insert point is properly set to exit block
|
||||
builder.set_insert_point(exit);
|
||||
}
|
||||
|
||||
void inliner::run(ir::module &mod) {
|
||||
|
||||
// gather all call sites
|
||||
while(true){
|
||||
std::map<ir::function*, size_t> counts;
|
||||
for(ir::function* fn: mod.get_function_list())
|
||||
counts[fn] = 0;
|
||||
|
||||
std::list<ir::call_inst*> callsites;
|
||||
for(ir::function* fn: mod.get_function_list()){
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
for(ir::instruction* instr: block->get_inst_list())
|
||||
if(ir::call_inst* call = dynamic_cast<ir::call_inst*>(instr)){
|
||||
callsites.push_back(call);
|
||||
counts[call->get_fn()] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for(auto& count: counts){
|
||||
if(!count.first->get_is_kernel() && count.second == 0)
|
||||
count.first->get_parent()->remove_function(count.first);
|
||||
}
|
||||
|
||||
if(callsites.empty())
|
||||
break;
|
||||
|
||||
for(ir::call_inst* call: callsites)
|
||||
do_inline(call->get_fn(), call, mod.get_builder(), callsites);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -36,6 +36,9 @@ int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
|
||||
else{
|
||||
if(layouts_->has_tmp(v))
|
||||
return async_write.size() - 1;
|
||||
// // Ignore copy_to_shared. It won't modify async behavior.
|
||||
// if(dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
// return 0;
|
||||
auto it = std::find(async_write.begin(), async_write.end(), v);
|
||||
return std::distance(async_write.begin(), it);
|
||||
}
|
||||
@@ -60,15 +63,22 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
|
||||
continue;
|
||||
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
|
||||
analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr;
|
||||
analysis::shared_layout* a_tmp_index = layouts_->has_tmp_index(a) ? layouts_->get(layouts_->tmp_index(a))->to_shared() : nullptr;
|
||||
for(ir::value* b: bs){
|
||||
if(!b->get_type()->is_block_ty())
|
||||
continue;
|
||||
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
|
||||
analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr;
|
||||
analysis::shared_layout* b_tmp_index = layouts_->has_tmp_index(b) ? layouts_->get(layouts_->tmp_index(b))->to_shared() : nullptr;
|
||||
if(intersect_with(a_layout, b_layout) ||
|
||||
intersect_with(a_layout, b_tmp) ||
|
||||
intersect_with(a_layout, b_tmp_index) ||
|
||||
intersect_with(a_tmp, b_layout) ||
|
||||
intersect_with(a_tmp, b_tmp))
|
||||
intersect_with(a_tmp, b_tmp) ||
|
||||
intersect_with(a_tmp, b_tmp_index) ||
|
||||
intersect_with(a_tmp_index, b_layout) ||
|
||||
intersect_with(a_tmp_index, b_tmp) ||
|
||||
intersect_with(a_tmp_index, b_tmp_index))
|
||||
ret.insert(b);
|
||||
}
|
||||
}
|
||||
|
@@ -61,7 +61,8 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
// dot(a, b, c) + d -> dot(a, b, c + d)
|
||||
// d + dot(a, b, c) -> dot(a, b, c + d)
|
||||
auto add = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(add && add->get_op() == ir::binary_op_t::FAdd) {
|
||||
if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) {
|
||||
bool is_int_dot = add->get_op() == ir::binary_op_t::Add;
|
||||
ir::value *lhs = add->get_operand(0);
|
||||
ir::value *rhs = add->get_operand(1);
|
||||
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
|
||||
@@ -72,15 +73,21 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
ir::value *other = (dot == lhs) ? rhs : lhs;
|
||||
ir::value *acc = dot->get_operand(2);
|
||||
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
|
||||
ir::constant_fp *_0 = nullptr;
|
||||
ir::constant *_0 = nullptr;
|
||||
if(splat)
|
||||
_0 = dynamic_cast<ir::constant_fp*>(splat->get_operand(0));
|
||||
if(!(_0 && _0->get_value() == 0.0))
|
||||
_0 = dynamic_cast<ir::constant*>(splat->get_operand(0));
|
||||
if(!_0)
|
||||
return false;
|
||||
if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0))
|
||||
if (fp_0->get_value() != 0.0)
|
||||
return false;
|
||||
if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0))
|
||||
if (int_0->get_value() != 0)
|
||||
return false;
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
builder.set_insert_point(add);
|
||||
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name()));
|
||||
ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other, dot->is_trans_a(), dot->is_trans_b(), dot->allow_tf32(), dot->get_name()));
|
||||
add->replace_all_uses_with(new_dot);
|
||||
return true;
|
||||
}
|
||||
@@ -116,7 +123,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
|
||||
int nts = layout->nts(layout->get_order()[0]);
|
||||
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(nts*dtsize >= 4){
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier());
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy());
|
||||
copy_to_shared->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
@@ -143,32 +150,53 @@ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
}
|
||||
|
||||
bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
|
||||
auto binop = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
|
||||
ir::value *lhs = binop->get_operand(0);
|
||||
ir::value *rhs = binop->get_operand(1);
|
||||
ir::constant_int *_1_lhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_lhs = cst;
|
||||
}
|
||||
ir::constant_int *_1_rhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_rhs = cst;
|
||||
}
|
||||
if(_1_lhs){
|
||||
binop->replace_all_uses_with(rhs);
|
||||
return true;
|
||||
}
|
||||
else if(_1_rhs){
|
||||
binop->replace_all_uses_with(lhs);
|
||||
return true;
|
||||
}
|
||||
auto binop = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
|
||||
ir::value *lhs = binop->get_operand(0);
|
||||
ir::value *rhs = binop->get_operand(1);
|
||||
ir::constant_int *_1_lhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_lhs = cst;
|
||||
}
|
||||
ir::constant_int *_1_rhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_rhs = cst;
|
||||
}
|
||||
if(_1_lhs){
|
||||
binop->replace_all_uses_with(rhs);
|
||||
return true;
|
||||
}
|
||||
else if(_1_rhs){
|
||||
binop->replace_all_uses_with(lhs);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_insert_extract(ir::instruction *value, ir::builder& builder){
|
||||
auto extracted = dynamic_cast<ir::extract_value_inst*>(value);
|
||||
if(!extracted)
|
||||
return false;
|
||||
size_t extract_idx = extracted->get_idx();
|
||||
ir::value* agg = extracted->get_operand(0);
|
||||
auto insert = dynamic_cast<ir::insert_value_inst*>(agg);
|
||||
while(insert){
|
||||
agg = insert->get_operand(0);
|
||||
ir::value* inserted = insert->get_operand(1);
|
||||
size_t insert_idx = insert->get_idx();
|
||||
insert = dynamic_cast<ir::insert_value_inst*>(agg);
|
||||
if(extract_idx == insert_idx){
|
||||
extracted->replace_all_uses_with(inserted);
|
||||
return true;
|
||||
}
|
||||
insert = dynamic_cast<ir::insert_value_inst*>(agg);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -207,7 +235,9 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b
|
||||
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
|
||||
if_value->get_mask_operand(),
|
||||
select->get_else_value_op(),
|
||||
if_value->get_cache_modifier());
|
||||
if_value->get_cache_modifier(),
|
||||
if_value->get_eviction_policy(),
|
||||
if_value->get_is_volatile());
|
||||
select->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
@@ -219,22 +249,22 @@ bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){
|
||||
ir::instruction* op = dynamic_cast<ir::instruction*>(cvt->get_operand(0));
|
||||
if(!op)
|
||||
return false;
|
||||
// convert(elementwise(x, y)) = elementwise(convert(x), convert(y))
|
||||
if(op->get_id() == ir::INST_BINOP){
|
||||
for(size_t i = 0; i < op->get_num_operands(); i++){
|
||||
ir::value* arg_i = op->get_operand(i);
|
||||
builder.set_insert_point(op);
|
||||
// create new layout transform
|
||||
ir::instruction* new_arg_i = cvt->clone();
|
||||
layouts_->copy(new_arg_i, op);
|
||||
builder.insert(new_arg_i);
|
||||
// set the right args
|
||||
new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
|
||||
op->replace_uses_of_with(arg_i, new_arg_i);
|
||||
}
|
||||
cvt->replace_all_uses_with(op);
|
||||
return true;
|
||||
}
|
||||
// // convert(elementwise(x, y)) = elementwise(convert(x), convert(y))
|
||||
// if(op->get_id() == ir::INST_BINOP){
|
||||
// for(size_t i = 0; i < op->get_num_operands(); i++){
|
||||
// ir::value* arg_i = op->get_operand(i);
|
||||
// builder.set_insert_point(op);
|
||||
// // create new layout transform
|
||||
// ir::instruction* new_arg_i = cvt->clone();
|
||||
// layouts_->copy(new_arg_i, op);
|
||||
// builder.insert(new_arg_i);
|
||||
// // set the right args
|
||||
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
|
||||
// op->replace_uses_of_with(arg_i, new_arg_i);
|
||||
// }
|
||||
// cvt->replace_all_uses_with(op);
|
||||
// return true;
|
||||
// }
|
||||
auto cvt_op = dynamic_cast<ir::cvt_layout_inst*>(op);
|
||||
if(!cvt_op)
|
||||
return false;
|
||||
@@ -282,9 +312,11 @@ void peephole::run(ir::module &mod) {
|
||||
was_modified = was_modified || rewrite_mult(i, builder);
|
||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
// was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_insert_extract(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
was_modified = was_modified || rewrite_select_masked_load(i, builder);
|
||||
// TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD
|
||||
// was_modified = was_modified || rewrite_select_masked_load(i, builder);
|
||||
was_modified = was_modified || rewrite_cvt_layout(i, builder);
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||
was_modified = was_modified || rewrite_load_to_shared(i, builder);
|
||||
|
@@ -134,6 +134,7 @@ void pipeline::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
const int num_stages = num_stages_;
|
||||
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
|
||||
|
||||
for(auto info: to_pipeline){
|
||||
ir::load_inst* load = info.load;
|
||||
ir::phi_node* ptr = info.ptr;
|
||||
@@ -178,7 +179,7 @@ void pipeline::run(ir::module &mod) {
|
||||
false_value = remat_false_value;
|
||||
} else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier());
|
||||
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
|
||||
for (int stage = 1; stage < num_stages-1; ++stage) {
|
||||
// mask is the loop condition of the previous iteration
|
||||
@@ -193,7 +194,7 @@ void pipeline::run(ir::module &mod) {
|
||||
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier());
|
||||
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
}
|
||||
|
||||
// create new phis for induction variables
|
||||
@@ -222,7 +223,7 @@ void pipeline::run(ir::module &mod) {
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier());
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
|
||||
|
||||
// phi node
|
||||
@@ -257,7 +258,7 @@ void pipeline::run(ir::module &mod) {
|
||||
}
|
||||
else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
@@ -268,7 +269,7 @@ void pipeline::run(ir::module &mod) {
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier());
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
// phi node
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
||||
|
@@ -29,8 +29,16 @@ void prefetch::run(ir::module &mod) {
|
||||
std::vector<ir::dot_inst*> to_prefetch;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i) {
|
||||
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
|
||||
// Now only do prefetching when dot is fp16
|
||||
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::FP16TyID)
|
||||
// Now only do prefetching when dot is using tensor cores
|
||||
if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
|
||||
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
|
||||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
|
||||
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ||
|
||||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8)
|
||||
&& dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8)
|
||||
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||
)
|
||||
)
|
||||
return;
|
||||
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
||||
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
||||
|
Reference in New Issue
Block a user