[PYTHON] CUTLASS wrapper for fair benchmarks (#75)
Before this commit, the benchmarking infrastructure used heterogeneous protocols between library (e.g., CUTLASS uses a C++ binary that reports mean TFLOPS; torch and triton use python call and report 10th, 50th and 90th quantiles). For the sake of uniformity and fair benchmark practices, this PR adds a python wrapper for auto-tuned CUTLASS matrix multiplication. Benchmarks have been rewritten to use this wrapper with `triton.testing.do_bench` rather than system calls to CUTLASS profiler. Importantly, this also ensures that all the matmuls are done on the *same* input data which should stabilize clock across providers.
This commit is contained in:
committed by
Philippe Tillet
parent
d6f18742b1
commit
eacbb73968
@@ -34,10 +34,19 @@ if(BUILD_PYTHON_MODULE)
|
|||||||
message(STATUS "Adding Python module")
|
message(STATUS "Adding Python module")
|
||||||
# PyBind11 wrapper source file
|
# PyBind11 wrapper source file
|
||||||
file(GLOB_RECURSE TORCH_SRC torch/*.cc)
|
file(GLOB_RECURSE TORCH_SRC torch/*.cc)
|
||||||
|
# Build CUTLASS python wrapper if requested
|
||||||
|
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}")
|
||||||
|
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}")
|
||||||
|
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
|
||||||
|
set(TORCH_SRC ${TORCH_SRC} cutlass.cc)
|
||||||
|
add_definitions(-DWITH_CUTLASS_BINDINGS)
|
||||||
|
set(CUTLASS_LIBRARIES "cutlass")
|
||||||
|
endif()
|
||||||
|
message(STATUS ${CUTLASS_INCLUDE_PATH})
|
||||||
set(PYTHON_SRC main.cc triton.cc ${TORCH_SRC})
|
set(PYTHON_SRC main.cc triton.cc ${TORCH_SRC})
|
||||||
set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}")
|
set_source_files_properties(${TORCH_SRC} PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI} ${CUTLASS_OPT}")
|
||||||
include_directories("." ${PYTHON_INCLUDE_DIRS})
|
include_directories("." ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
|
||||||
link_directories(${PYTHON_LINK_DIRS})
|
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
@@ -47,5 +56,5 @@ add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
|||||||
target_link_libraries(triton ${LLVM_LIBRARIES} ${LLVM_SYSTEM_LIBS})
|
target_link_libraries(triton ${LLVM_LIBRARIES} ${LLVM_SYSTEM_LIBS})
|
||||||
|
|
||||||
if(BUILD_PYTHON_MODULE)
|
if(BUILD_PYTHON_MODULE)
|
||||||
target_link_libraries(triton ${TORCH_LIBRARIES})
|
target_link_libraries(triton ${TORCH_LIBRARIES} ${CUTLASS_LIBRARIES})
|
||||||
endif()
|
endif()
|
||||||
|
@@ -48,7 +48,7 @@ transformer_confs = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(transformer_confs)
|
@triton.testing.perf_report(square_confs)
|
||||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
|
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
|
||||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||||
@@ -62,37 +62,11 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
|
|||||||
if provider == "triton":
|
if provider == "triton":
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
||||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||||
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
|
if provider == "cutlass":
|
||||||
import subprocess
|
cutlass_matmul = triton.testing.cutlass_matmul
|
||||||
import tempfile
|
try:
|
||||||
import pandas as pd
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
|
||||||
# run program specified by CUTLASS_PROFILER env variable
|
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||||
layout_a = "column" if AT else "row"
|
except:
|
||||||
layout_b = "column" if BT else "row"
|
return None
|
||||||
# create temporary file name
|
|
||||||
fd, fname = tempfile.mkstemp()
|
|
||||||
# run program and gets its output
|
|
||||||
cmd = [
|
|
||||||
os.environ["CUTLASS_PROFILER"],
|
|
||||||
f"--m={M}",
|
|
||||||
f"--n={N}",
|
|
||||||
f"--k={K}",
|
|
||||||
f"--A=f16:{layout_a}",
|
|
||||||
f"--B=f16:{layout_b}",
|
|
||||||
"--C=f16:column",
|
|
||||||
"--accum=f32",
|
|
||||||
"--operation=gemm",
|
|
||||||
"--verification-enabled=false",
|
|
||||||
f"--warmup-iterations={warmup}",
|
|
||||||
f"--profiling-iterations={rep}",
|
|
||||||
f"--output={fname}",
|
|
||||||
"--dist=uniform,min:0,max:1,scale:-1",
|
|
||||||
"--verbose=false",
|
|
||||||
]
|
|
||||||
# run cmd
|
|
||||||
subprocess.run(cmd, stdout=subprocess.PIPE)
|
|
||||||
# read CSV output
|
|
||||||
df_c = pd.read_csv(f"{fname}.gemm.csv")
|
|
||||||
tflops = max(df_c["GFLOPs"]) / 1e3
|
|
||||||
return tflops
|
|
||||||
return None
|
return None
|
||||||
|
@@ -14,6 +14,7 @@ from setuptools.command.test import test as TestCommand
|
|||||||
import distutils.spawn
|
import distutils.spawn
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def find_llvm():
|
def find_llvm():
|
||||||
versions = ["-10", "-10.0", ""]
|
versions = ["-10", "-10.0", ""]
|
||||||
supported = ["llvm-config{v}".format(v=v) for v in versions]
|
supported = ["llvm-config{v}".format(v=v) for v in versions]
|
||||||
@@ -28,19 +29,23 @@ def find_llvm():
|
|||||||
version = os.popen("{config} --version".format(config=config)).read()
|
version = os.popen("{config} --version".format(config=config)).read()
|
||||||
raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions)
|
raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions)
|
||||||
|
|
||||||
|
|
||||||
class CMakeExtension(Extension):
|
class CMakeExtension(Extension):
|
||||||
def __init__(self, name, path, sourcedir=""):
|
def __init__(self, name, path, sourcedir=""):
|
||||||
Extension.__init__(self, name, sources=[])
|
Extension.__init__(self, name, sources=[])
|
||||||
self.sourcedir = os.path.abspath(sourcedir)
|
self.sourcedir = os.path.abspath(sourcedir)
|
||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
|
|
||||||
class CMakeBuild(build_ext):
|
class CMakeBuild(build_ext):
|
||||||
def run(self):
|
def run(self):
|
||||||
try:
|
try:
|
||||||
out = subprocess.check_output(["cmake", "--version"])
|
out = subprocess.check_output(["cmake", "--version"])
|
||||||
except OSError:
|
except OSError:
|
||||||
raise RuntimeError("CMake must be installed to build the following extensions: " +
|
raise RuntimeError(
|
||||||
", ".join(e.name for e in self.extensions))
|
"CMake must be installed to build the following extensions: " +
|
||||||
|
", ".join(e.name for e in self.extensions)
|
||||||
|
)
|
||||||
|
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1))
|
cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1))
|
||||||
@@ -92,6 +97,7 @@ class CMakeBuild(build_ext):
|
|||||||
subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env)
|
subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env)
|
||||||
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
|
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="triton",
|
name="triton",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
@@ -101,7 +107,10 @@ setup(
|
|||||||
long_description="",
|
long_description="",
|
||||||
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
|
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
|
||||||
install_requires=["numpy", "torch"],
|
install_requires=["numpy", "torch"],
|
||||||
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
package_data={
|
||||||
|
"triton/ops": ["*.c"],
|
||||||
|
"triton/ops/blocksparse": ["*.c"]
|
||||||
|
},
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||||
cmdclass={"build_ext": CMakeBuild},
|
cmdclass={"build_ext": CMakeBuild},
|
||||||
|
206
python/src/cutlass.cc
Normal file
206
python/src/cutlass.cc
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
#include "cutlass/library/handle.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/operation_table.h"
|
||||||
|
#include "cutlass/library/singleton.h"
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "triton/tools/bench.hpp"
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
using namespace cutlass;
|
||||||
|
using namespace cutlass::library;
|
||||||
|
|
||||||
|
std::map<std::vector<size_t>, const Operation *> op_cache_;
|
||||||
|
|
||||||
|
static int const kHostWorkspaceSize = (4 << 10);
|
||||||
|
static int const kDeviceWorkspaceSize = (4 << 20);
|
||||||
|
|
||||||
|
void run(int M, int N, int K,
|
||||||
|
int lda, int ldb, int ldc, int ldd,
|
||||||
|
void const *ptr_A, void const *ptr_B, void const *ptr_C, void *ptr_D,
|
||||||
|
void const *alpha, void const *beta,
|
||||||
|
ScalarPointerMode scalar_mode,
|
||||||
|
const Operation *operation,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
|
||||||
|
GemmUniversalConfiguration configuration{
|
||||||
|
GemmUniversalMode::kGemm,
|
||||||
|
{M, N, K},
|
||||||
|
1,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
ldc,
|
||||||
|
ldd};
|
||||||
|
|
||||||
|
// host workspace size
|
||||||
|
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
|
||||||
|
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed)
|
||||||
|
throw std::runtime_error("Unable to find gemm operation");
|
||||||
|
char host_workspace[kHostWorkspaceSize];
|
||||||
|
|
||||||
|
// device workspace size
|
||||||
|
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
|
||||||
|
if (uint64_t(kDeviceWorkspaceSize) < device_workspace_size_needed)
|
||||||
|
throw std::runtime_error("Unable to find gemm operation");
|
||||||
|
static void *device_workspace;
|
||||||
|
|
||||||
|
// Initialize host and device workspaces
|
||||||
|
Status status = operation->initialize(&configuration, host_workspace, device_workspace, stream);
|
||||||
|
if (status != cutlass::Status::kSuccess)
|
||||||
|
throw std::runtime_error("Unable to initialize workspace");
|
||||||
|
|
||||||
|
// Run the operator
|
||||||
|
GemmArguments arguments{ptr_A, ptr_B, ptr_C, ptr_D, alpha, beta, scalar_mode};
|
||||||
|
operation->run(&arguments, host_workspace, device_workspace, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
const Operation *autotune(int M, int N, int K,
|
||||||
|
NumericTypeID element_compute,
|
||||||
|
NumericTypeID element_scalar,
|
||||||
|
void const *alpha,
|
||||||
|
NumericTypeID element_A,
|
||||||
|
LayoutTypeID layout_A,
|
||||||
|
ComplexTransform transform_A,
|
||||||
|
void const *ptr_A,
|
||||||
|
int lda,
|
||||||
|
NumericTypeID element_B,
|
||||||
|
LayoutTypeID layout_B,
|
||||||
|
ComplexTransform transform_B,
|
||||||
|
void const *ptr_B,
|
||||||
|
int ldb,
|
||||||
|
void const *beta,
|
||||||
|
NumericTypeID element_C,
|
||||||
|
void const *ptr_C,
|
||||||
|
int ldc,
|
||||||
|
void *ptr_D,
|
||||||
|
int ldd,
|
||||||
|
ScalarPointerMode scalar_mode,
|
||||||
|
int device_id,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
|
||||||
|
// index operation table with functional key
|
||||||
|
GemmFunctionalKey key(
|
||||||
|
Provider::kCUTLASS,
|
||||||
|
GemmKind::kUniversal,
|
||||||
|
element_compute,
|
||||||
|
element_scalar,
|
||||||
|
element_A,
|
||||||
|
layout_A,
|
||||||
|
transform_A,
|
||||||
|
element_B,
|
||||||
|
layout_B,
|
||||||
|
transform_B,
|
||||||
|
element_C);
|
||||||
|
|
||||||
|
auto operators_it = Singleton::get().operation_table.gemm_operations.find(key);
|
||||||
|
if (operators_it == Singleton::get().operation_table.gemm_operations.end())
|
||||||
|
throw std::runtime_error("Unable to find gemm operation");
|
||||||
|
if (operators_it->second.empty())
|
||||||
|
throw std::runtime_error("Unable to find gemm operation");
|
||||||
|
|
||||||
|
cudaDeviceProp device_prop;
|
||||||
|
cudaError_t error = cudaGetDeviceProperties(&device_prop, device_id);
|
||||||
|
if (error != cudaSuccess)
|
||||||
|
throw std::runtime_error("Unable to get device properties");
|
||||||
|
int cc = device_prop.major * 10 + device_prop.minor;
|
||||||
|
|
||||||
|
// index operation table with preference key
|
||||||
|
// assume 8-bytes aligned memory pointers
|
||||||
|
int alignment = 8;
|
||||||
|
GemmPreferenceKey preference_key(cc, alignment);
|
||||||
|
auto autotune_it = operators_it->second.find(preference_key);
|
||||||
|
if (autotune_it == operators_it->second.end())
|
||||||
|
throw std::runtime_error("Unable to find gemm operation");
|
||||||
|
const std::vector<const Operation *> &operations = autotune_it->second;
|
||||||
|
if (operations.empty())
|
||||||
|
throw std::runtime_error("Unable to find gemm operation");
|
||||||
|
|
||||||
|
// auto-tune
|
||||||
|
const Operation *best = nullptr;
|
||||||
|
double best_ms = std::numeric_limits<double>::max();
|
||||||
|
for (const Operation *op : operations) {
|
||||||
|
auto fn = [&]() { run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D,
|
||||||
|
alpha, beta, scalar_mode, op, stream); };
|
||||||
|
triton::driver::cu_stream tt_stream((CUstream)stream, false);
|
||||||
|
double ms = triton::tools::bench(fn, &tt_stream, 10, 25);
|
||||||
|
if (ms < best_ms) {
|
||||||
|
best_ms = ms;
|
||||||
|
best = op;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
|
||||||
|
// map of torch datatypes to cutlass datatypes
|
||||||
|
std::map<caffe2::TypeIdentifier, NumericTypeID> type_map = {
|
||||||
|
{caffe2::TypeMeta::Id<at::Half>(), NumericTypeID::kF16},
|
||||||
|
{caffe2::TypeMeta::Id<float>(), NumericTypeID::kF32},
|
||||||
|
{caffe2::TypeMeta::Id<double>(), NumericTypeID::kF64}};
|
||||||
|
|
||||||
|
void cutlass_matmul(torch::Tensor A, torch::Tensor B, torch::Tensor C) {
|
||||||
|
size_t M = A.size(0);
|
||||||
|
size_t N = B.size(1);
|
||||||
|
size_t K = A.size(1);
|
||||||
|
size_t lda = A.stride(0);
|
||||||
|
size_t ldb = B.stride(0);
|
||||||
|
size_t ldc = C.stride(1);
|
||||||
|
size_t ldd = C.stride(1);
|
||||||
|
void *ptr_A = A.data_ptr();
|
||||||
|
void *ptr_B = B.data_ptr();
|
||||||
|
void *ptr_C = C.data_ptr();
|
||||||
|
void *ptr_D = ptr_C;
|
||||||
|
float alpha = 1.0f;
|
||||||
|
float beta = 0.0f;
|
||||||
|
// layout for A
|
||||||
|
LayoutTypeID layout_A;
|
||||||
|
if (A.stride(0) == 1)
|
||||||
|
layout_A = LayoutTypeID::kColumnMajor;
|
||||||
|
else if (A.stride(1) == 1)
|
||||||
|
layout_A = LayoutTypeID::kRowMajor;
|
||||||
|
else {
|
||||||
|
A = A.contiguous();
|
||||||
|
layout_A = LayoutTypeID::kRowMajor;
|
||||||
|
}
|
||||||
|
// layout for B
|
||||||
|
LayoutTypeID layout_B;
|
||||||
|
if (B.stride(0) == 1)
|
||||||
|
layout_B = LayoutTypeID::kColumnMajor;
|
||||||
|
else if (B.stride(1) == 1)
|
||||||
|
layout_B = LayoutTypeID::kRowMajor;
|
||||||
|
else {
|
||||||
|
B = B.contiguous();
|
||||||
|
layout_B = LayoutTypeID::kRowMajor;
|
||||||
|
}
|
||||||
|
// data types
|
||||||
|
NumericTypeID element_compute = NumericTypeID::kF32;
|
||||||
|
NumericTypeID element_A = type_map[A.dtype().id()];
|
||||||
|
NumericTypeID element_B = type_map[B.dtype().id()];
|
||||||
|
NumericTypeID element_C = type_map[C.dtype().id()];
|
||||||
|
// misc. flags
|
||||||
|
ScalarPointerMode scalar_mode = ScalarPointerMode::kHost;
|
||||||
|
NumericTypeID element_scalar = NumericTypeID::kF32;
|
||||||
|
ComplexTransform transform_A = ComplexTransform::kNone;
|
||||||
|
ComplexTransform transform_B = ComplexTransform::kNone;
|
||||||
|
// runtime flags
|
||||||
|
size_t dev_id = C.device().index();
|
||||||
|
cudaStream_t stream = c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||||
|
// auto-tune
|
||||||
|
std::vector<size_t> tune_key = {M, N, K, (size_t)element_A, (size_t)element_B, (size_t)element_C,
|
||||||
|
dev_id, (size_t)element_compute, (size_t)scalar_mode};
|
||||||
|
auto it = op_cache_.find(tune_key);
|
||||||
|
if (it == op_cache_.end()) {
|
||||||
|
const Operation *op = autotune(M, N, K, element_compute, element_scalar, &alpha,
|
||||||
|
element_A, layout_A, transform_A, ptr_A, lda,
|
||||||
|
element_B, layout_B, transform_B, ptr_B, ldb,
|
||||||
|
&beta, element_C, ptr_C, ldc, ptr_D, ldd, scalar_mode,
|
||||||
|
dev_id, stream);
|
||||||
|
it = op_cache_.insert({tune_key, op}).first;
|
||||||
|
}
|
||||||
|
run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D, &alpha, &beta,
|
||||||
|
scalar_mode, it->second, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
void init_cutlass(pybind11::module &m) {
|
||||||
|
pybind11::module subm = m.def_submodule("cutlass");
|
||||||
|
subm.def("matmul", &cutlass_matmul, "matrix multiplication");
|
||||||
|
}
|
@@ -3,10 +3,14 @@
|
|||||||
void init_superblocking(pybind11::module &m);
|
void init_superblocking(pybind11::module &m);
|
||||||
void init_torch_utils(pybind11::module &m);
|
void init_torch_utils(pybind11::module &m);
|
||||||
void init_triton(pybind11::module &m);
|
void init_triton(pybind11::module &m);
|
||||||
|
void init_cutlass(pybind11::module &m);
|
||||||
|
|
||||||
PYBIND11_MODULE(libtriton, m) {
|
PYBIND11_MODULE(libtriton, m) {
|
||||||
m.doc() = "Python bindings to the C++ Triton API";
|
m.doc() = "Python bindings to the C++ Triton API";
|
||||||
init_triton(m);
|
init_triton(m);
|
||||||
init_torch_utils(m);
|
init_torch_utils(m);
|
||||||
init_superblocking(m);
|
init_superblocking(m);
|
||||||
|
#ifdef WITH_CUTLASS_BINDINGS
|
||||||
|
init_cutlass(m);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton._C.libtriton.cutlass as _cutlass
|
||||||
|
except ImportError:
|
||||||
|
_cutlass = None
|
||||||
|
|
||||||
|
|
||||||
def sparsify_tensor(x, mask, block):
|
def sparsify_tensor(x, mask, block):
|
||||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||||
@@ -9,6 +14,15 @@ def sparsify_tensor(x, mask, block):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_matmul(a, b):
|
||||||
|
if _cutlass is None:
|
||||||
|
raise RuntimeError("Cannot find cutlass library")
|
||||||
|
M, N = a.shape[0], b.shape[1]
|
||||||
|
c = torch.empty_strided((M, N), (1, M), dtype=a.dtype, device=a.device)
|
||||||
|
_cutlass.matmul(a, b, c)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
def mask_tensor(x, mask, block, value=0):
|
def mask_tensor(x, mask, block, value=0):
|
||||||
ret = x.clone()
|
ret = x.clone()
|
||||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||||
|
Reference in New Issue
Block a user