[DOCS] Matrix copy and transpose

This commit is contained in:
jack-willturner
2020-05-05 14:30:49 +01:00
committed by Philippe Tillet
parent 32819dea51
commit a98a2db2c2
2 changed files with 136 additions and 0 deletions

View File

@@ -0,0 +1,68 @@
import torch
import triton
class _copy(torch.autograd.Function):
src = """
__global__ void copy(TYPE * X, TYPE * Y,
int M, int N, int ldx __multipleof(8)) {
// extract program ID
int pidm = get_program_id(0); //(1)
int pidn = get_program_id(1); //(2)
// create 1D range along the two matrix's axes
int rm[TM] = pidm * TM + 0 ... TM; //(3)
int rn[TN] = pidn * TN + 0 ... TN; //(4)
// create 2D array of pointers
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldx; //(6)
*py = *px;
}
"""
kernel = None ### initialize later when we know the sizes
@staticmethod
def forward(ctx, x):
M, N = x.shape
ldx = N;
dtype = x.dtype
y = torch.empty((M,N)).cuda()
defines= {
'TYPE' : dtype,
'TM' : [32,64,128],
'TN' : [32,64,128],
}
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
if _copy.kernel is None:
_copy.kernel = triton.kernel(_copy.src, defines=defines, num_warps=[4])
_copy.kernel(x, y, M, N, ldx, grid=grid)
return y
copy = _copy.apply
# test
torch.manual_seed(0)
x = torch.randn(8,4).cuda()
print(x)
ya = x
yb = copy(x)
print()
print(ya)
print()
print(yb)
print(torch.allclose(ya, yb))
print(ya == yb)

View File

@@ -0,0 +1,68 @@
import torch
import triton
class _transpose(torch.autograd.Function):
src = """
__global__ void transpose(TYPE * X, TYPE * Y,
int M, int N, int ldx __multipleof(8), int ldy __multipleof(8)) {
// extract program ID
int pidm = get_program_id(0); //(1)
int pidn = get_program_id(1); //(2)
// create 1D range along the two matrix's axes
int rm[TM] = pidm * TM + 0 ... TM; //(3)
int rn[TN] = pidn * TN + 0 ... TN; //(4)
// create 2D array of pointers
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
TYPE* py[TN, TM] = Y + rm[newaxis, :] * ldy + rn[:,newaxis]; //(6)
*py = ^*px;
}
"""
kernel = None ### initialize later when we know the sizes
@staticmethod
def forward(ctx, x):
M, N = x.shape
ldx = N
ldy = M
dtype = x.dtype
y = torch.empty((N,M)).cuda()
defines= {
'TYPE' : dtype,
'TM' : [32,64,128],
'TN' : [32,64,128],
}
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
if _transpose.kernel is None:
_transpose.kernel = triton.kernel(_transpose.src, defines=defines, num_warps=[4])
_transpose.kernel(x, y, M, N, ldx, ldy, grid=grid)
return y
transpose = _transpose.apply
# test
torch.manual_seed(0)
x = torch.randn(128,200).cuda()
print(x)
ya = torch.t(x)
yb = transpose(x)
print()
print(ya)
print()
print(yb)
print(torch.allclose(ya, yb))
print(ya == yb)