[BACKEND] Added Int8 mma (#440)

This commit is contained in:
daadaada
2022-01-28 01:12:44 +08:00
committed by GitHub
parent 3a23c1dd33
commit 59d371c6eb
11 changed files with 232 additions and 115 deletions

View File

@@ -33,7 +33,9 @@ inline bool is_hmma_c(ir::value *v, int sm){
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
x->allow_tf32() && sm >= 80);
x->allow_tf32() && sm >= 80) ||
(a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&
sm >= 80);
}
return result;
}
@@ -63,7 +65,7 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
return mma_type;
}
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
throw std::runtime_error("integer tensor cores are not yet supported");
// throw std::runtime_error("integer tensor cores are not yet supported");
// // integer tensor cores
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
@@ -73,10 +75,10 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
// return mma_type;
// }
// if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
// mma_type = mma_layout::INT32_INT8_INT8_INT32;
// return mma_type;
// }
if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
mma_type = mma_layout::INT32_INT8_INT8_INT32;
return mma_type;
}
}
}
return mma_layout::NOT_APPLICABLE;
@@ -444,11 +446,21 @@ shared_layout::shared_layout(data_layout *arg,
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
// for now, disable swizzle when using lds.8
if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32)
if (order_[0] == 0) // need transpose
allow_swizzle_ = false;
} else if (hmma_dot_b_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
// for now, disable swizzle when using lds.8
if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32)
if (order_[0] == 1) // need transpose
allow_swizzle_ = false;
}
// size

View File

@@ -41,9 +41,15 @@ void swizzle::run(ir::module &) {
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
}
else {
if (!layout->allow_swizzle()) {
per_phase_[layout] = 1;
max_phase_[layout] = 1;
vec_[layout] = 1;
} else {
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
vec_[layout] = layout->get_mma_vec();
}
}
}
}