From 5db3a7adfebcb220e510600a2c6fa0d9261e447f Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 30 Aug 2019 17:05:03 -0700 Subject: [PATCH] [python][examples] some more cleaning of dot product example --- lib/runtime/function.cc | 2 +- python/examples/dot.py | 32 +++++++++++++------------------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 9b2072974..54d6af4c1 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -217,7 +217,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c dce.run(module); vectorize.run(module); dce.run(module); -// ir::print(module, std::cout); + ir::print(module, std::cout); // generate llvm code llvm::LLVMContext ctx; std::unique_ptr llvm(new llvm::Module(module.get_name(), ctx)); diff --git a/python/examples/dot.py b/python/examples/dot.py index 1eb7867af..ffb93fd33 100644 --- a/python/examples/dot.py +++ b/python/examples/dot.py @@ -35,45 +35,39 @@ src = """ #define SHAPE_B TK, TN #endif -void dot(TYPE * A, - TYPE * B, - TYPE * C, +void dot(TYPE * A, TYPE * B, TYPE * C, int M, int N, int K, int lda __multipleof(8), int ldb __multipleof(8), int ldc) { - - /* prologue */ + // prologue int ridx = get_program_id(0); int ridy = get_program_id(1); int rxa[TM] = ridx * TM + 0 ... TM; int ryb[TN] = ridy * TN + 0 ... TN; int rka[TK] = 0 ... TK; int rkb[TK] = 0 ... TK; - float xc[TM, TN] = 0; - /* pointers for operands */ + float c[TM, TN] = 0; + // pointers to operands TYPE* pa[SHAPE_A] = A + rka[BROADCAST_AK] * STRIDE_AK + rxa[BROADCAST_AM] * STRIDE_AM; TYPE* pb[SHAPE_B] = B + rkb[BROADCAST_BK] * STRIDE_BK + ryb[BROADCAST_BN] * STRIDE_BN; - /* prefetches operands */ + // prefetches operands TYPE a[SHAPE_A] = *pa; TYPE b[SHAPE_B] = *pb; - /* reduction loop */ - for(int k = K; k > 0; k = k - TK){ - xc = USEA @ USEB + xc; + // reduction loop + for(int k = K; k > 0; k-= TK){ + c += USEA @ USEB; pa = pa + TK * STRIDE_AK; pb = pb + TK * STRIDE_BK; a = *pa; b = *pb; } - /* epilogue */ - int rxc[TM] = ridx * TM + (0 ... TM); - int ryc[TN] = ridy * TN + (0 ... TN); + // epilogue + int rxc[TM] = ridx * TM + 0 ... TM; + int ryc[TN] = ridy * TN + 0 ... TN; TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc; - TYPE c[TM, TN] = xc; - bool checkc0[TM] = rxc < M; - bool checkc1[TN] = ryc < N; - bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - *pc = c; + bool checkc[TM, TN] = (rxc < M)[:, newaxis] && (ryc < N)[newaxis, :]; + *?(checkc) pc = c; } """