[PYTHON] Added option to show PTX source code in Python

This commit is contained in:
Philippe Tillet
2020-11-07 02:55:48 -05:00
committed by Philippe Tillet
parent cf80ccc798
commit 8f3ee53f24
10 changed files with 68 additions and 33 deletions

View File

@@ -121,8 +121,8 @@ dot = _dot.apply
torch.manual_seed(0)
M, N, K = 2048, 2048, 2048
a = torch.rand((M, K)).cuda()
b = torch.rand((K, N)).cuda()
a = torch.rand((M, K)).cuda().half()
b = torch.rand((K, N)).cuda().half()
#a[:] = 1
#b[:] = 1

View File

@@ -23,12 +23,9 @@ __global__ void add(float* z, float* x, float* y, int N) {
@staticmethod
def forward(ctx, x, y):
z = torch.empty_like(x).cuda()
N = x.numel()
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
_add.kernel(z,x,y, N, grid=grid)
return z
add = _add.apply