Init
This commit is contained in:
@@ -1469,6 +1469,7 @@ class CompiledKernel:
|
||||
def runner(*args, stream=None):
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream().cuda_stream
|
||||
#print(args)
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
|
||||
return runner
|
||||
|
||||
|
@@ -11,6 +11,161 @@ You will specifically learn about:
|
||||
- Automatic performance tuning
|
||||
"""
|
||||
|
||||
IR = """
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 4]}>
|
||||
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
||||
func public @matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c3_i32 = arith.constant 3 : i32
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c2_i32 = arith.constant 2 : i32
|
||||
%c64 = arith.constant 64 : index
|
||||
%cst = arith.constant dense<64> : tensor<128x64xi32, #blocked0>
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%c255_i32 = arith.constant 255 : i32
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%c127_i32 = arith.constant 127 : i32
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
%c0 = arith.constant 0 : index
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.addi %arg3, %c127_i32 : i32
|
||||
%2 = arith.divsi %1, %c128_i32 : i32
|
||||
%3 = arith.addi %arg4, %c255_i32 : i32
|
||||
%4 = arith.divsi %3, %c256_i32 : i32
|
||||
%5 = arith.muli %4, %c8_i32 : i32
|
||||
%6 = arith.divsi %0, %5 : i32
|
||||
%7 = arith.muli %6, %c8_i32 : i32
|
||||
%8 = arith.subi %2, %7 : i32
|
||||
%9 = arith.cmpi slt, %8, %c8_i32 : i32
|
||||
%10 = select %9, %8, %c8_i32 : i32
|
||||
%11 = arith.remsi %0, %10 : i32
|
||||
%12 = arith.addi %7, %11 : i32
|
||||
%13 = arith.remsi %0, %5 : i32
|
||||
%14 = arith.divsi %13, %10 : i32
|
||||
%15 = arith.muli %12, %c128_i32 : i32
|
||||
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%18 = tt.splat %15 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%19 = tt.splat %15 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%20 = arith.muli %14, %c256_i32 : i32
|
||||
%21 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%22 = tt.splat %20 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%23 = arith.addi %18, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%24 = arith.addi %19, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%25 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<128x1xi32, #blocked0>
|
||||
%26 = tt.expand_dims %24 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1>
|
||||
%27 = tt.splat %arg6 : (i32) -> tensor<128x1xi32, #blocked0>
|
||||
%28 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%29 = tt.expand_dims %28 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x64xi32, #blocked0>
|
||||
%30 = tt.broadcast %29 : (tensor<1x64xi32, #blocked0>) -> tensor<128x64xi32, #blocked0>
|
||||
%31 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x64x!tt.ptr<f16>, #blocked0>
|
||||
%32 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%33 = tt.expand_dims %32 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
|
||||
%34 = tt.splat %arg7 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%35 = arith.addi %22, %21 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%36 = tt.expand_dims %35 {axis = 0 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x256xi32, #blocked1>
|
||||
%37 = tt.broadcast %36 : (tensor<1x256xi32, #blocked1>) -> tensor<64x256xi32, #blocked1>
|
||||
%38 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<64x256x!tt.ptr<f16>, #blocked1>
|
||||
%39 = arith.index_cast %arg5 : i32 to index
|
||||
%40 = arith.muli %arg7, %c64_i32 : i32
|
||||
%41 = tt.splat %40 : (i32) -> tensor<64x256xi32, #blocked1>
|
||||
%42 = arith.muli %25, %27 : tensor<128x1xi32, #blocked0>
|
||||
%43 = tt.broadcast %42 : (tensor<128x1xi32, #blocked0>) -> tensor<128x64xi32, #blocked0>
|
||||
%44 = arith.addi %43, %30 : tensor<128x64xi32, #blocked0>
|
||||
%45 = tt.addptr %31, %44 : tensor<128x64x!tt.ptr<f16>, #blocked0>
|
||||
%46 = arith.muli %33, %34 : tensor<64x1xi32, #blocked1>
|
||||
%47 = tt.broadcast %46 : (tensor<64x1xi32, #blocked1>) -> tensor<64x256xi32, #blocked1>
|
||||
%48 = arith.addi %47, %37 : tensor<64x256xi32, #blocked1>
|
||||
%49 = tt.addptr %38, %48 : tensor<64x256x!tt.ptr<f16>, #blocked1>
|
||||
%50 = arith.cmpi slt, %c0, %39 : index
|
||||
%51 = triton_gpu.alloc_tensor : tensor<3x128x64xf16, #shared>
|
||||
%52 = tt.splat %50 : (i1) -> tensor<128x64xi1, #blocked0>
|
||||
%53 = triton_gpu.insert_slice_async %45, %51, %c0_i32, %52 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64x!tt.ptr<f16>, #blocked0> -> tensor<3x128x64xf16, #shared>
|
||||
%54 = triton_gpu.alloc_tensor : tensor<3x64x256xf16, #shared>
|
||||
%55 = tt.splat %50 : (i1) -> tensor<64x256xi1, #blocked1>
|
||||
%56 = triton_gpu.insert_slice_async %49, %54, %c0_i32, %55 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x256x!tt.ptr<f16>, #blocked1> -> tensor<3x64x256xf16, #shared>
|
||||
%57 = tt.addptr %45, %cst : tensor<128x64x!tt.ptr<f16>, #blocked0>
|
||||
%58 = tt.addptr %49, %41 : tensor<64x256x!tt.ptr<f16>, #blocked1>
|
||||
%59 = arith.cmpi slt, %c64, %39 : index
|
||||
%60 = tt.splat %59 : (i1) -> tensor<128x64xi1, #blocked0>
|
||||
%61 = triton_gpu.insert_slice_async %57, %53, %c1_i32, %60 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64x!tt.ptr<f16>, #blocked0> -> tensor<3x128x64xf16, #shared>
|
||||
%62 = tt.splat %59 : (i1) -> tensor<64x256xi1, #blocked1>
|
||||
%63 = triton_gpu.insert_slice_async %58, %56, %c1_i32, %62 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x256x!tt.ptr<f16>, #blocked1> -> tensor<3x64x256xf16, #shared>
|
||||
%64 = tt.addptr %57, %cst : tensor<128x64x!tt.ptr<f16>, #blocked0>
|
||||
%65 = tt.addptr %58, %41 : tensor<64x256x!tt.ptr<f16>, #blocked1>
|
||||
triton_gpu.async_wait {num = 2 : i32}
|
||||
%66 = tensor.extract_slice %61[0, 0, 0] [1, 128, 64] [1, 1, 1] : tensor<3x128x64xf16, #shared> to tensor<128x64xf16, #shared>
|
||||
%67 = tensor.extract_slice %63[0, 0, 0] [1, 64, 256] [1, 1, 1] : tensor<3x64x256xf16, #shared> to tensor<64x256xf16, #shared>
|
||||
%68 = tensor.extract_slice %66[0, 0] [128, 16] [1, 1] : tensor<128x64xf16, #shared> to tensor<128x16xf16, #shared>
|
||||
%69 = triton_gpu.convert_layout %68 : (tensor<128x16xf16, #shared>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
|
||||
%70 = tensor.extract_slice %67[0, 0] [16, 256] [1, 1] : tensor<64x256xf16, #shared> to tensor<16x256xf16, #shared>
|
||||
%71 = triton_gpu.convert_layout %70 : (tensor<16x256xf16, #shared>) -> tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>
|
||||
%72:14 = scf.for %arg9 = %c0 to %39 step %c64 iter_args(%arg10 = %cst_0, %arg11 = %45, %arg12 = %49, %arg13 = %61, %arg14 = %63, %arg15 = %66, %arg16 = %67, %arg17 = %64, %arg18 = %65, %arg19 = %c64, %arg20 = %c2_i32, %arg21 = %c1_i32, %arg22 = %69, %arg23 = %71) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<3x128x64xf16, #shared>, tensor<3x64x256xf16, #shared>, tensor<128x64xf16, #shared>, tensor<64x256xf16, #shared>, tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<64x256x!tt.ptr<f16>, #blocked1>, index, i32, i32, tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>) {
|
||||
%89 = tt.dot %arg22, %arg23, %arg10 {allowTF32 = true, transA = false, transB = false} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x256xf32, #mma>
|
||||
%90 = tensor.extract_slice %arg15[0, 16] [128, 32] [1, 1] : tensor<128x64xf16, #shared> to tensor<128x32xf16, #shared>
|
||||
%91 = triton_gpu.convert_layout %90 : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
|
||||
%92 = tensor.extract_slice %arg16[16, 0] [32, 256] [1, 1] : tensor<64x256xf16, #shared> to tensor<32x256xf16, #shared>
|
||||
%93 = triton_gpu.convert_layout %92 : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>
|
||||
%94 = tt.dot %91, %93, %89 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x256xf32, #mma>
|
||||
%95 = tensor.extract_slice %arg15[0, 48] [128, 16] [1, 1] : tensor<128x64xf16, #shared> to tensor<128x16xf16, #shared>
|
||||
%96 = triton_gpu.convert_layout %95 : (tensor<128x16xf16, #shared>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
|
||||
%97 = tensor.extract_slice %arg16[48, 0] [16, 256] [1, 1] : tensor<64x256xf16, #shared> to tensor<16x256xf16, #shared>
|
||||
%98 = triton_gpu.convert_layout %97 : (tensor<16x256xf16, #shared>) -> tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>
|
||||
%99 = tt.dot %96, %98, %94 {allowTF32 = true, transA = false, transB = false} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x256xf32, #mma>
|
||||
%100 = tt.addptr %arg11, %cst : tensor<128x64x!tt.ptr<f16>, #blocked0>
|
||||
%101 = tt.addptr %arg12, %41 : tensor<64x256x!tt.ptr<f16>, #blocked1>
|
||||
%102 = arith.addi %arg19, %c64 : index
|
||||
%103 = arith.cmpi slt, %102, %39 : index
|
||||
%104 = arith.remsi %arg20, %c3_i32 : i32
|
||||
%105 = arith.remsi %arg21, %c3_i32 : i32
|
||||
%106 = arith.index_cast %105 : i32 to index
|
||||
%107 = tt.splat %103 : (i1) -> tensor<128x64xi1, #blocked0>
|
||||
%108 = triton_gpu.insert_slice_async %arg17, %arg13, %104, %107 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64x!tt.ptr<f16>, #blocked0> -> tensor<3x128x64xf16, #shared>
|
||||
%109 = tt.splat %103 : (i1) -> tensor<64x256xi1, #blocked1>
|
||||
%110 = triton_gpu.insert_slice_async %arg18, %arg14, %104, %109 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x256x!tt.ptr<f16>, #blocked1> -> tensor<3x64x256xf16, #shared>
|
||||
%111 = tt.addptr %arg17, %cst : tensor<128x64x!tt.ptr<f16>, #blocked0>
|
||||
%112 = tt.addptr %arg18, %41 : tensor<64x256x!tt.ptr<f16>, #blocked1>
|
||||
triton_gpu.async_wait {num = 2 : i32}
|
||||
%113 = tensor.extract_slice %108[%106, 0, 0] [1, 128, 64] [1, 1, 1] : tensor<3x128x64xf16, #shared> to tensor<128x64xf16, #shared>
|
||||
%114 = tensor.extract_slice %110[%106, 0, 0] [1, 64, 256] [1, 1, 1] : tensor<3x64x256xf16, #shared> to tensor<64x256xf16, #shared>
|
||||
%115 = arith.addi %arg20, %c1_i32 : i32
|
||||
%116 = arith.addi %arg21, %c1_i32 : i32
|
||||
%117 = tensor.extract_slice %113[0, 0] [128, 16] [1, 1] : tensor<128x64xf16, #shared> to tensor<128x16xf16, #shared>
|
||||
%118 = triton_gpu.convert_layout %117 : (tensor<128x16xf16, #shared>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
|
||||
%119 = tensor.extract_slice %114[0, 0] [16, 256] [1, 1] : tensor<64x256xf16, #shared> to tensor<16x256xf16, #shared>
|
||||
%120 = triton_gpu.convert_layout %119 : (tensor<16x256xf16, #shared>) -> tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>
|
||||
scf.yield %99, %100, %101, %108, %110, %113, %114, %111, %112, %102, %115, %116, %118, %120 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<3x128x64xf16, #shared>, tensor<3x64x256xf16, #shared>, tensor<128x64xf16, #shared>, tensor<64x256xf16, #shared>, tensor<128x64x!tt.ptr<f16>, #blocked0>, tensor<64x256x!tt.ptr<f16>, #blocked1>, index, i32, i32, tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>
|
||||
}
|
||||
triton_gpu.async_wait {num = 0 : i32}
|
||||
%73 = triton_gpu.convert_layout %72#0 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked1>
|
||||
%74 = tt.splat %arg8 : (i32) -> tensor<128x1xi32, #blocked1>
|
||||
%75 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>, #blocked1>
|
||||
%76 = tt.broadcast %36 : (tensor<1x256xi32, #blocked1>) -> tensor<128x256xi32, #blocked1>
|
||||
%77 = tt.splat %arg3 : (i32) -> tensor<128x1xi32, #blocked1>
|
||||
%78 = tt.splat %arg4 : (i32) -> tensor<1x256xi32, #blocked1>
|
||||
%79 = "triton_gpu.cmpi"(%36, %78) {predicate = 2 : i64} : (tensor<1x256xi32, #blocked1>, tensor<1x256xi32, #blocked1>) -> tensor<1x256xi1, #blocked1>
|
||||
%80 = tt.broadcast %79 : (tensor<1x256xi1, #blocked1>) -> tensor<128x256xi1, #blocked1>
|
||||
%81 = arith.muli %74, %26 : tensor<128x1xi32, #blocked1>
|
||||
%82 = tt.addptr %75, %81 : tensor<128x1x!tt.ptr<f16>, #blocked1>
|
||||
%83 = tt.broadcast %82 : (tensor<128x1x!tt.ptr<f16>, #blocked1>) -> tensor<128x256x!tt.ptr<f16>, #blocked1>
|
||||
%84 = tt.addptr %83, %76 : tensor<128x256x!tt.ptr<f16>, #blocked1>
|
||||
%85 = arith.truncf %73 : tensor<128x256xf32, #blocked1> to tensor<128x256xf16, #blocked1>
|
||||
%86 = "triton_gpu.cmpi"(%26, %77) {predicate = 2 : i64} : (tensor<128x1xi32, #blocked1>, tensor<128x1xi32, #blocked1>) -> tensor<128x1xi1, #blocked1>
|
||||
%87 = tt.broadcast %86 : (tensor<128x1xi1, #blocked1>) -> tensor<128x256xi1, #blocked1>
|
||||
%88 = arith.andi %87, %80 : tensor<128x256xi1, #blocked1>
|
||||
tt.store %84, %85, %88 : tensor<128x256xf16, #blocked1>
|
||||
return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
# Motivations
|
||||
# -------------
|
||||
@@ -144,6 +299,7 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.testing
|
||||
|
||||
# %
|
||||
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune`
|
||||
@@ -250,6 +406,7 @@ def leaky_relu(x):
|
||||
# We can now create a convenience wrapper function that only takes two input tensors
|
||||
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel
|
||||
|
||||
ttgir_kernel = None
|
||||
|
||||
def matmul(a, b, activation=None):
|
||||
# checks constraints
|
||||
@@ -267,14 +424,28 @@ def matmul(a, b, activation=None):
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
||||
)
|
||||
matmul_kernel[grid](
|
||||
a, b, c,
|
||||
global ttgir_kernel
|
||||
if ttgir_kernel is None:
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
||||
f.write(IR)
|
||||
f.flush()
|
||||
ttgir_kernel = triton.compile(f.name, num_warps=8)
|
||||
ttgir_kernel[(2048, 1, 1)](
|
||||
a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
ACTIVATION=activation,
|
||||
a.stride(0),
|
||||
b.stride(0),
|
||||
c.stride(0)
|
||||
)
|
||||
#k = matmul_kernel[grid](
|
||||
# a, b, c,
|
||||
# M, N, K,
|
||||
# a.stride(0), a.stride(1),
|
||||
# b.stride(0), b.stride(1),
|
||||
# c.stride(0), c.stride(1),
|
||||
# ACTIVATION=None,
|
||||
#)
|
||||
return c
|
||||
|
||||
|
||||
@@ -285,8 +456,8 @@ def matmul(a, b, activation=None):
|
||||
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS)
|
||||
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
a = torch.randn((8192, 8192), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((8192, 8192), device='cuda', dtype=torch.float16)
|
||||
triton_output = matmul(a, b, activation=None)
|
||||
torch_output = torch.matmul(a, b)
|
||||
print(f"triton_output={triton_output}")
|
||||
@@ -326,10 +497,11 @@ else:
|
||||
def benchmark(M, N, K, provider):
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
with triton.testing.set_gpu_clock():
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=1000)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=1000)
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
Reference in New Issue
Block a user