Merge pull request #38 from jack-willturner/master
Add working examples to tutorials and python examples folder
This commit is contained in:
@@ -58,7 +58,7 @@ As you will see, a wrapper for the above Triton function can be created in just
|
|||||||
"""
|
"""
|
||||||
# create callable kernel for the source-code
|
# create callable kernel for the source-code
|
||||||
# options: 4 warps and a -DTILE=1024
|
# options: 4 warps and a -DTILE=1024
|
||||||
kernel = triton.kernel(src, defines = {'TILE': 1024}; num_warps = [4])
|
kernel = triton.kernel(src, defines = {'TILE': 1024}, num_warps = [4])
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -88,6 +88,7 @@ As you will see, a wrapper for the above Triton function can be created in just
|
|||||||
zb = add(x, y)
|
zb = add(x, y)
|
||||||
diff = (za - zb).abs().max()
|
diff = (za - zb).abs().max()
|
||||||
print(diff)
|
print(diff)
|
||||||
|
print(torch.allclose(za,zb))
|
||||||
|
|
||||||
Executing the above code will:
|
Executing the above code will:
|
||||||
|
|
||||||
@@ -97,3 +98,5 @@ Executing the above code will:
|
|||||||
- Call the resulting custom op
|
- Call the resulting custom op
|
||||||
|
|
||||||
In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation.
|
In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation.
|
||||||
|
|
||||||
|
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/vec_add.py>`_.
|
||||||
|
@@ -35,7 +35,7 @@ Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fa
|
|||||||
TYPE a[TM, TK] = *pa; //(9)
|
TYPE a[TM, TK] = *pa; //(9)
|
||||||
TYPE b[TK, TN] = *pb; //(10)
|
TYPE b[TK, TN] = *pb; //(10)
|
||||||
// matrix-multiply accumulate
|
// matrix-multiply accumulate
|
||||||
c += dot(a, b); //(11)
|
c += a @ b; //(11)
|
||||||
// increment pointers
|
// increment pointers
|
||||||
pa = pa + TK * 1; //(12)
|
pa = pa + TK * 1; //(12)
|
||||||
pb = pb + TK * ldb; //(13)
|
pb = pb + TK * ldb; //(13)
|
||||||
@@ -88,7 +88,7 @@ The purpose of pre-fetching is to overlap the update of the accumulator `c` with
|
|||||||
TYPE a[TM, TK] = *pa; //(9)
|
TYPE a[TM, TK] = *pa; //(9)
|
||||||
TYPE b[TK, TN] = *pb; //(10)
|
TYPE b[TK, TN] = *pb; //(10)
|
||||||
for(int k = K; k > 0; k-= TK){
|
for(int k = K; k > 0; k-= TK){
|
||||||
c += dot(a, b);
|
c += a @ b;
|
||||||
pa = pa + TK * 1;
|
pa = pa + TK * 1;
|
||||||
pb = pb + TK * ldb;
|
pb = pb + TK * ldb;
|
||||||
// don't prefetch last iteration
|
// don't prefetch last iteration
|
||||||
@@ -144,7 +144,7 @@ It is common for optimized matrix-multiplication implementations (e.g., BLAS) to
|
|||||||
TYPE b[SHAPE_B] = (*pb);
|
TYPE b[SHAPE_B] = (*pb);
|
||||||
// reduction loop
|
// reduction loop
|
||||||
for(int k = K; k > 0; k-= TK){
|
for(int k = K; k > 0; k-= TK){
|
||||||
c += dot(USE_A, USE_B);
|
c += USE_A @ USE_B;
|
||||||
pa = pa + TK * STRIDE_AK;
|
pa = pa + TK * STRIDE_AK;
|
||||||
pb = pb + TK * STRIDE_BK;
|
pb = pb + TK * STRIDE_BK;
|
||||||
a = *pa;
|
a = *pa;
|
||||||
@@ -182,3 +182,5 @@ Auto-tuning can also be handled using pre-processor macros:
|
|||||||
|
|
||||||
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
|
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
|
||||||
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
|
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
|
||||||
|
|
||||||
|
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_mul.py>`_.
|
||||||
|
@@ -25,8 +25,8 @@ In Triton, however, kernels are single-threaded and the compiler automatically d
|
|||||||
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
||||||
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||||
// create 2D array of pointers
|
// create 2D array of pointers
|
||||||
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
|
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
|
||||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] * ldy + rn[:, newaxis]; //(6)
|
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
|
||||||
// write back using the transposition operator '^'
|
// write back using the transposition operator '^'
|
||||||
*py = ^(*px); //(7)
|
*py = ^(*px); //(7)
|
||||||
}
|
}
|
||||||
@@ -73,6 +73,65 @@ which will be used in statements (5) and (6) to construct tiles of pointers
|
|||||||
- Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`.
|
- Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`.
|
||||||
|
|
||||||
|
|
||||||
|
==================================
|
||||||
|
A Note on Numpy-style Broadcasting
|
||||||
|
==================================
|
||||||
|
|
||||||
|
The construction statements (5) and (6) are a little subtle. To help understand them, consider the following numpy example.
|
||||||
|
|
||||||
|
First, we create a row vector of numbers 0 to 11, which we reshape into a 4x3 matrix.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
vec = np.linspace(0,11,12)
|
||||||
|
mat = vec.reshape((4,3))
|
||||||
|
|
||||||
|
Imagine that we would like to process this in two 2x3 tiles (i.e. tile 0 will consider the top half, and tile 1 will consider the bottom).
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
[[ 0, 1, 2],
|
||||||
|
[ 3, 4, 5],
|
||||||
|
[ 6, 7, 8],
|
||||||
|
[ 9, 10, 11]]
|
||||||
|
|
||||||
|
Given `pidm=0`, `pidn=0`, `TM=2`, `TN=3`, we would like for tile 0 to have the values:
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
[ 0, 1, 2],
|
||||||
|
[ 3, 4, 5],
|
||||||
|
|
||||||
|
We construct ranges `rm` and `rn` as:
|
||||||
|
::
|
||||||
|
|
||||||
|
rm = [0, 1]
|
||||||
|
rn = [0, 1, 2]
|
||||||
|
|
||||||
|
Using numpy-style broadcasting, we can add these together to create a matrix:
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
rm[:, np.newaxis] + rn[np.newaxis, :]
|
||||||
|
|
||||||
|
rn -> [0, 1, 2]
|
||||||
|
rm -> [0., [[0, 1, 2],
|
||||||
|
1.] [1, 2, 3]]
|
||||||
|
|
||||||
|
The bottom row is incorrect. Notice that `rm` indexes the rows of the matrix; we need to offset it so that each element gives the index
|
||||||
|
of the start of that row. For instance, to access row 1 column 0, we need to access location 3. To access row 2 column 0, we need
|
||||||
|
to access location 6. To translate from row N, column 0, we need to multiply N by the number of columns in each row (the leading dimension).
|
||||||
|
In this case this is 3, so what we really need is:
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
ldx = 3
|
||||||
|
px = rm[:, np.newaxis] * ldx + rn[np.newaxis,:]
|
||||||
|
|
||||||
|
`newaxis` is built into Triton, and pointer arrays can be constructed in just the same way (as in this example).
|
||||||
|
|
||||||
==========================
|
==========================
|
||||||
The __multipleof attribute
|
The __multipleof attribute
|
||||||
==========================
|
==========================
|
||||||
@@ -95,7 +154,7 @@ Bounds Checking
|
|||||||
==========================
|
==========================
|
||||||
|
|
||||||
|
|
||||||
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle thie situation, as shown below:
|
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle this situation, as shown below:
|
||||||
|
|
||||||
.. code-block:: C
|
.. code-block:: C
|
||||||
|
|
||||||
@@ -111,3 +170,5 @@ You might have noticed that the above code will fail when `M` and `N` are not mu
|
|||||||
|
|
||||||
|
|
||||||
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.
|
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.
|
||||||
|
|
||||||
|
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_transpose.py>`_.
|
||||||
|
@@ -110,9 +110,9 @@ However, in practice only A, B are provided by the user, and all the other :code
|
|||||||
'TYPE' : dtype,
|
'TYPE' : dtype,
|
||||||
'AT' : transpose_a,
|
'AT' : transpose_a,
|
||||||
'BT' : transpose_b,
|
'BT' : transpose_b,
|
||||||
'TM' : [32, 64, 128]
|
'TM' : [32, 64, 128],
|
||||||
'TN' : [32, 64, 128]
|
'TN' : [32, 64, 128],
|
||||||
'TK' : [8]
|
'TK' : [8],
|
||||||
# handle A transposition
|
# handle A transposition
|
||||||
'USE_A' : '^a' if transpose_a else 'a',
|
'USE_A' : '^a' if transpose_a else 'a',
|
||||||
'STRIDE_AK' : 'lda' if transpose_a else '1',
|
'STRIDE_AK' : 'lda' if transpose_a else '1',
|
||||||
|
68
python/examples/tutorials/mat_copy.py
Normal file
68
python/examples/tutorials/mat_copy.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
class _copy(torch.autograd.Function):
|
||||||
|
src = """
|
||||||
|
__global__ void copy(TYPE * X, TYPE * Y,
|
||||||
|
int M, int N, int ldx __multipleof(8)) {
|
||||||
|
// extract program ID
|
||||||
|
int pidm = get_program_id(0); //(1)
|
||||||
|
int pidn = get_program_id(1); //(2)
|
||||||
|
|
||||||
|
// create 1D range along the two matrix's axes
|
||||||
|
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
||||||
|
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||||
|
|
||||||
|
// create 2D array of pointers
|
||||||
|
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
|
||||||
|
TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldx; //(6)
|
||||||
|
|
||||||
|
*py = *px;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel = None ### initialize later when we know the sizes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
|
||||||
|
M, N = x.shape
|
||||||
|
|
||||||
|
ldx = N;
|
||||||
|
|
||||||
|
dtype = x.dtype
|
||||||
|
|
||||||
|
y = torch.empty((M,N)).cuda()
|
||||||
|
|
||||||
|
defines= {
|
||||||
|
'TYPE' : dtype,
|
||||||
|
'TM' : [32,64,128],
|
||||||
|
'TN' : [32,64,128],
|
||||||
|
}
|
||||||
|
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||||
|
|
||||||
|
if _copy.kernel is None:
|
||||||
|
_copy.kernel = triton.kernel(_copy.src, defines=defines, num_warps=[4])
|
||||||
|
|
||||||
|
_copy.kernel(x, y, M, N, ldx, grid=grid)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
copy = _copy.apply
|
||||||
|
|
||||||
|
# test
|
||||||
|
torch.manual_seed(0)
|
||||||
|
x = torch.randn(8,4).cuda()
|
||||||
|
|
||||||
|
print(x)
|
||||||
|
|
||||||
|
ya = x
|
||||||
|
yb = copy(x)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(ya)
|
||||||
|
print()
|
||||||
|
print(yb)
|
||||||
|
print(torch.allclose(ya, yb))
|
||||||
|
|
||||||
|
print(ya == yb)
|
84
python/examples/tutorials/mat_mul.py
Normal file
84
python/examples/tutorials/mat_mul.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
class _dot(torch.autograd.Function):
|
||||||
|
src = """
|
||||||
|
__global__ void dot(TYPE *A, TYPE *B, TYPE *C, int M, int N, int K,
|
||||||
|
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
|
||||||
|
int pm = get_program_id(0);
|
||||||
|
int pn = get_program_id(1);
|
||||||
|
|
||||||
|
// ranges
|
||||||
|
int rm[TM] = pm * TM + 0 ... TM;
|
||||||
|
int rn[TN] = pn * TN + 0 ... TN;
|
||||||
|
int rk[TK] = 0 ... TK;
|
||||||
|
|
||||||
|
// accumulator
|
||||||
|
float c[TM, TN] = 0;
|
||||||
|
|
||||||
|
//pointers
|
||||||
|
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda;
|
||||||
|
TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1;
|
||||||
|
|
||||||
|
for(int k=K; k>0; k-=TK) {
|
||||||
|
TYPE a[TM, TK] = *pa;
|
||||||
|
TYPE b[TK, TN] = *pb;
|
||||||
|
|
||||||
|
c += a @ b;
|
||||||
|
|
||||||
|
pa = pa + TK * 1;
|
||||||
|
pb = pb + TK * ldb;
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPE* pc[TM,TN] = C + rn[newaxis, :] + rm[:,newaxis] * ldc;
|
||||||
|
*pc = c;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, a, b):
|
||||||
|
c = _dot._call(a,b)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _call(a, b):
|
||||||
|
M, K = a.shape
|
||||||
|
K, N = b.shape
|
||||||
|
|
||||||
|
lda = K
|
||||||
|
ldb = N
|
||||||
|
ldc = N
|
||||||
|
|
||||||
|
dtype = a.dtype
|
||||||
|
|
||||||
|
c = triton.empty([M,N], dtype=dtype)
|
||||||
|
|
||||||
|
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||||
|
|
||||||
|
defines= {
|
||||||
|
'TYPE' : dtype,
|
||||||
|
'TM' : [32,64,128],
|
||||||
|
'TN' : [32,64,128],
|
||||||
|
'TK' : [8],
|
||||||
|
}
|
||||||
|
|
||||||
|
_dot.kernel = triton.kernel(_dot.src, defines=defines)
|
||||||
|
_dot.kernel(a, b, c, M, N, K, lda, ldb, ldc,
|
||||||
|
grid=grid, num_warps=4, defines=defines)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
dot = _dot.apply
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
M, N, K = 128, 512, 256
|
||||||
|
a = torch.rand((M, K)).cuda()
|
||||||
|
b = torch.rand((K, N)).cuda()
|
||||||
|
|
||||||
|
|
||||||
|
zc = torch.matmul(a,b)
|
||||||
|
zc_ = dot(a,b)
|
||||||
|
|
||||||
|
print(torch.allclose(zc, zc_))
|
74
python/examples/tutorials/mat_transpose.py
Normal file
74
python/examples/tutorials/mat_transpose.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
class _transpose(torch.autograd.Function):
|
||||||
|
src = """
|
||||||
|
__global__ void transpose(TYPE * X, TYPE * Y,
|
||||||
|
int M, int N, int ldx __multipleof(8), int ldy __multipleof(8)) {
|
||||||
|
// extract program ID
|
||||||
|
int pidm = get_program_id(0); //(1)
|
||||||
|
int pidn = get_program_id(1); //(2)
|
||||||
|
|
||||||
|
// create 1D range along the two matrix's axes
|
||||||
|
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
||||||
|
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||||
|
|
||||||
|
// create 2D array of pointers
|
||||||
|
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
|
||||||
|
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
|
||||||
|
|
||||||
|
// create bounds-checking mask
|
||||||
|
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
|
||||||
|
bool checky[TN, TM] = (rn[:, newaxis] < N) && (rm[newaxis, :] < M); //(7b)
|
||||||
|
|
||||||
|
// conditional write-back using the conditional dereferencing operatior '*?()'
|
||||||
|
*?(checky)py = ^(*?(checkx)px); //(7)
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel = None ### initialize later when we know the sizes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
|
||||||
|
M, N = x.shape
|
||||||
|
|
||||||
|
ldx = N
|
||||||
|
ldy = M
|
||||||
|
|
||||||
|
dtype = x.dtype
|
||||||
|
|
||||||
|
y = torch.empty((N,M)).cuda()
|
||||||
|
|
||||||
|
defines= {
|
||||||
|
'TYPE' : dtype,
|
||||||
|
'TM' : [32,64,128],
|
||||||
|
'TN' : [32,64,128],
|
||||||
|
}
|
||||||
|
|
||||||
|
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||||
|
|
||||||
|
if _transpose.kernel is None:
|
||||||
|
_transpose.kernel = triton.kernel(_transpose.src, defines=defines, num_warps=[4])
|
||||||
|
|
||||||
|
_transpose.kernel(x, y, M, N, ldx, ldy, grid=grid)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
transpose = _transpose.apply
|
||||||
|
|
||||||
|
# test
|
||||||
|
torch.manual_seed(0)
|
||||||
|
x = torch.randn(1024,128).cuda()
|
||||||
|
|
||||||
|
print(x)
|
||||||
|
|
||||||
|
ya = torch.t(x)
|
||||||
|
yb = transpose(x)
|
||||||
|
print()
|
||||||
|
print(ya)
|
||||||
|
print()
|
||||||
|
print(yb)
|
||||||
|
print(torch.allclose(ya, yb))
|
||||||
|
|
||||||
|
print(ya == yb)
|
43
python/examples/tutorials/vec_add.py
Normal file
43
python/examples/tutorials/vec_add.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
class _add(torch.autograd.Function):
|
||||||
|
src = """
|
||||||
|
__global__ void add(float* z, float* x, float* y, int N) {
|
||||||
|
|
||||||
|
int pid = get_program_id(0);
|
||||||
|
|
||||||
|
int offset[TILE] = pid * TILE + 0 ... TILE;
|
||||||
|
float* pz[TILE] = z + offset;
|
||||||
|
float* px[TILE] = x + offset;
|
||||||
|
float* py[TILE] = y + offset;
|
||||||
|
|
||||||
|
bool check[TILE] = offset < N;
|
||||||
|
|
||||||
|
*?(check)pz = *?(check)px + *?(check)py;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, y):
|
||||||
|
z = torch.empty_like(x).cuda()
|
||||||
|
|
||||||
|
N = x.numel()
|
||||||
|
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
|
||||||
|
|
||||||
|
_add.kernel(z,x,y, N, grid=grid)
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
add = _add.apply
|
||||||
|
|
||||||
|
# test
|
||||||
|
torch.manual_seed(0)
|
||||||
|
x = torch.rand(98432).cuda()
|
||||||
|
y = torch.rand(98432).cuda()
|
||||||
|
za = x + y
|
||||||
|
zb = add(x, y)
|
||||||
|
|
||||||
|
print(torch.allclose(za,zb))
|
Reference in New Issue
Block a user