[dnn/gemm]: fixed leading dimension in transposed variants
This commit is contained in:
@@ -8,12 +8,12 @@
|
||||
|
||||
|
||||
int main() {
|
||||
bool AT = false;
|
||||
bool AT = true;
|
||||
bool BT = false;
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// matrix multiplication parameters
|
||||
int32_t M = 1024, N = 1024, K = 1024;
|
||||
int32_t M = 64, N = 128, K = 128;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
@@ -33,14 +33,14 @@ int main() {
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp16", "fp16", 4, 4);
|
||||
gemm.enqueue(stream, {da, db, dc}, true);
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// gemm.cpu_ref<float>(rc, ha, hb);
|
||||
// for(size_t i = 0; i < M*N; i++)
|
||||
// if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
// exit(EXIT_FAILURE);
|
||||
// }
|
||||
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp32", "fp32", 4, 4);
|
||||
gemm.enqueue(stream, {da, db, dc}, false);
|
||||
stream->read(dc, true, 0, hc);
|
||||
gemm.cpu_ref<float>(rc, ha, hb);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
@@ -1149,6 +1149,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||
// vector_size = result->axis(0).contiguous;
|
||||
// vector_size = 1;
|
||||
std::map<unsigned, Value*> packets;
|
||||
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());
|
||||
result->for_each([&](indices_t idx){
|
||||
|
@@ -54,6 +54,9 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
unsigned grid_0 = (M_ + TM - 1)/TM;
|
||||
unsigned grid_1 = (N_ + TN - 1)/TN;
|
||||
unsigned grid_2 = 1;
|
||||
int32_t lda = AT_ ? K_ : M_;
|
||||
int32_t ldb = BT_ ? N_ : K_;
|
||||
int32_t ldc = M_;
|
||||
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
@@ -61,9 +64,9 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, M_);
|
||||
kernel->setArg(7, N_);
|
||||
kernel->setArg(8, M_);
|
||||
kernel->setArg(6, lda);
|
||||
kernel->setArg(7, ldb);
|
||||
kernel->setArg(8, ldc);
|
||||
kernel->setArg(9, locks_);
|
||||
kernel->setArg(10, grid_0);
|
||||
kernel->setArg(11, grid_1);
|
||||
|
Reference in New Issue
Block a user