[dnn/gemm]: fixed leading dimension in transposed variants

This commit is contained in:
Philippe Tillet
2019-07-18 16:35:48 -07:00
parent f0d8306437
commit 71594da66f
3 changed files with 18 additions and 14 deletions

View File

@@ -8,12 +8,12 @@
int main() {
bool AT = false;
bool AT = true;
bool BT = false;
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
// matrix multiplication parameters
int32_t M = 1024, N = 1024, K = 1024;
int32_t M = 64, N = 128, K = 128;
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);
std::vector<float> ha(M*K);
@@ -33,14 +33,14 @@ int main() {
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp16", "fp16", 4, 4);
gemm.enqueue(stream, {da, db, dc}, true);
// stream->read(dc, true, 0, hc);
// gemm.cpu_ref<float>(rc, ha, hb);
// for(size_t i = 0; i < M*N; i++)
// if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
// exit(EXIT_FAILURE);
// }
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp32", "fp32", 4, 4);
gemm.enqueue(stream, {da, db, dc}, false);
stream->read(dc, true, 0, hc);
gemm.cpu_ref<float>(rc, ha, hb);
for(size_t i = 0; i < M*N; i++)
if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
std::cout << "Pass!" << std::endl;
}

View File

@@ -1149,6 +1149,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
unsigned alignment = std::min(starting_multiple, max_contiguous);
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
// vector_size = result->axis(0).contiguous;
// vector_size = 1;
std::map<unsigned, Value*> packets;
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());
result->for_each([&](indices_t idx){

View File

@@ -54,6 +54,9 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
unsigned grid_0 = (M_ + TM - 1)/TM;
unsigned grid_1 = (N_ + TN - 1)/TN;
unsigned grid_2 = 1;
int32_t lda = AT_ ? K_ : M_;
int32_t ldb = BT_ ? N_ : K_;
int32_t ldc = M_;
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
kernel->setArg(0, a);
kernel->setArg(1, b);
@@ -61,9 +64,9 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(3, M_);
kernel->setArg(4, N_);
kernel->setArg(5, K_);
kernel->setArg(6, M_);
kernel->setArg(7, N_);
kernel->setArg(8, M_);
kernel->setArg(6, lda);
kernel->setArg(7, ldb);
kernel->setArg(8, ldc);
kernel->setArg(9, locks_);
kernel->setArg(10, grid_0);
kernel->setArg(11, grid_1);