diff --git a/docs/tutorials/matrix-multiplication.rst b/docs/tutorials/matrix-multiplication.rst index b42a6dc92..8cf5c79e5 100644 --- a/docs/tutorials/matrix-multiplication.rst +++ b/docs/tutorials/matrix-multiplication.rst @@ -10,7 +10,7 @@ The purpose of this section is to present a Triton-C implementation of matrix mu Compute Kernel ============== -Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below: +Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below: .. code-block:: C @@ -35,7 +35,7 @@ Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fa TYPE a[TM, TK] = *pa; //(9) TYPE b[TK, TN] = *pb; //(10) // matrix-multiply accumulate - c += a @ b; //(11) + c += dot(a, b); //(11) // increment pointers pa = pa + TK * 1; //(12) pb = pb + TK * ldb; //(13) @@ -88,7 +88,7 @@ The purpose of pre-fetching is to overlap the update of the accumulator `c` with TYPE a[TM, TK] = *pa; //(9) TYPE b[TK, TN] = *pb; //(10) for(int k = K; k > 0; k-= TK){ - c += a @ b; + c += dot(a, b); pa = pa + TK * 1; pb = pb + TK * ldb; // don't prefetch last iteration @@ -144,7 +144,7 @@ It is common for optimized matrix-multiplication implementations (e.g., BLAS) to TYPE b[SHAPE_B] = (*pb); // reduction loop for(int k = K; k > 0; k-= TK){ - c += USE_A @ USE_B; + c += dot(USE_A, USE_B); pa = pa + TK * STRIDE_AK; pb = pb + TK * STRIDE_BK; a = *pa; diff --git a/docs/tutorials/matrix-transposition.rst b/docs/tutorials/matrix-transposition.rst index 6e7fdb1db..22894e7d8 100644 --- a/docs/tutorials/matrix-transposition.rst +++ b/docs/tutorials/matrix-transposition.rst @@ -53,21 +53,21 @@ which will be used in statements (5) and (6) to construct tiles of pointers - Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics: - .. code-block:: C +:: │ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │ - │ ⋮ ⋮ │ - │ ⋮ ⋮ │ + │ ⋮ ⋮ │ + │ ⋮ ⋮ │ │ X + (pidm*TM + TM - 1) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + TM - 1) + (pidn*TN + TN - 1)*ldx) │ - Statement (6) constructs the following array of pointers `py` using numpy-style broadcasting semantics: - .. code-block:: C +:: │ Y + (pidn*TN + 0) + (pidm*TM + 0)*ldy, ..., ..., Y + (pidn*TN + 0) + (pidm*TM + TM - 1)*ldy) │ - │ ⋮ ⋮ │ - │ ⋮ ⋮ │ + │ ⋮ ⋮ │ + │ ⋮ ⋮ │ │ Y + (pidn*TN + TN - 1) + (pidn*TN + 0)*ldy, ..., ..., Y + (pidn*TN + TN - 1) + (pidm*TM + TM - 1)*ldy) │ - Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`.