[DOCS] Transposition fix
This commit is contained in:
committed by
Philippe Tillet
parent
0920da6fae
commit
180ed26b61
@@ -25,8 +25,8 @@ In Triton, however, kernels are single-threaded and the compiler automatically d
|
||||
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||
// create 2D array of pointers
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
|
||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] * ldy + rn[:, newaxis]; //(6)
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
|
||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
|
||||
// write back using the transposition operator '^'
|
||||
*py = ^(*px); //(7)
|
||||
}
|
||||
@@ -73,6 +73,65 @@ which will be used in statements (5) and (6) to construct tiles of pointers
|
||||
- Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`.
|
||||
|
||||
|
||||
==================================
|
||||
A Note on Numpy-style Broadcasting
|
||||
==================================
|
||||
|
||||
The construction statements (5) and (6) are a little subtle. To help understand them, consider the following numpy example.
|
||||
|
||||
First, we create a row vector of numbers 0 to 11, which we reshape into a 4x3 matrix.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
|
||||
vec = np.linspace(0,11,12)
|
||||
mat = vec.reshape((4,3))
|
||||
|
||||
Imagine that we would like to process this in two 2x3 tiles (i.e. tile 0 will consider the top half, and tile 1 will consider the bottom).
|
||||
|
||||
::
|
||||
|
||||
[[ 0, 1, 2],
|
||||
[ 3, 4, 5],
|
||||
[ 6, 7, 8],
|
||||
[ 9, 10, 11]]
|
||||
|
||||
Given `pidm=0`, `pidn=0`, `TM=2`, `TN=3`, we would like for tile 0 to have the values:
|
||||
|
||||
::
|
||||
|
||||
[ 0, 1, 2],
|
||||
[ 3, 4, 5],
|
||||
|
||||
We construct ranges `rm` and `rn` as:
|
||||
::
|
||||
|
||||
rm = [0, 1]
|
||||
rn = [0, 1, 2]
|
||||
|
||||
Using numpy-style broadcasting, we can add these together to create a matrix:
|
||||
|
||||
::
|
||||
|
||||
rm[:, np.newaxis] + rn[np.newaxis, :]
|
||||
|
||||
rn -> [0, 1, 2]
|
||||
rm -> [0., [[0, 1, 2],
|
||||
1.] [1, 2, 3]]
|
||||
|
||||
The bottom row is incorrect. Notice that `rm` indexes the rows of the matrix; we need to offset it so that each element gives the index
|
||||
of the start of that row. For instance, to access row 1 column 0, we need to access location 3. To access row 2 column 0, we need
|
||||
to access location 6. To translate from row N, column 0, we need to multiply N by the number of columns in each row (the leading dimension).
|
||||
In this case this is 3, so what we really need is:
|
||||
|
||||
::
|
||||
|
||||
ldx = 3
|
||||
px = rm[:, np.newaxis] * ldx + rn[np.newaxis,:]
|
||||
|
||||
`newaxis` is built into Triton, and pointer arrays can be constructed in just the same way (as in this example).
|
||||
|
||||
==========================
|
||||
The __multipleof attribute
|
||||
==========================
|
||||
@@ -111,3 +170,5 @@ You might have noticed that the above code will fail when `M` and `N` are not mu
|
||||
|
||||
|
||||
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.
|
||||
|
||||
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_transpose.py>`_.
|
||||
|
@@ -14,10 +14,15 @@ class _transpose(torch.autograd.Function):
|
||||
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||
|
||||
// create 2D array of pointers
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
|
||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] * ldy + rn[:,newaxis]; //(6)
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
|
||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
|
||||
|
||||
*py = ^*px;
|
||||
// create bounds-checking mask
|
||||
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
|
||||
bool checky[TN, TM] = (rn[:, newaxis] < N) && (rm[newaxis, :] < M); //(7b)
|
||||
|
||||
// conditional write-back using the conditional dereferencing operatior '*?()'
|
||||
*?(checky)py = ^(*?(checkx)px); //(7)
|
||||
}
|
||||
"""
|
||||
|
||||
@@ -40,6 +45,7 @@ class _transpose(torch.autograd.Function):
|
||||
'TM' : [32,64,128],
|
||||
'TN' : [32,64,128],
|
||||
}
|
||||
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||
|
||||
if _transpose.kernel is None:
|
||||
@@ -53,7 +59,7 @@ transpose = _transpose.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(128,200).cuda()
|
||||
x = torch.randn(1024,128).cuda()
|
||||
|
||||
print(x)
|
||||
|
||||
|
Reference in New Issue
Block a user