diff --git a/examples/matrix.cpp b/examples/matrix.cpp index a6fb5c168..03d206ba0 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -31,19 +31,21 @@ extern translation_unit *ast_root; const char src[] = "\ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\ - int32 rx[16] = get_global_range[16](0);\ - int32 ry[16] = get_global_range[16](1);\ + int32 rxa[16] = get_global_range[16](0);\ + int32 ryb[16] = get_global_range[16](1);\ int32 rka[8] = 0 ... 8;\ int32 rkb[8] = 0 ... 8;\ + int32 rxc[16] = get_global_range[16](0);\ + int32 ryc[16] = get_global_range[16](1);\ fp32 C[16, 16] = 0;\ int32 k;\ - fp32* pa[16, 8] = a + rx[:, newaxis] + rka[newaxis, :]*M;\ - fp32* pb[16, 8] = b + ry[:, newaxis] + rkb[newaxis, :]*K;\ - fp32* pc[16, 16] = c + rx[:, newaxis] + ry[newaxis, :]*M;\ + fp32* pa[16, 8] = a + rxa[:, newaxis] + rka[newaxis, :]*M;\ + fp32* pb[16, 8] = b + ryb[:, newaxis] + rkb[newaxis, :]*K;\ + fp32* pc[16, 16] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\ for(k = K; k > 0; k = k - 8){\ fp32 a[16, 8] = *pa;\ fp32 b[16, 8] = *pb;\ - C = C + 1;\ + C = dot(a, b, C);\ pa = pa + 8*M;\ pb = pb + 8*K;\ }\ @@ -127,6 +129,17 @@ static void compile_machine_code(CUdevice &device, CUcontext &context, CUmodule checkCudaErrors(cuModuleGetFunction(&function, module, name.c_str())); } +template +void simple_gemm(std::vector &c, const std::vector &a, const std::vector &b, size_t M, size_t N, size_t K){ + for(size_t m = 0; m < M; m++) + for(size_t n = 0; n < N; n++){ + T acc = 0; + for(size_t k = 0; k < K; k++) + acc += a[m + k*M] * b[n + k*N]; + c[m + n*M] = acc; + } +} + int main() { // create AST from Triton-C source YY_BUFFER_STATE buffer = yy_scan_string(src); @@ -151,14 +164,18 @@ int main() { // tuning parameters tune.run(module); std::vector params = { - // asm + // a0 2, 8, 1, - // bsn + // b0 4, 4, 1, - // pa + // c0 + 2, 8, 1, + // c1 + 4, 4, 1, + // a1 2, 4, 1, - // pb - 1, 8, 1, + // b1 + 1, 8, 1 }; std::map> errors; unsigned i = 0; @@ -194,12 +211,14 @@ int main() { CUstream cu_stream; int major, minor; compile_machine_code(cu_device, cu_context, cu_module, cu_kernel, cu_stream, major, minor, src, "test"); +// std::cout << src << std::endl; // execute machine code // Allocate buffers typedef float numeric_t; - size_t M = 256, N = 256, K = 256; + size_t M = 32, N = 32, K = 32; std::vector c(M*N); + std::vector rc(M*N); std::vector a(M*K); std::vector b(K*N); for(size_t i = 0; i < a.size(); i++) @@ -222,14 +241,13 @@ int main() { cuFuncGetAttribute(&num_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, cu_kernel); unsigned TM = 16; unsigned TN = 16; - unsigned nthreads = params[1]*params[2]*params[7]*params[8]; + unsigned nthreads = 32; checkCudaErrors(cuLaunchKernel(cu_kernel, M/TM, N/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL)); checkCudaErrors(cuStreamSynchronize(cu_stream)); // Write back checkCudaErrors(cuMemcpyDtoH(c.data(), d_c, sizeof(numeric_t) * c.size())); - - std::cout << c[0] << " " << c[1] << " " << c[2] << " " << c[3] << std::endl; -// for(size_t i = 0; i < M*N; i++) -// if(c[i] != 32) -// std::cout << i << " " << "success" << std::endl; + simple_gemm(rc, a, b, M, N, K); + for(size_t i = 0; i < M*N; i++) + if(std::abs(c[i] - rc[i])/std::max(c[i], rc[i]) > 1e-4) + std::cout << i << " " << c[i] << " " << rc[i] << std::endl; } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index d4d6394a1..1f29275f9 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -317,6 +317,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, if(is_shared){ size_t offset = alloc_->get_offset(v); Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset)); + ptr = builder.CreateBitCast(ptr, ty->getPointerTo(ptr->getType()->getPointerAddressSpace())); tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)}); } // create distributed tile @@ -445,6 +446,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & ir::value *A = ins->get_operand(0); ir::value *B = ins->get_operand(1); ir::value *C = ins->get_operand(2); + Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)}); result->for_each([&](indices_t idx){ Value *res = tmap_.at(C)->get_value(idx); unsigned NK = A->get_type()->get_tile_shapes()[1]; @@ -453,7 +455,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & indices_t b_idx = {idx[1], builder.getInt32(K)}; Value *a = tmap_.at(A)->get_value(a_idx); Value *b = tmap_.at(B)->get_value(b_idx); - res = builder.CreateAdd(res, builder.CreateMul(a, b)); + res = builder.CreateCall(f_mul_add, {a, b, res}); } result->set_value(idx, res); });