[dnn/gemm] added some bounds checking
This commit is contained in:
@@ -8,15 +8,15 @@
|
|||||||
|
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
bool AT = true;
|
bool AT = false;
|
||||||
bool BT = false;
|
bool BT = true;
|
||||||
typedef float T;
|
typedef float T;
|
||||||
std::string ty = "fp16";
|
std::string ty = "fp16";
|
||||||
size_t dt_nbytes = sizeof(T);
|
size_t dt_nbytes = sizeof(T);
|
||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
// matrix multiplication parameters
|
// matrix multiplication parameters
|
||||||
int32_t M = 65536, N = 2048, K = 2048;
|
int32_t M = 4096, N = 4096, K = 4096;
|
||||||
std::vector<T> hc(M*N);
|
std::vector<T> hc(M*N);
|
||||||
std::vector<T> rc(M*N);
|
std::vector<T> rc(M*N);
|
||||||
std::vector<T> ha(M*K);
|
std::vector<T> ha(M*K);
|
||||||
|
@@ -31,9 +31,6 @@ public:
|
|||||||
// clone
|
// clone
|
||||||
base* clone() const;
|
base* clone() const;
|
||||||
|
|
||||||
// default params
|
|
||||||
std::vector<unsigned> default_params();
|
|
||||||
|
|
||||||
// CPU reference implementation
|
// CPU reference implementation
|
||||||
template<class T, bool AT, bool BT>
|
template<class T, bool AT, bool BT>
|
||||||
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
|
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
|
||||||
|
@@ -235,7 +235,7 @@ void tune::run(ir::module &mod) {
|
|||||||
continue;
|
continue;
|
||||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8));
|
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||||
*params_.at(i).at("nts.d0") = *tmp;
|
*params_.at(i).at("nts.d0") = *tmp;
|
||||||
}
|
}
|
||||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
|
@@ -51,6 +51,7 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
|||||||
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
||||||
unsigned TM = info.globals.at("TM");
|
unsigned TM = info.globals.at("TM");
|
||||||
unsigned TN = info.globals.at("TN");
|
unsigned TN = info.globals.at("TN");
|
||||||
|
unsigned TK = info.globals.at("TK");
|
||||||
unsigned grid_0 = (M_ + TM - 1)/TM;
|
unsigned grid_0 = (M_ + TM - 1)/TM;
|
||||||
unsigned grid_1 = (N_ + TN - 1)/TN;
|
unsigned grid_1 = (N_ + TN - 1)/TN;
|
||||||
unsigned grid_2 = 1;
|
unsigned grid_2 = 1;
|
||||||
@@ -67,23 +68,13 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
|||||||
kernel->setArg(6, lda);
|
kernel->setArg(6, lda);
|
||||||
kernel->setArg(7, ldb);
|
kernel->setArg(7, ldb);
|
||||||
kernel->setArg(8, ldc);
|
kernel->setArg(8, ldc);
|
||||||
kernel->setArg(9, locks_);
|
kernel->setArg(9, TK);
|
||||||
kernel->setArg(10, grid_0);
|
kernel->setArg(10, locks_);
|
||||||
kernel->setArg(11, grid_1);
|
kernel->setArg(11, grid_0);
|
||||||
|
kernel->setArg(12, grid_1);
|
||||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<unsigned> gemm::default_params() {
|
|
||||||
if(AT_ && BT_)
|
|
||||||
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
|
|
||||||
else if(AT_ && !BT_)
|
|
||||||
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
|
|
||||||
else if(!AT_ && BT_)
|
|
||||||
return {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1};
|
|
||||||
else
|
|
||||||
return {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1};
|
|
||||||
}
|
|
||||||
|
|
||||||
void gemm::triton_c_src(std::ostream &os) const {
|
void gemm::triton_c_src(std::ostream &os) const {
|
||||||
std::string AS0 = "TM", AS1 = "TK";
|
std::string AS0 = "TM", AS1 = "TK";
|
||||||
std::string BS0 = "TK", BS1 = "TN";
|
std::string BS0 = "TK", BS1 = "TN";
|
||||||
@@ -103,12 +94,14 @@ void gemm::triton_c_src(std::ostream &os) const {
|
|||||||
std::swap(bcb0, bcb1);
|
std::swap(bcb0, bcb1);
|
||||||
std::swap(ldb0, ldb1);
|
std::swap(ldb0, ldb1);
|
||||||
}
|
}
|
||||||
|
std::string AS = AS0 + ", " + AS1;
|
||||||
|
std::string BS = BS0 + ", " + BS1;
|
||||||
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
|
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
|
||||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||||
std::string res =
|
std::string res =
|
||||||
R"(
|
R"(
|
||||||
const tunable int32 TM = {16, 32, 64, 128};
|
const tunable int32 TM = {32, 64, 128, 256};
|
||||||
const tunable int32 TN = {16, 32, 64, 128};
|
const tunable int32 TN = {32, 64, 128, 256};
|
||||||
const tunable int32 TK = {32};
|
const tunable int32 TK = {32};
|
||||||
const tunable int32 GZ = {1};
|
const tunable int32 GZ = {1};
|
||||||
|
|
||||||
@@ -117,27 +110,36 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
|||||||
fp32 *C,
|
fp32 *C,
|
||||||
int32 M, int32 N, int32 K,
|
int32 M, int32 N, int32 K,
|
||||||
)" + align_lda_str + R"( int32 lda, )" + align_ldb_str + R"(" int32 ldb, int32 ldc,
|
)" + align_lda_str + R"( int32 lda, )" + align_ldb_str + R"(" int32 ldb, int32 ldc,
|
||||||
int32 *locks, int32 grid0, int32 grid1) {
|
int32 bound, int32 *locks, int32 grid0, int32 grid1) {
|
||||||
int32 rxa[TM] = get_global_range[TM](0);
|
int32 ridx = get_range_id(0);
|
||||||
int32 ryb[TN] = get_global_range[TN](1);
|
int32 ridy = get_range_id(1);
|
||||||
|
int32 rxa[TM] = ridx*TM + (0 ... TM);
|
||||||
|
int32 ryb[TN] = ridy*TN + (0 ... TN);
|
||||||
int32 rka[TK] = 0 ... TK;
|
int32 rka[TK] = 0 ... TK;
|
||||||
int32 rkb[TK] = 0 ... TK;
|
int32 rkb[TK] = 0 ... TK;
|
||||||
fp32 c[TM, TN] = 0;
|
fp32 c[TM, TN] = 0;
|
||||||
)" + a_ty_ + R"(* pa[)" + AS0 + ", " + AS1 + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
)" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||||
)" + b_ty_ + R"(* pb[)" + BS0 + ", " + BS1 + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
)" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||||
)" + a_ty_ + R"( a[)" + AS0 + ", " + AS1 + R"(] = *pa;
|
int1 checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(;
|
||||||
)" + b_ty_ + R"( b[)" + BS0 + ", " + BS1 + R"(] = *pb;
|
int1 checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
|
||||||
for(int32 k = K; k > TK; k = k - TK){
|
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
|
||||||
|
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
|
||||||
|
for(int32 k = K; k > 0; k = k - TK){
|
||||||
c = dot()" + usea + ", " + useb + R"(, c);
|
c = dot()" + usea + ", " + useb + R"(, c);
|
||||||
pa = pa + TK)" + lda0 + R"(;
|
pa = pa + TK)" + lda0 + R"(;
|
||||||
pb = pb + TK)" + ldb0 + R"(;
|
pb = pb + TK)" + ldb0 + R"(;
|
||||||
a = *pa;
|
int1 checka[)" + AS + R"(] = k > bound;
|
||||||
b = *pb;
|
int1 checkb[)" + BS + R"(] = k > bound;
|
||||||
|
@checka a = *pa;
|
||||||
|
@checkb b = *pb;
|
||||||
}
|
}
|
||||||
int32 rxc[TM] = get_global_range[TM](0);
|
int32 rxc[TM] = ridx*TM + (0 ... TM);
|
||||||
int32 ryc[TN] = get_global_range[TN](1);
|
int32 ryc[TN] = ridy*TN + (0 ... TN);
|
||||||
|
int1 checkc0[TM] = rxc < M;
|
||||||
|
int1 checkc1[TN] = ryc < N;
|
||||||
|
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||||
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||||
*pc = c;
|
@checkc *pc = c;
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
os << res;
|
os << res;
|
||||||
|
@@ -49,7 +49,7 @@ void loop_nest(std::vector<size_t> const & ranges,
|
|||||||
values[i--] = 0;
|
values[i--] = 0;
|
||||||
}
|
}
|
||||||
i = D - 1;
|
i = D - 1;
|
||||||
// Small sleep so that the thread pool doesn't grow too big
|
// Short sleep so that the thread pool doesn't grow too big
|
||||||
std::this_thread::sleep_for(std::chrono::microseconds(1));
|
std::this_thread::sleep_for(std::chrono::microseconds(1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user