Merge branch 'master' into rcom52_fixes
This commit is contained in:
@@ -115,6 +115,18 @@ std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_ins
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_dequantize(ir::dequantize_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
result.push_back(op_cst[d]);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
@@ -129,6 +141,36 @@ std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_cmp(ir::cmp_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
|
||||
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
|
||||
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
cst_info ax = {1, 0};
|
||||
// Examples:
|
||||
// 16 17 18 ... 32 < 24 24 24 ... 24 => equal in groups of 8
|
||||
// 16 17 18 ... 32 < 20 20 20 ... 20 => equal in groups of 4
|
||||
// 16 17 18 ... 32 < 16 16 16 ... 16 => equal in groups of 16
|
||||
//
|
||||
// if LHS is a range of N continuous (or equal) elements that starts at M,
|
||||
// and RHS is a set of N constants that start at K
|
||||
// then the result in constant in groups of gcd(M, K)
|
||||
if(rhs[d].num_cst % lhs_max_contiguous[d] == 0 ||
|
||||
rhs[d].num_cst % lhs[d].num_cst == 0)
|
||||
ax.num_cst = gcd(lhs_multiple_of[d], rhs_multiple_of[d]);
|
||||
result.push_back(ax);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
@@ -136,12 +178,14 @@ std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operat
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
auto max_contiguous = populate_max_contiguous(lhs_op);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
|
||||
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
|
||||
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
cst_info ax;
|
||||
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
|
||||
// todo might not be entirely true
|
||||
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
|
||||
unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value);
|
||||
ax = {num_constants, 0};
|
||||
}
|
||||
else
|
||||
@@ -180,10 +224,14 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
return populate_is_constant_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_is_constant_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
|
||||
return populate_is_constant_dequantize(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_is_constant_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_is_constant_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::cmp_inst*>(v))
|
||||
return populate_is_constant_cmp(x);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_is_constant_gep(x);
|
||||
return populate_is_constant_default(v);
|
||||
@@ -245,6 +293,23 @@ std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_dequantize(ir::dequantize_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
|
||||
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
for(size_t d = 0; d < shapes.size(); d++) {
|
||||
unsigned factor = 1;
|
||||
if (d == shapes.size() - 1) {
|
||||
factor = ret_last_dim / op_last_dim;
|
||||
}
|
||||
result.push_back(factor * op_mc[d]);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
@@ -285,8 +350,8 @@ std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator*
|
||||
}
|
||||
if(x->is_int_add_sub()){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]);
|
||||
rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]);
|
||||
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
|
||||
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
|
||||
value = std::max(lvalue, rvalue);
|
||||
}
|
||||
result.push_back(value);
|
||||
@@ -332,9 +397,9 @@ std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||
return max_contiguous_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
|
||||
if(max_contiguous > 0)
|
||||
return add_to_cache(x, {max_contiguous}, max_contiguous_);
|
||||
std::vector<unsigned> max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
|
||||
if(!max_contiguous.empty())
|
||||
return add_to_cache(x, max_contiguous, max_contiguous_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||
return populate_max_contiguous_cast(x);
|
||||
@@ -342,6 +407,8 @@ std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||
return populate_max_contiguous_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_max_contiguous_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
|
||||
return populate_max_contiguous_dequantize(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_max_contiguous_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
@@ -386,6 +453,23 @@ std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_dequantize(ir::dequantize_inst* x){
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
|
||||
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
|
||||
auto op_multiple = populate_starting_multiple(op);
|
||||
for(size_t d = 0; d < shapes.size(); d++) {
|
||||
unsigned factor = 1;
|
||||
if (d == shapes.size() - 1) {
|
||||
factor = ret_last_dim / op_last_dim;
|
||||
}
|
||||
result.push_back(factor * op_multiple[d]);
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
|
||||
auto result = populate_starting_multiple(x->get_operand(0));
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
@@ -401,7 +485,7 @@ std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operato
|
||||
if(x->is_int_add_sub())
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
if(x->is_int_div())
|
||||
result[d] = 1;
|
||||
result[d] = (lhs[d] == (1 << 31)) ? 1 << 31 : 1;
|
||||
if(x->is_int_rem() && rhs[d] > 1){
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
}
|
||||
@@ -471,28 +555,42 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
||||
return add_to_cache(v, {1}, starting_multiple_);
|
||||
}
|
||||
|
||||
unsigned get_max_multiple(int val){
|
||||
if(val == 0) return 1 << 31;
|
||||
if(val % 128 == 0) return 128;
|
||||
if(val % 64 == 0) return 64;
|
||||
if(val % 32 == 0) return 32;
|
||||
if(val % 16 == 0) return 16;
|
||||
if(val % 8 == 0) return 8;
|
||||
if(val % 4 == 0) return 4;
|
||||
if(val % 2 == 0) return 2;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||
return starting_multiple_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
if(multiple_of > 0)
|
||||
return add_to_cache(x, {multiple_of}, starting_multiple_);
|
||||
std::vector<unsigned> multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
if(!multiple_of.empty())
|
||||
return add_to_cache(x, multiple_of, starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||
return populate_starting_multiple_cast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_starting_multiple_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
|
||||
return add_to_cache(x, {get_max_multiple(x->get_value())}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
|
||||
return add_to_cache(x, {get_max_multiple(x->get_first()->get_value())}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_starting_multiple_gep(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_starting_multiple_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_starting_multiple_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
|
||||
return populate_starting_multiple_dequantize(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_starting_multiple_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
@@ -511,12 +609,15 @@ std::vector<unsigned> align::contiguous(ir::value* v) const {
|
||||
return max_contiguous_.at(v);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::get_cst_info(ir::value* v) const {
|
||||
return is_constant_.at(v);
|
||||
}
|
||||
|
||||
|
||||
void align::populate(ir::value *v) {
|
||||
populate_is_constant(v);
|
||||
populate_starting_multiple(v);
|
||||
populate_max_contiguous(v);
|
||||
|
||||
}
|
||||
|
||||
void align::run(ir::module &mod) {
|
||||
|
@@ -92,8 +92,10 @@ void allocation::run(ir::module &mod) {
|
||||
}
|
||||
// Save maximum size of induced memory space
|
||||
allocated_size_ = 0;
|
||||
for(shared_layout* x: V)
|
||||
for(shared_layout* x: V){
|
||||
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
|
||||
// std::cout << "start: " << starts[x] << " | end: " << starts[x] + x->get_size() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -56,6 +56,17 @@ void axes::update_graph_trans(ir::instruction *i) {
|
||||
graph_.add_edge({i, perm[d]}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_dequantize(ir::instruction *i) {
|
||||
auto *dequantize = static_cast<ir::dequantize_inst*>(i);
|
||||
auto shapes = dequantize->get_type()->get_block_shapes();
|
||||
ir::value *op = dequantize->get_operand(0);
|
||||
|
||||
// add edge except the last axis
|
||||
for(unsigned d = 0; d < shapes.size() - 1; d ++){
|
||||
graph_.add_edge({i, d}, {op, d});
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_broadcast(ir::instruction *i) {
|
||||
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
|
||||
auto shapes = broadcast->get_type()->get_block_shapes();
|
||||
@@ -79,7 +90,7 @@ void axes::update_graph_dot(ir::instruction *i) {
|
||||
graph_.add_edge({dot, d}, {D, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_elementwise(ir::instruction *i,
|
||||
void axes::update_graph_elementwise(ir::instruction *i,
|
||||
bool is_masked_load_async) {
|
||||
if(i->get_num_operands() == 0)
|
||||
return;
|
||||
@@ -119,6 +130,7 @@ void axes::update_graph(ir::instruction *i) {
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);
|
||||
case ir::INST_CAT: return update_graph_elementwise(i, true);
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_DEQUANTIZE: return update_graph_dequantize(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
||||
|
@@ -23,19 +23,67 @@ inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
|
||||
return std::min(std::max(x, lo), hi);
|
||||
}
|
||||
|
||||
inline bool is_hmma_c(ir::value *v){
|
||||
inline bool is_hmma_c(ir::value *v, int sm){
|
||||
bool result = false;
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *a = x->get_operand(0);
|
||||
ir::type *a_ty = a->get_type();
|
||||
ir::value *b = x->get_operand(1);
|
||||
ir::type *b_ty = b->get_type();
|
||||
result = a_ty->get_scalar_ty()->is_fp16_ty() &&
|
||||
b_ty->get_scalar_ty()->is_fp16_ty();
|
||||
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
|
||||
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
|
||||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
|
||||
x->allow_tf32() && sm >= 80) ||
|
||||
(a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&
|
||||
sm >= 80);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
|
||||
mma_layout::TensorCoreType mma_type;
|
||||
if (auto* dot = dynamic_cast<ir::dot_inst*>(v)) {
|
||||
ir::value* a = dot->get_operand(0);
|
||||
ir::value* b = dot->get_operand(1);
|
||||
ir::type* a_ty = a->get_type();
|
||||
ir::type* b_ty = b->get_type();
|
||||
ir::type* c_ty = v->get_type();
|
||||
|
||||
if (c_ty->get_scalar_ty()->is_fp32_ty()) {
|
||||
// floating point tensor cores
|
||||
if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) {
|
||||
mma_type = mma_layout::FP32_FP16_FP16_FP32;
|
||||
return mma_type;
|
||||
}
|
||||
if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) {
|
||||
mma_type = mma_layout::FP32_BF16_BF16_FP32;
|
||||
return mma_type;
|
||||
}
|
||||
if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty()
|
||||
&& dot->allow_tf32()) {
|
||||
mma_type = mma_layout::FP32_TF32_TF32_FP32;
|
||||
return mma_type;
|
||||
}
|
||||
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
|
||||
// throw std::runtime_error("integer tensor cores are not yet supported");
|
||||
// // integer tensor cores
|
||||
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
|
||||
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
|
||||
// return mma_type;
|
||||
// }
|
||||
// if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) {
|
||||
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
|
||||
// return mma_type;
|
||||
// }
|
||||
if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
|
||||
mma_type = mma_layout::INT32_INT8_INT8_INT32;
|
||||
return mma_type;
|
||||
}
|
||||
}
|
||||
}
|
||||
return mma_layout::NOT_APPLICABLE;
|
||||
}
|
||||
|
||||
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||
@@ -52,11 +100,12 @@ inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||
}
|
||||
}
|
||||
|
||||
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||
if(i && is_hmma_c(i) && i->get_operand(n) == v)
|
||||
if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) {
|
||||
result = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,7 +191,9 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target* tgt,
|
||||
shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) {
|
||||
shared_layout *layout_a, shared_layout *layout_b,
|
||||
ir::value *dot): distributed_layout(MMA, axes, shape, values, align) {
|
||||
tensor_core_type_ = get_mma_type(dot);
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
if(tgt->as_nvidia() && tgt->as_nvidia()->sm() < 80){
|
||||
@@ -157,25 +208,67 @@ mma_layout::mma_layout(size_t num_warps,
|
||||
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
|
||||
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
|
||||
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
|
||||
contig_per_thread_ = {1, 1};
|
||||
order_ = {0, 1};
|
||||
}
|
||||
else{
|
||||
fpw_ = {1, 1, 1};
|
||||
spw_ = {16, 8, 1};
|
||||
rep_ = {2, 2, 1};
|
||||
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
|
||||
contig_per_thread_ = {1, 2};
|
||||
order_ = {1, 0};
|
||||
}
|
||||
order_ = {0, 1};
|
||||
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
wpt_ = {1, 1, 1};
|
||||
std::vector<int> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt_;
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
||||
}while(wpt_nm1 != wpt_);
|
||||
// try to make warp-level tiles as square as possible to maximize data re-use
|
||||
if (tgt->as_nvidia()->sm() < 80) {
|
||||
std::vector<int> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt_;
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
||||
}while(wpt_nm1 != wpt_);
|
||||
} else {
|
||||
bool changed = false;
|
||||
// try to have a warp own entire rows of the output
|
||||
// this makes it easier to fuse multiple mmas by fusing
|
||||
// registers
|
||||
bool one_warp_per_row = false;
|
||||
for(ir::value* v: values)
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto* dot = dynamic_cast<ir::dot_inst*>(u);
|
||||
auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(u);
|
||||
if((dot && dot->get_operand(2)!=v) || !layout_a->to_shared() || cts)
|
||||
one_warp_per_row = shape[0] / spw_[0] >= num_warps;
|
||||
}
|
||||
// std::cout << one_warp_per_row << std::endl;
|
||||
|
||||
if(one_warp_per_row){
|
||||
wpt_[1] = 1;
|
||||
wpt_[0] = num_warps;
|
||||
}
|
||||
else{
|
||||
do {
|
||||
changed = false;
|
||||
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
|
||||
break;
|
||||
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
|
||||
if (wpt_[0] < shape_[0] / spw_[0]) {
|
||||
wpt_[0] *= 2;
|
||||
changed = true;
|
||||
}
|
||||
} else {
|
||||
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
|
||||
wpt_[1] *= 2;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
} while(changed);
|
||||
}
|
||||
}
|
||||
|
||||
// std::cout << wpt_[0] << " " << wpt_[1] << std::endl;
|
||||
|
||||
/* shape per block */
|
||||
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
|
||||
@@ -198,8 +291,6 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
bool is_dot = std::any_of(values.begin(), values.end(),
|
||||
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
|
||||
|
||||
|
||||
|
||||
std::vector<ir::value*> ptrs;
|
||||
for(ir::value *v: values)
|
||||
for(ir::user *usr: v->get_users())
|
||||
@@ -215,7 +306,6 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
|
||||
}
|
||||
|
||||
|
||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
size /= shape_[i];
|
||||
@@ -277,12 +367,16 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
|
||||
res.reset(new double_buffer_info_t{value_1, value_0, phi});
|
||||
}
|
||||
|
||||
static bool is_smem(ir::value* v) {
|
||||
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||
dynamic_cast<ir::masked_load_async_inst*>(v))
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
static bool is_smem_in(ir::value* v, const ir::basic_block* bb) {
|
||||
if (ir::instruction *instr = dynamic_cast<ir::instruction*>(v)) {
|
||||
if (instr->get_parent() != bb)
|
||||
return false;
|
||||
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||
dynamic_cast<ir::masked_load_async_inst*>(v)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// param:
|
||||
@@ -297,14 +391,14 @@ static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::
|
||||
ir::basic_block *cbb0 = cphi->get_incoming_block(0);
|
||||
ir::basic_block *cbb1 = cphi->get_incoming_block(1);
|
||||
|
||||
if (is_smem(c0)) {
|
||||
if (is_smem_in(c0, cbb0)) {
|
||||
assert(cbb0 == bb0);
|
||||
values_0.push_back(c0);
|
||||
if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) {
|
||||
next = phi1;
|
||||
continue;
|
||||
} else {
|
||||
if (is_smem(c1)) {
|
||||
if (is_smem_in(c1, cbb1)) {
|
||||
value_1 = c1;
|
||||
assert(cbb1 == bb1);
|
||||
return true;
|
||||
@@ -359,7 +453,8 @@ shared_layout::shared_layout(data_layout *arg,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
ir::type *ty,
|
||||
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
|
||||
analysis::align* align, target *tgt, bool is_tmp)
|
||||
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt), is_tmp_(is_tmp){
|
||||
|
||||
size_ = 0;
|
||||
arg_layout_ = arg;
|
||||
@@ -385,12 +480,35 @@ shared_layout::shared_layout(data_layout *arg,
|
||||
for(ir::value* v: values){
|
||||
extract_dot_use(v, dot_a, 0);
|
||||
extract_dot_use(v, dot_b, 1);
|
||||
extract_hmma_dot_use(v, hmma_dot_a, 0);
|
||||
extract_hmma_dot_use(v, hmma_dot_b, 1);
|
||||
extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm());
|
||||
extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm());
|
||||
}
|
||||
hmma_dot_a_ = hmma_dot_a;
|
||||
hmma_dot_b_ = hmma_dot_b;
|
||||
|
||||
// Update mma_vec
|
||||
if (hmma_dot_a_) {
|
||||
assert(order_.size() == 2);
|
||||
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
|
||||
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
|
||||
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
|
||||
|
||||
// for now, disable swizzle when using lds.8
|
||||
if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32)
|
||||
if (order_[0] == 0) // need transpose
|
||||
allow_swizzle_ = false;
|
||||
} else if (hmma_dot_b_) {
|
||||
assert(order_.size() == 2);
|
||||
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
|
||||
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
|
||||
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
|
||||
|
||||
// for now, disable swizzle when using lds.8
|
||||
if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32)
|
||||
if (order_[0] == 1) // need transpose
|
||||
allow_swizzle_ = false;
|
||||
}
|
||||
|
||||
// size
|
||||
size_ = ty_->get_primitive_size_in_bits() / 8;
|
||||
for(auto s: shape_)
|
||||
@@ -454,7 +572,8 @@ void layouts::make_graph(ir::instruction *i) {
|
||||
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
// if(layouts_.find(id) != layouts_.end())
|
||||
// return;
|
||||
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
||||
auto it_hmma_c = std::find_if(values.begin(), values.end(),
|
||||
[&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); });
|
||||
auto cmp = [](ir::value* x, ir::value *y) {
|
||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
|
||||
@@ -476,19 +595,61 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
ir::value *b = dot->get_operand(1);
|
||||
create(groups_.at(a), values_.at(groups_.at(a)));
|
||||
create(groups_.at(b), values_.at(groups_.at(b)));
|
||||
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b)));
|
||||
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_,
|
||||
(shared_layout*)layouts_.at(groups_.at(a)),
|
||||
(shared_layout*)layouts_.at(groups_.at(b)),
|
||||
dot);
|
||||
}
|
||||
else if(it_cts != values.end()){
|
||||
ir::instruction *cts = (ir::instruction*)*it_cts;
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
}
|
||||
else{
|
||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
|
||||
}
|
||||
}
|
||||
|
||||
// layout checkers
|
||||
bool layouts::is_scanline(ir::instruction *i) {
|
||||
return this->get(i->get_operand(0))->to_scanline() != nullptr;
|
||||
}
|
||||
|
||||
bool layouts::is_coalesced_scanline(ir::instruction *i) {
|
||||
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
|
||||
auto *scanline = this->get(i->get_operand(0))->to_scanline();
|
||||
return scanline && scanline->get_order()[0] == red->get_axis();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool layouts::is_mma(ir::instruction *i) {
|
||||
return this->get(i->get_operand(0))->to_mma() != nullptr;
|
||||
}
|
||||
|
||||
bool layouts::is_a100_mma(ir::instruction *i) {
|
||||
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
|
||||
return is_mma(red) && (tgt_->as_nvidia()->sm() >= 80) &&
|
||||
(red->get_axis() == 1);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void layouts::create_tmp_layout(size_t id, data_layout *arg,
|
||||
const std::vector<int> &axes,
|
||||
const std::vector<unsigned> &shape,
|
||||
ir::instruction *i, bool is_index) {
|
||||
ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
|
||||
: i->get_type()->get_scalar_ty();
|
||||
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_, true);
|
||||
if (is_index) {
|
||||
tmp_index_[i] = id;
|
||||
} else {
|
||||
tmp_[i] = id;
|
||||
}
|
||||
}
|
||||
|
||||
void layouts::run(ir::module &mod) {
|
||||
// make graph
|
||||
graph_.clear();
|
||||
@@ -510,35 +671,47 @@ void layouts::run(ir::module &mod) {
|
||||
// create temporaries
|
||||
size_t id = values_.size();
|
||||
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
||||
// std::cout << "layout: " << std::endl;
|
||||
// i->print(std::cout);
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
|
||||
id++;
|
||||
ir::value *arg = red->get_operand(0);
|
||||
unsigned axis = red->get_axis();
|
||||
distributed_layout *layout =
|
||||
dynamic_cast<analysis::distributed_layout *>(get(arg));
|
||||
// shape
|
||||
auto shapes = arg->get_type()->get_block_shapes();
|
||||
scanline_layout *layout = get(arg)->to_scanline();
|
||||
shapes[axis] = layout->mts(axis);
|
||||
unsigned axis = red->get_axis();
|
||||
shapes[axis] =
|
||||
layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
|
||||
// create layout
|
||||
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[red] = id;
|
||||
id++;
|
||||
create_tmp_layout(id, layout, axes_->get(arg), shapes, red);
|
||||
|
||||
if (red->with_index()) {
|
||||
id++;
|
||||
create_tmp_layout(id, layout, axes_->get(arg), shapes, red, true);
|
||||
}
|
||||
}
|
||||
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
|
||||
distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
|
||||
distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
|
||||
id++;
|
||||
size_t dim = val->get_type()->get_tile_rank();
|
||||
ir::type::block_shapes_t shape(dim);
|
||||
for(size_t k = 0; k < dim; k++){
|
||||
shape[k] = std::max(in_layout->shape_per_cta(k),
|
||||
out_layout->shape_per_cta(k));
|
||||
}
|
||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[val] = id;
|
||||
auto in_ord = in_layout->get_order();
|
||||
auto out_ord = out_layout->get_order();
|
||||
int in_vec = in_layout->contig_per_thread(in_ord[0]);
|
||||
int out_vec = out_layout->contig_per_thread(out_ord[0]);
|
||||
int pad = std::max(in_vec, out_vec);
|
||||
shape[out_ord[0]] += pad;
|
||||
id++;
|
||||
create_tmp_layout(id, out_layout, axes_->get(val), shape, val);
|
||||
}
|
||||
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
|
||||
id++;
|
||||
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[atom] = id;
|
||||
create_tmp_layout(id, nullptr, {}, {1}, atom);
|
||||
}
|
||||
});
|
||||
|
||||
|
@@ -14,43 +14,108 @@ namespace analysis{
|
||||
void liveness::run(ir::module &mod) {
|
||||
intervals_.clear();
|
||||
|
||||
// Assigns index to each instruction
|
||||
std::map<ir::value*, slot_index> indices;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
slot_index index = 0;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()){
|
||||
index += 1;
|
||||
indices.insert({instr, index});
|
||||
std::map<ir::value*, std::set<shared_layout*>> layouts_map;
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(!layout || layout->is_tmp())
|
||||
continue;
|
||||
for(ir::value* v:layout->get_values()){
|
||||
layouts_map[v].insert(layout);
|
||||
}
|
||||
}
|
||||
|
||||
// create live intervals
|
||||
|
||||
|
||||
std::map<ir::user*, std::set<shared_layout*>> live_in;
|
||||
while(true){
|
||||
bool changed = false;
|
||||
ir::instruction* last_inst = nullptr;
|
||||
ir::for_each_instruction_backward(mod, [&](ir::instruction* i){
|
||||
// gen
|
||||
std::set<shared_layout*> gen;
|
||||
for(ir::value* v: i->ops())
|
||||
for(shared_layout* layout: layouts_map[v])
|
||||
gen.insert(layout);
|
||||
// kill
|
||||
std::set<shared_layout*> kill;
|
||||
for(shared_layout* layout: layouts_map[i])
|
||||
kill.insert(layout);
|
||||
// temporaries are handled separately
|
||||
if(layouts_->has_tmp(i)){
|
||||
gen.insert(layouts_->get(layouts_->tmp(i))->to_shared());
|
||||
kill.insert(layouts_->get(layouts_->tmp(i))->to_shared());
|
||||
}
|
||||
if(layouts_->has_tmp_index(i)){
|
||||
gen.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
|
||||
kill.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
|
||||
}
|
||||
// live-out
|
||||
std::set<shared_layout*> live_out;
|
||||
std::vector<ir::instruction*> succs = {last_inst};
|
||||
if(i == i->get_parent()->get_inst_list().back())
|
||||
for(ir::basic_block* succ: i->get_parent()->get_successors())
|
||||
succs.push_back(succ->get_inst_list().front());
|
||||
for(ir::instruction* succ: succs)
|
||||
for(shared_layout* layout: live_in[succ])
|
||||
if(!layout->is_tmp())
|
||||
live_out.insert(layout);
|
||||
|
||||
// new sets
|
||||
std::set<shared_layout*> live_out_minus_kill;
|
||||
std::set_difference(live_out.begin(), live_out.end(), kill.begin(), kill.end(),
|
||||
std::inserter(live_out_minus_kill, live_out_minus_kill.end()));
|
||||
std::set<shared_layout*> new_live_in;
|
||||
std::set_union(gen.begin(), gen.end(), live_out_minus_kill.begin(), live_out_minus_kill.end(),
|
||||
std::inserter(new_live_in, new_live_in.end()));
|
||||
|
||||
changed = changed || (new_live_in != live_in[i]);
|
||||
live_in[i] = new_live_in;
|
||||
last_inst = i;
|
||||
});
|
||||
if(!changed)
|
||||
break;
|
||||
}
|
||||
|
||||
// ir::for_each_instruction(mod, [&](ir::instruction* i){
|
||||
// i->print(std::cout);
|
||||
// std::cout << " live_in: " << live_in[i].size() << std::endl;
|
||||
// });
|
||||
|
||||
|
||||
|
||||
// Assigns index to each instruction
|
||||
std::map<ir::value*, slot_index> indices;
|
||||
slot_index index = 0;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction* instr){
|
||||
index += 1;
|
||||
indices.insert({instr, index});
|
||||
});
|
||||
|
||||
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(layout)
|
||||
intervals_[layout] = segment{INT32_MAX, 0};
|
||||
}
|
||||
|
||||
for(auto& x: live_in)
|
||||
for(shared_layout* layout: x.second)
|
||||
intervals_[layout].start = std::min<int>(intervals_[layout].start, indices[x.first]);
|
||||
|
||||
for(auto& x: live_in)
|
||||
for(shared_layout* layout: x.second){
|
||||
intervals_[layout].end = std::max<int>(intervals_[layout].end, indices[x.first] + 1);
|
||||
}
|
||||
|
||||
|
||||
for(auto &x: layouts_->get_all()) {
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(!layout)
|
||||
continue;
|
||||
// users
|
||||
std::set<ir::user*> users;
|
||||
for(ir::value *v: layout->get_values()){
|
||||
for(ir::user *u: v->get_users())
|
||||
users.insert(u);
|
||||
}
|
||||
// compute intervals
|
||||
unsigned start = INT32_MAX;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(indices.find(v) != indices.end())
|
||||
start = std::min(start, indices.at(v));
|
||||
unsigned end = 0;
|
||||
for(ir::user *u: users)
|
||||
if(indices.find(u) != indices.end())
|
||||
end = std::max(end, indices.at(u));
|
||||
if(end == 0)
|
||||
end = start + 1;
|
||||
intervals_[layout] = segment{start, end};
|
||||
// std::cout << intervals_[layout].start << " " << intervals_[layout].end << std::endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
@@ -19,6 +19,7 @@ void swizzle::run(ir::module &) {
|
||||
continue;
|
||||
ir::value* mma_dot_a = layout->hmma_dot_a();
|
||||
ir::value* mma_dot_b = layout->hmma_dot_b();
|
||||
|
||||
if(!mma_dot_a && !mma_dot_b){
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
@@ -27,22 +28,31 @@ void swizzle::run(ir::module &) {
|
||||
}
|
||||
auto ord = layout->get_order();
|
||||
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
|
||||
if(!in_layout)
|
||||
continue;
|
||||
int per_phase = 1;
|
||||
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(in_layout)
|
||||
per_phase = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
else
|
||||
per_phase = 1;
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
|
||||
int inner = mma_dot_a ? 0 : 1;
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
per_phase_[layout] = per_phase;
|
||||
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
|
||||
if(mma_dot_a)
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||
else
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||
}
|
||||
else{
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = 8 / per_phase_[layout];
|
||||
vec_[layout] = 8;
|
||||
else {
|
||||
if (!layout->allow_swizzle()) {
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
} else {
|
||||
per_phase_[layout] = per_phase;
|
||||
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
|
||||
vec_[layout] = layout->get_mma_vec();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user