[RUNTIME] Added auto-alignment mechanism (#71)
This PR adds an automatic memory alignment mechanism in the Triton runtime. Specifically, the JIT compiler detects the alignment (in bytes) of each pointer argument as well as the largest power of two divisor (between 1 and 16) of each integer argument. Proper .aligned and .multipleof attributes are then added to the Triton-IR on-the-fly for all auto-tunable kernels. There is a cache that remembers all the kernels compiled for each possible configuration. This PR also includes substantial cleaning of the Python API. This adds 2-3us overhead, mostly due to accessing integer #defines from the auto-tuned compilation options. The previous solution was slightly faster but hacky and potentially unsafe, so this is preferred for now.
This commit is contained in:
committed by
Philippe Tillet
parent
ff62f7fffc
commit
62835a0979
@@ -1,12 +1,19 @@
|
||||
import triton
|
||||
import torch
|
||||
import os
|
||||
|
||||
# square benchmarks
|
||||
def rounded_linspace(low, high, steps, div):
|
||||
ret = torch.linspace(low, high, steps)
|
||||
ret = (ret.int() + div - 1) // div * div
|
||||
ret = torch.unique(ret)
|
||||
return list(map(int, ret))
|
||||
|
||||
# Square benchmarks
|
||||
nt = {False: "n", True: "t"}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["M", "N", "K"],
|
||||
x_vals=[512 * i for i in range(1, 16)],
|
||||
x_vals=rounded_linspace(512, 8192, 17, 128),
|
||||
y_name="provider",
|
||||
y_vals=["torch", "triton", "cutlass"],
|
||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||
@@ -17,16 +24,29 @@ square_confs = [
|
||||
) for AT in [False, True] for BT in [False, True]
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
||||
import os
|
||||
# Transformer training benchmarks
|
||||
transformer_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=[x],
|
||||
x_vals = rounded_linspace(NK//16, NK, 33, 128),
|
||||
y_name="provider",
|
||||
y_vals=["torch", "triton", "cutlass"],
|
||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
loglog=False,
|
||||
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
||||
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
||||
) for NK in [8192]\
|
||||
for i, x in enumerate(["N", "K"])\
|
||||
for M in [2048]
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40):
|
||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||
if AT:
|
||||
a = a.t()
|
||||
if BT:
|
||||
b = b.t()
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
num_flops = 2 * M * N * K
|
||||
if provider == "torch":
|
||||
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
||||
@@ -40,7 +60,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
||||
import subprocess
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
|
||||
# run program specified by CUTLASS_PROFILER env variable
|
||||
layout_a = "column" if AT else "row"
|
||||
layout_b = "column" if BT else "row"
|
||||
@@ -61,6 +80,7 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
||||
f"--warmup-iterations={warmup}",
|
||||
f"--profiling-iterations={rep}",
|
||||
f"--output={fname}",
|
||||
"--dist=uniform,min:0,max:1,scale:-1",
|
||||
"--verbose=false",
|
||||
]
|
||||
# run cmd
|
||||
@@ -70,6 +90,3 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
||||
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
|
||||
return cutlass_tflops
|
||||
return None
|
||||
|
||||
if __name__ == "__main__":
|
||||
bench_op.run()
|
||||
|
@@ -38,4 +38,4 @@ def main(args):
|
||||
run_all(args.result_dir, args.with_plots, args.names)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
||||
main(sys.argv[1:])
|
||||
|
@@ -5,38 +5,16 @@
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
|
||||
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
|
||||
|
||||
namespace torch_utils {
|
||||
|
||||
void register_device(int64_t dev_id) {
|
||||
if (tt_devices.find(dev_id) != tt_devices.end())
|
||||
return;
|
||||
triton::driver::device *device;
|
||||
if (dev_id >= 0) {
|
||||
CUdevice handle;
|
||||
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
|
||||
device = new triton::driver::cu_device(handle, false);
|
||||
} else
|
||||
device = new triton::driver::host_device();
|
||||
tt_devices[dev_id].reset(device);
|
||||
uint64_t cu_device(int64_t dev_id) {
|
||||
CUdevice handle;
|
||||
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
|
||||
return (uint64_t)handle;
|
||||
}
|
||||
|
||||
void register_stream(int64_t dev_id) {
|
||||
if (tt_streams.find(dev_id) != tt_streams.end())
|
||||
return;
|
||||
triton::driver::stream *stream;
|
||||
if (dev_id >= 0) {
|
||||
CUstream handle = (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
stream = new triton::driver::cu_stream(handle, false);
|
||||
} else
|
||||
stream = new triton::driver::host_stream();
|
||||
tt_streams[dev_id].reset(stream);
|
||||
}
|
||||
|
||||
void synchronize(int64_t dev_id) {
|
||||
tt_streams[dev_id]->synchronize();
|
||||
uint64_t cu_stream(int64_t dev_id) {
|
||||
return (uint64_t)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||
}
|
||||
|
||||
void set_device(int64_t dev_id) {
|
||||
@@ -44,23 +22,11 @@ void set_device(int64_t dev_id) {
|
||||
C10_CUDA_CHECK(cudaSetDevice(dev_id));
|
||||
}
|
||||
|
||||
torch::Tensor move_out_of_pool(torch::Tensor x) {
|
||||
if (x.nbytes() == 0)
|
||||
return torch::empty_like(x);
|
||||
void *data;
|
||||
cudaMalloc(&data, x.nbytes());
|
||||
auto ret = torch::from_blob((void *)data, x.sizes(), x.strides(), [data](void *ptr) { cudaFree(data); }, x.options());
|
||||
ret.copy_(x);
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace torch_utils
|
||||
|
||||
void init_torch_utils(pybind11::module &m) {
|
||||
pybind11::module subm = m.def_submodule("torch_utils");
|
||||
subm.def("register_device", &torch_utils::register_device);
|
||||
subm.def("register_stream", &torch_utils::register_stream);
|
||||
subm.def("cu_device", &torch_utils::cu_device);
|
||||
subm.def("cu_stream", &torch_utils::cu_stream);
|
||||
subm.def("set_device", &torch_utils::set_device);
|
||||
subm.def("synchronize", &torch_utils::synchronize);
|
||||
subm.def("move_out_of_pool", &torch_utils::move_out_of_pool);
|
||||
}
|
@@ -1,10 +1,4 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/lang/code_gen.h"
|
||||
#include "triton/lang/cpp.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
@@ -13,72 +7,22 @@
|
||||
#include <regex>
|
||||
#include <string>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using namespace triton;
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
namespace lng = triton::lang;
|
||||
|
||||
std::unordered_map<const rt::options_t *, pybind11::object> opt_cache_;
|
||||
std::map<int, std::shared_ptr<rt::function>> id_fn_map;
|
||||
extern std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
|
||||
extern std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::tools */
|
||||
/*****************************************************************************/
|
||||
|
||||
/* Function utilities */
|
||||
|
||||
void register_fn(int op_id, int dev_id,
|
||||
const std::string &src, const rt::options_t &opt,
|
||||
const rt::function::autotune_vals_t &autotune_vals,
|
||||
const std::vector<std::string> &autotune_key) {
|
||||
if (id_fn_map.find(op_id) == id_fn_map.end()) {
|
||||
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key));
|
||||
}
|
||||
for (const auto &k : id_fn_map[op_id]->get_kernels()) {
|
||||
const rt::options_t *opt = &k.first;
|
||||
pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference);
|
||||
for (auto x : opt->defines)
|
||||
if (std::all_of(x.second.begin(), x.second.end(), ::isdigit))
|
||||
obj.attr(x.first.c_str()) = std::stoi(x.second);
|
||||
opt_cache_[&k.second->opt] = obj;
|
||||
}
|
||||
}
|
||||
|
||||
void delete_fn(int op_id) {
|
||||
id_fn_map.erase(op_id);
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
id_fn_map.clear();
|
||||
opt_cache_.clear();
|
||||
}
|
||||
|
||||
size_t make_op_id() {
|
||||
return id_fn_map.size();
|
||||
}
|
||||
|
||||
std::vector<rt::arg_type> get_fn_signature(size_t op_id) {
|
||||
return id_fn_map[op_id]->get_kernels()[0].second->get_sig();
|
||||
}
|
||||
|
||||
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
|
||||
// as a string constructed with struct.pack in python
|
||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string &args, size_t grid_0, size_t grid_1, size_t grid_2) {
|
||||
rt::function *fn = id_fn_map.at(op_id).get();
|
||||
(*fn)((void **)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]);
|
||||
}
|
||||
|
||||
pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string &args, const rt::function::grid_fn_ty &grid) {
|
||||
rt::function *fn = id_fn_map.at(op_id).get();
|
||||
auto wrapper = [&grid](const rt::options_t &opt) {
|
||||
pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference);
|
||||
for (auto x : opt.defines)
|
||||
if (std::all_of(x.second.begin(), x.second.end(), ::isdigit))
|
||||
obj.attr(x.first.c_str()) = std::stoi(x.second);
|
||||
return grid(*obj.cast<rt::options_t *>());
|
||||
};
|
||||
rt::kernel *kernel = fn->autotune((void **)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]);
|
||||
return opt_cache_.at(&kernel->opt);
|
||||
}
|
||||
/*!
|
||||
@brief Function for extracting kernels out of a given source-string
|
||||
|
||||
This can be important to enable pre-processor macros (or tunable parameters) that should only
|
||||
be defined within the scope of a single kernel function
|
||||
*/
|
||||
std::string extract_kernels(const std::string &str, const std::vector<std::string> &names) {
|
||||
if (names.empty())
|
||||
return str;
|
||||
@@ -94,50 +38,82 @@ std::string extract_kernels(const std::string &str, const std::vector<std::strin
|
||||
std::string name = it->str(1);
|
||||
kernels.push_back(std::make_tuple(name, pos, len));
|
||||
}
|
||||
|
||||
// check that all the kernels provided actually exist
|
||||
for (const std::string &name : names) {
|
||||
// check that str matches any string in kernels using std::any_of
|
||||
auto pred = [&name](const std::tuple<std::string, int, int> &t) { return std::get<0>(t) == name; };
|
||||
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
|
||||
if (!found)
|
||||
throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
|
||||
}
|
||||
|
||||
// extract functions
|
||||
// simple parsing logic to extract the declaration and body of each specified kernel
|
||||
std::string ret;
|
||||
for (const auto &k : kernels) {
|
||||
std::string name;
|
||||
int pos, len;
|
||||
std::tie(name, pos, len) = k;
|
||||
if (std::find(names.begin(), names.end(), name) != names.end()) {
|
||||
std::string def = str.substr(pos, str.size() - pos);
|
||||
int count, pos;
|
||||
// skip over declaration
|
||||
count = 1;
|
||||
pos = def.find('(');
|
||||
while (!(def[pos++] == ')' && count == 0) && pos < def.size()) {
|
||||
count += def[pos] == '(';
|
||||
count -= def[pos] == ')';
|
||||
}
|
||||
// skip over definition
|
||||
count = 1;
|
||||
pos = def.find('{', pos);
|
||||
while (!(def[pos++] == '}' && count == 0) && pos < def.size()) {
|
||||
count += def[pos] == '{';
|
||||
count -= def[pos] == '}';
|
||||
}
|
||||
ret += def.substr(0, pos);
|
||||
ret += '\n';
|
||||
if (std::find(names.begin(), names.end(), name) == names.end())
|
||||
continue;
|
||||
std::string def = str.substr(pos, str.size() - pos);
|
||||
// skip over declaration
|
||||
// by finding matching ')' for first '('
|
||||
int count = 1;
|
||||
pos = def.find('(');
|
||||
while (!(def[pos++] == ')' && count == 0) && pos < def.size()) {
|
||||
count += def[pos] == '(';
|
||||
count -= def[pos] == ')';
|
||||
}
|
||||
// skip over definition
|
||||
// by finding matching '{' for first '}'
|
||||
count = 1;
|
||||
pos = def.find('{', pos);
|
||||
while (!(def[pos++] == '}' && count == 0) && pos < def.size()) {
|
||||
count += def[pos] == '{';
|
||||
count -= def[pos] == '}';
|
||||
}
|
||||
ret += def.substr(0, pos);
|
||||
ret += '\n';
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
void init_triton(pybind11::module &m) {
|
||||
pybind11::module subm = m.def_submodule("triton");
|
||||
// bindings for triton classes
|
||||
pybind11::enum_<rt::arg_type>(subm, "arg_type")
|
||||
void init_triton_tools(py::module &&m) {
|
||||
m.def("extract_kernels", &extract_kernels);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::driver */
|
||||
/*****************************************************************************/
|
||||
|
||||
void init_triton_driver(py::module &&m) {
|
||||
// base device
|
||||
py::class_<drv::device>(m, "device");
|
||||
// cuda device
|
||||
py::class_<drv::cu_device, driver::device>(m, "cu_device")
|
||||
.def(py::init<CUdevice, bool>());
|
||||
// host device
|
||||
py::class_<drv::host_device, driver::device>(m, "host_device")
|
||||
.def(py::init<>());
|
||||
|
||||
// base stream
|
||||
py::class_<drv::stream>(m, "stream");
|
||||
// host stream
|
||||
py::class_<drv::host_stream, drv::stream>(m, "host_stream")
|
||||
.def(py::init<>());
|
||||
// cuda stream
|
||||
py::class_<drv::cu_stream, drv::stream>(m, "cu_stream")
|
||||
// py doesn't support opaque pointer (e.g., CUstream) so
|
||||
// we assume it has been converted to uint64_t
|
||||
.def(py::init([](uint64_t handle, bool take_ownership) {
|
||||
return std::unique_ptr<driver::cu_stream>(new driver::cu_stream((CUstream)handle, take_ownership));
|
||||
}));
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::runtime */
|
||||
/*****************************************************************************/
|
||||
void init_triton_runtime(py::module &&m) {
|
||||
// argument type
|
||||
py::enum_<rt::arg_type>(m, "arg_type")
|
||||
.value("int1", rt::INT1_T)
|
||||
.value("int8", rt::INT8_T)
|
||||
.value("int16", rt::INT16_T)
|
||||
@@ -147,23 +123,38 @@ void init_triton(pybind11::module &m) {
|
||||
.value("float", rt::FLOAT_T)
|
||||
.value("double", rt::DOUBLE_T)
|
||||
.value("buffer", rt::BUFFER_T);
|
||||
|
||||
pybind11::enum_<rt::asm_mode_t>(subm, "asm_mode")
|
||||
// assembly mode
|
||||
py::enum_<rt::asm_mode_t>(m, "asm_mode")
|
||||
.value("ptx", rt::ASM_NV_PTX)
|
||||
.value("sass", rt::ASM_NV_SASS);
|
||||
|
||||
pybind11::class_<rt::options_t>(subm, "options", pybind11::dynamic_attr())
|
||||
.def(pybind11::init<>())
|
||||
// compilation options
|
||||
py::class_<rt::options_t>(m, "options", py::dynamic_attr())
|
||||
.def(py::init<>())
|
||||
.def_readwrite("defines", &rt::options_t::defines)
|
||||
.def_readwrite("num_warps", &rt::options_t::num_warps);
|
||||
.def_readwrite("num_warps", &rt::options_t::num_warps)
|
||||
.def("__getattr__", [](rt::options_t *opt, const std::string &name) {
|
||||
return opt->D<int>(name);
|
||||
});
|
||||
// kernel
|
||||
py::class_<rt::kernel>(m, "kernel")
|
||||
.def("__call__", &rt::kernel::operator())
|
||||
.def_readonly("opt", &rt::kernel::opt);
|
||||
// tune conf
|
||||
py::class_<rt::config>(m, "config")
|
||||
.def(py::init<std::map<std::string, std::string>, int>(),
|
||||
py::arg("defines") = std::map<std::string, std::string>(),
|
||||
py::arg("num_warps"));
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
subm.def("extract_kernels", &extract_kernels);
|
||||
subm.def("get_fn_signature", &get_fn_signature);
|
||||
subm.def("register_fn", ®ister_fn);
|
||||
subm.def("delete_fn", &delete_fn);
|
||||
subm.def("make_op_id", &make_op_id);
|
||||
subm.def("cleanup", &cleanup);
|
||||
subm.def("autotune", &autotune, pybind11::return_value_policy::reference);
|
||||
subm.def("launch_kernel", &launch_kernel);
|
||||
// function
|
||||
py::class_<rt::function>(m, "function")
|
||||
.def(py::init<const std::string &, const rt::options_t &, driver::device *, const std::vector<rt::config> &, const std::vector<std::string> &>())
|
||||
.def("autotune", &rt::function::autotune, py::return_value_policy::reference_internal)
|
||||
.def("signature", &rt::function::get_signature);
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
py::module subm = m.def_submodule("triton");
|
||||
init_triton_driver(std::move(subm.def_submodule("driver")));
|
||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||
init_triton_tools(std::move(subm.def_submodule("tools")));
|
||||
}
|
||||
|
@@ -50,8 +50,9 @@ import torch
|
||||
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
torch.manual_seed(0)
|
||||
defines = {"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}
|
||||
triton.ops._matmul._kernels = dict()
|
||||
triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)]
|
||||
triton.ops._matmul._CONFIGS = [triton.config(defines=defines, num_warps=NWARP)]
|
||||
if M is None:
|
||||
M = TM
|
||||
if N is None:
|
||||
|
@@ -1,27 +1,21 @@
|
||||
import os
|
||||
import struct
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
import torch
|
||||
# C bindings
|
||||
import triton._C.libtriton.triton as _triton
|
||||
import triton._C.libtriton.torch_utils as _torch_utils
|
||||
# Make sure internal C resources are cleaned up upon exit
|
||||
import atexit
|
||||
|
||||
@atexit.register
|
||||
def cleanup():
|
||||
_triton.cleanup()
|
||||
|
||||
codes = {
|
||||
_triton.arg_type.int1: 'B', _triton.arg_type.int8: 'B', _triton.arg_type.int32: 'I', _triton.arg_type.int64: 'Q',
|
||||
_triton.arg_type.half: 'H', _triton.arg_type.float: 'f', _triton.arg_type.double: 'd', _triton.arg_type.buffer: 'P'
|
||||
_triton.runtime.arg_type.int1: 'B', _triton.runtime.arg_type.int8: 'B', _triton.runtime.arg_type.int32: 'I',
|
||||
_triton.runtime.arg_type.int64: 'Q', _triton.runtime.arg_type.half: 'H', _triton.runtime.arg_type.float: 'f',
|
||||
_triton.runtime.arg_type.double: 'd', _triton.runtime.arg_type.buffer: 'P'
|
||||
}
|
||||
|
||||
def th_to_triton(obj):
|
||||
tys = {
|
||||
torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long', torch.float16: 'half',
|
||||
torch.float32: 'float', torch.float64: 'double'
|
||||
torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long',\
|
||||
torch.float16: 'half', torch.float32: 'float', torch.float64: 'double'
|
||||
}
|
||||
if isinstance(obj, torch.dtype):
|
||||
return tys[obj]
|
||||
@@ -30,69 +24,54 @@ def th_to_triton(obj):
|
||||
def cdiv(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
def synchronize(device):
|
||||
dev_id = device.index
|
||||
dev_id = -1 if dev_id is None else dev_id
|
||||
_torch_utils.synchronize(dev_id)
|
||||
|
||||
def read(path, kernel_names:Optional[List]=None):
|
||||
def read(path, kernel_names: Optional[List] = None):
|
||||
if kernel_names is None:
|
||||
kernel_names = []
|
||||
with open(path, 'r') as f:
|
||||
source = f.read()
|
||||
source = _triton.extract_kernels(source, kernel_names)
|
||||
source = _triton.tools.extract_kernels(source, kernel_names)
|
||||
return source
|
||||
|
||||
class kernel:
|
||||
def __init__(self,
|
||||
src,
|
||||
device,
|
||||
defines: Optional[Dict]=None,
|
||||
num_warps:int=4,
|
||||
autotune_vals:Optional[List]=None,
|
||||
autotune_key:Optional[List]=None):
|
||||
config = _triton.runtime.config
|
||||
|
||||
class kernel:
|
||||
def __init__(self, src, device, defines: Optional[Dict] = None, num_warps: int = 4,
|
||||
autotune_vals: Optional[List] = None, autotune_key: Optional[List] = None):
|
||||
if defines is None:
|
||||
defines = {}
|
||||
if autotune_vals is None:
|
||||
autotune_vals = []
|
||||
if autotune_key is None:
|
||||
autotune_key = []
|
||||
|
||||
|
||||
# check if src is empty
|
||||
if src == '':
|
||||
raise ValueError('Kernel source code is empty')
|
||||
self.src = src
|
||||
self.opt = _triton.options()
|
||||
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
|
||||
self.opt.num_warps = num_warps
|
||||
# device
|
||||
assert device.type in ['cuda', 'cpu']
|
||||
if device.type == 'cuda':
|
||||
self.device = torch.cuda.current_device() if device.index is None else device.index
|
||||
self.device_id = torch.cuda.current_device() if device.index is None else device.index
|
||||
self.device = _triton.driver.cu_device(_torch_utils.cu_device(self.device_id), False)
|
||||
self.stream = _triton.driver.cu_stream(_torch_utils.cu_stream(self.device_id), False)
|
||||
if device.type == 'cpu':
|
||||
self.device = -1
|
||||
_torch_utils.register_device(self.device)
|
||||
_torch_utils.register_stream(self.device)
|
||||
# C++ function wrapper
|
||||
self.op_id = _triton.make_op_id()
|
||||
_torch_utils.set_device(self.device)
|
||||
_triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
|
||||
# debug mode
|
||||
self.is_debug = 'TRITON_DEBUG' in os.environ
|
||||
# signature
|
||||
arg_types = _triton.get_fn_signature(self.op_id)
|
||||
self.tys = ''.join([codes[x] for x in arg_types])
|
||||
self.device_id = -1
|
||||
self.device = _triton.driver.host_device()
|
||||
self.device = _triton.driver.host_stream()
|
||||
_torch_utils.set_device(self.device_id)
|
||||
# function
|
||||
self.opt = _triton.runtime.options()
|
||||
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
|
||||
self.opt.num_warps = num_warps
|
||||
# autotune_vals = [({}, 4)]
|
||||
self.fn = _triton.runtime.function(self.src, self.opt, self.device, autotune_vals, autotune_key)
|
||||
self.tys = ''.join([codes[x] for x in self.fn.signature()])
|
||||
|
||||
def __call__(self, *args, grid):
|
||||
_torch_utils.set_device(self.device)
|
||||
# make sure that the executing thread is on the right device
|
||||
_torch_utils.set_device(self.device_id)
|
||||
# pack parameters into a byte buffer
|
||||
params = struct.pack(self.tys, *args)
|
||||
opt = _triton.autotune(self.op_id, self.device, params, grid)
|
||||
kernel = self.fn.autotune(params, grid, self.stream)
|
||||
# run kernel
|
||||
grid = grid(opt)
|
||||
grid_0 = grid[0]
|
||||
grid_1 = 1 if len(grid) < 2 else grid[1]
|
||||
grid_2 = 1 if len(grid) < 3 else grid[2]
|
||||
_triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)
|
||||
grid = grid(kernel.opt)
|
||||
kernel(params, self.stream, grid)
|
||||
|
@@ -1,17 +1,17 @@
|
||||
__global__ void NAME(TYPE *A __readonly __noalias __aligned(16),
|
||||
TYPE *B __readonly __noalias __aligned(16),
|
||||
TYPE *C __noalias __aligned(16),
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8),
|
||||
long stride_za __multipleof(8),
|
||||
long stride_zb __multipleof(8),
|
||||
long stride_zc __multipleof(8),
|
||||
long stride_ha __multipleof(8),
|
||||
long stride_hb __multipleof(8),
|
||||
long stride_hc __multipleof(8),
|
||||
__global__ void NAME(TYPE *A __readonly __noalias,
|
||||
TYPE *B __readonly __noalias,
|
||||
TYPE *C __noalias,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
long stride_za,
|
||||
long stride_zb,
|
||||
long stride_zc,
|
||||
long stride_ha,
|
||||
long stride_hb,
|
||||
long stride_hc,
|
||||
int DS0, int DS1,
|
||||
int SDD_K __multipleof(16),
|
||||
int SDD_K,
|
||||
int SDD_off_width,
|
||||
int *lut, int *locks, int nlocks) {
|
||||
/* ---------------- */
|
||||
|
@@ -1,17 +1,16 @@
|
||||
__global__ void forward(TYPE *X __readonly __noalias __aligned(16),
|
||||
__global__ void forward(TYPE *X __readonly __noalias,
|
||||
float scale,
|
||||
int *LUT __readonly __noalias __aligned(16),
|
||||
TYPE *RPE __readonly __noalias __aligned(16),
|
||||
TYPE *KP_M __readonly __noalias __aligned(16),
|
||||
TYPE *ATTN_M __readonly __noalias __aligned(16),
|
||||
int *LUT __readonly __noalias,
|
||||
TYPE *RPE __readonly __noalias,
|
||||
TYPE *KP_M __readonly __noalias,
|
||||
TYPE *ATTN_M __readonly __noalias,
|
||||
int sizemax,
|
||||
long stride_zx __multipleof(4),
|
||||
long stride_zrpe __multipleof(BLOCK),
|
||||
int stride_hrpe __multipleof(BLOCK),
|
||||
int stride_srpe __multipleof(BLOCK),
|
||||
int stride_zkpm __multipleof(BLOCK),
|
||||
int stride_zattnm __multipleof(BLOCK))
|
||||
{
|
||||
long stride_zx,
|
||||
long stride_zrpe,
|
||||
int stride_hrpe,
|
||||
int stride_srpe,
|
||||
int stride_zkpm,
|
||||
int stride_zattnm) {
|
||||
int pidhm = get_program_id(0);
|
||||
int pidz = get_program_id(1);
|
||||
// create index ranges
|
||||
@@ -97,14 +96,13 @@ __global__ void forward(TYPE *X __readonly __noalias __aligned(16),
|
||||
*? (check)px = y / ysum;
|
||||
}
|
||||
|
||||
__global__ void backward(TYPE *X __readonly __noalias __aligned(16),
|
||||
__global__ void backward(TYPE *X __readonly __noalias,
|
||||
float scale,
|
||||
TYPE *DX __readonly __noalias __aligned(16),
|
||||
TYPE *DX __readonly __noalias,
|
||||
int *LUT,
|
||||
int sizemax,
|
||||
long stride_zx __multipleof(BLOCK),
|
||||
long stride_zdx __multipleof(BLOCK))
|
||||
{
|
||||
long stride_zx,
|
||||
long stride_zdx) {
|
||||
int pidhm = get_program_id(0);
|
||||
int pidz = get_program_id(1);
|
||||
// create index ranges
|
||||
|
@@ -1,126 +1,131 @@
|
||||
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
|
||||
TYPE *B __noalias __readonly __aligned(16),
|
||||
TYPE *C __noalias __aligned(16),
|
||||
float alpha,
|
||||
// equivalent matmul
|
||||
int M, int N, int K,
|
||||
// convolution properties
|
||||
int pad_h, int pad_w, int stride_h, int stride_w,
|
||||
// pointer increment
|
||||
int *ADELTA,
|
||||
// memory strides
|
||||
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
|
||||
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
|
||||
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int ridz = get_program_id(2);
|
||||
int gridx = M / TM;
|
||||
int gridy = N / TN;
|
||||
int rid = ridx + ridy * gridx;
|
||||
ridx = rid / gridy;
|
||||
ridy = rid % gridy;
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = ridz * K + 0 ... TK;
|
||||
__global__ void conv(TYPE *A __noalias __readonly,
|
||||
TYPE *B __noalias __readonly,
|
||||
TYPE *C __noalias,
|
||||
float alpha,
|
||||
// equivalent matmul
|
||||
int M, int N, int K,
|
||||
// convolution properties
|
||||
int pad_h, int pad_w, int stride_h, int stride_w,
|
||||
// pointer increment
|
||||
int *ADELTA,
|
||||
// memory strides
|
||||
int lda_z, int lda_ci, int lda_h, int lda_w,
|
||||
int ldb_ci, int ldb_r, int ldb_s, int ldb_co,
|
||||
int ldc_z, int ldc_co, int ldc_p, int ldc_q) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int ridz = get_program_id(2);
|
||||
int gridx = M / TM;
|
||||
int gridy = N / TN;
|
||||
int rid = ridx + ridy * gridx;
|
||||
ridx = rid / gridy;
|
||||
ridy = rid % gridy;
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = ridz * K + 0 ... TK;
|
||||
|
||||
// unpack aggregate rows
|
||||
// m = (z, p, q)
|
||||
int rq[TM] = rm % QQ;
|
||||
int rzp[TM] = rm / QQ;
|
||||
int rp[TM] = rzp % PP;
|
||||
int rz[TM] = rzp / PP;
|
||||
// unpack aggregate reduction
|
||||
// k = (ci, r, s)
|
||||
int rs [TK] = rk % SS;
|
||||
int rcir[TK] = rk / SS;
|
||||
int rr [TK] = rcir % RR;
|
||||
int rci [TK] = rcir / RR;
|
||||
// unpack aggregate rows
|
||||
// m = (z, p, q)
|
||||
int rq[TM] = rm % QQ;
|
||||
int rzp[TM] = rm / QQ;
|
||||
int rp[TM] = rzp % PP;
|
||||
int rz[TM] = rzp / PP;
|
||||
// unpack aggregate reduction
|
||||
// k = (ci, r, s)
|
||||
int rs[TK] = rk % SS;
|
||||
int rcir[TK] = rk / SS;
|
||||
int rr[TK] = rcir % RR;
|
||||
int rci[TK] = rcir / RR;
|
||||
|
||||
// padding / striding
|
||||
int rh_0[TM] = rp * stride_h - pad_h;
|
||||
int rw_0[TM] = rq * stride_w - pad_w;
|
||||
int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :];
|
||||
int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :];
|
||||
// padding / striding
|
||||
int rh_0[TM] = rp * stride_h - pad_h;
|
||||
int rw_0[TM] = rq * stride_w - pad_w;
|
||||
int rh[TM, TK] = rh_0[:, newaxis] + rr [newaxis, :];
|
||||
int rw[TM, TK] = rw_0[:, newaxis] + rs [newaxis, :];
|
||||
|
||||
// pointers to lhs
|
||||
int offa[TM, TK] = rz [:, newaxis] * lda_z +
|
||||
rci[newaxis, :] * lda_ci +
|
||||
rh * lda_h +
|
||||
rw * 1;
|
||||
TYPE* pa[TM, TK] = A + offa;
|
||||
int* padelta[TK] = ADELTA + rk;
|
||||
// pointers to rhs
|
||||
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
|
||||
rr [:, newaxis] * ldb_r +
|
||||
rs [:, newaxis] * ldb_s +
|
||||
rn [newaxis, :] * 1;
|
||||
TYPE* pb[TK, TN] = B + offb;
|
||||
// pointers to lhs
|
||||
int offa[TM, TK] = rz[:, newaxis] * lda_z +
|
||||
rci [newaxis, :] * lda_ci +
|
||||
rh * lda_h +
|
||||
rw * 1;
|
||||
TYPE *pa[TM, TK] = A + offa;
|
||||
int *padelta[TK] = ADELTA + rk;
|
||||
// pointers to rhs
|
||||
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
|
||||
rr
|
||||
[:, newaxis] * ldb_r +
|
||||
rs
|
||||
[:, newaxis] * ldb_s +
|
||||
rn [newaxis, :] * 1;
|
||||
TYPE *pb[TK, TN] = B + offb;
|
||||
|
||||
// prefetches operands
|
||||
bool checkam[TM, TK] = rm[:, newaxis] < M;
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
int total = 0;
|
||||
// prefetches operands
|
||||
bool checkam[TM, TK] = rm[:, newaxis] < M;
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
int total = 0;
|
||||
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
acc += a @ b;
|
||||
// increment A
|
||||
int adelta[TK] = *padelta;
|
||||
padelta += TK;
|
||||
pa += adelta[newaxis, :];
|
||||
// bounds-checking A
|
||||
rk += TK;
|
||||
rs = rk % SS;
|
||||
rcir = rk / SS;
|
||||
rr = rcir % RR;
|
||||
rh = rh_0[:, newaxis] + rr[newaxis, :];
|
||||
rw = rw_0[:, newaxis] + rs[newaxis, :];
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
// increment B
|
||||
pb += TK * ldb_s;
|
||||
// bounds-checking B
|
||||
bool checkb[TK, TN] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for (int k = K; k > 0; k -= TK) {
|
||||
acc += a @b;
|
||||
// increment A
|
||||
int adelta[TK] = *padelta;
|
||||
padelta += TK;
|
||||
pa += adelta [newaxis, :];
|
||||
// bounds-checking A
|
||||
rk += TK;
|
||||
rs = rk % SS;
|
||||
rcir = rk / SS;
|
||||
rr = rcir % RR;
|
||||
rh = rh_0[:, newaxis] + rr [newaxis, :];
|
||||
rw = rw_0[:, newaxis] + rs [newaxis, :];
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
// increment B
|
||||
pb += TK * ldb_s;
|
||||
// bounds-checking B
|
||||
bool checkb[TK, TN] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = *? (checkb)pb;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
rm = ridx * TM + 0 ... TM;
|
||||
rn = ridy * TN + 0 ... TN;
|
||||
rq = rm % QQ;
|
||||
rzp = rm / QQ;
|
||||
rp = rzp % PP;
|
||||
rz = rzp / PP;
|
||||
int offc[TM, TN] = rz [:, newaxis] * ldc_z +
|
||||
rn [newaxis, :] * ldc_co+
|
||||
rp [:, newaxis] * ldc_p +
|
||||
rq [:, newaxis] * 1;
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N;
|
||||
// epilogue
|
||||
rm = ridx * TM + 0 ... TM;
|
||||
rn = ridy * TN + 0 ... TN;
|
||||
rq = rm % QQ;
|
||||
rzp = rm / QQ;
|
||||
rp = rzp % PP;
|
||||
rz = rzp / PP;
|
||||
int offc[TM, TN] = rz[:, newaxis] * ldc_z +
|
||||
rn [newaxis, :] * ldc_co +
|
||||
rp
|
||||
[:, newaxis] * ldc_p +
|
||||
rq
|
||||
[:, newaxis] * 1;
|
||||
TYPE *pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = rm[:, newaxis] < M && rn [newaxis, :] < N;
|
||||
|
||||
#if (TZ==1)
|
||||
*?(checkc) pc = c;
|
||||
#if (TZ == 1)
|
||||
*? (checkc)pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
|
||||
;
|
||||
int count = *pcount;
|
||||
if (count == 0)
|
||||
*? (checkc)pc = c;
|
||||
else
|
||||
*? (checkc)pc = c + *? (checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
@@ -1,8 +1,4 @@
|
||||
__global__ void forward(TYPE *logit __aligned(16),
|
||||
TYPE *modified_logit __aligned(16),
|
||||
long *indices __readonly,
|
||||
TYPE *result __aligned(16),
|
||||
int n_cols __multipleof(N_COLS_MULT)) {
|
||||
__global__ void forward(TYPE *logit, TYPE *modified_logit, long *indices, TYPE *result, int n_cols) {
|
||||
int row = get_program_id(0);
|
||||
|
||||
bool check[TILE] = ((0 ... TILE) < n_cols);
|
||||
@@ -19,10 +15,7 @@ __global__ void forward(TYPE *logit __aligned(16),
|
||||
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
|
||||
}
|
||||
|
||||
__global__ void backward(TYPE *neg_logprobs __aligned(16),
|
||||
long *indices __aligned(16),
|
||||
TYPE *dneg_logprobs __aligned(16),
|
||||
int n_cols __multipleof(N_COLS_MULT)) {
|
||||
__global__ void backward(TYPE *neg_logprobs, long *indices, TYPE *dneg_logprobs, int n_cols) {
|
||||
|
||||
int row = get_program_id(0);
|
||||
// pointer arithmetic
|
||||
|
@@ -1,16 +1,12 @@
|
||||
#define STM 8
|
||||
#define STN 8
|
||||
|
||||
__global__ void matmul(TYPE *A __noalias __readonly __aligned(16),
|
||||
TYPE *B __noalias __readonly __aligned(16),
|
||||
TYPE *C __noalias __aligned(16),
|
||||
__global__ void matmul(TYPE *A __noalias __readonly,
|
||||
TYPE *B __noalias __readonly,
|
||||
TYPE *C __noalias,
|
||||
float alpha,
|
||||
int M,
|
||||
int N,
|
||||
int K __multipleof(16),
|
||||
int lda __multipleof(LDA_POW2_DIV),
|
||||
int ldb __multipleof(LDB_POW2_DIV),
|
||||
int ldc __multipleof(LDC_POW2_DIV),
|
||||
int M, int N, int K,
|
||||
int lda, int ldb, int ldc,
|
||||
int *locks) {
|
||||
// prologue
|
||||
int pid = get_program_id(0);
|
||||
|
@@ -6,18 +6,18 @@ class _matmul(torch.autograd.Function):
|
||||
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
|
||||
|
||||
_DEFAULT_CONFIGS = [
|
||||
({"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, 4),
|
||||
({'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, 4),
|
||||
({'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, 4),
|
||||
({'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 4),
|
||||
({'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, 4),
|
||||
({'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 4),
|
||||
({'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 2),
|
||||
({'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 2),
|
||||
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
|
||||
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
|
||||
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
|
||||
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
|
||||
triton.config(defines={"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, num_warps=4),
|
||||
triton.config(defines={'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, num_warps=4),
|
||||
triton.config(defines={'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, num_warps=4),
|
||||
triton.config(defines={'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
|
||||
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
|
||||
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
|
||||
triton.config(defines={'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=2),
|
||||
triton.config(defines={'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=2),
|
||||
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4),
|
||||
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4),
|
||||
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4),
|
||||
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4),
|
||||
]
|
||||
_CONFIGS = _DEFAULT_CONFIGS
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import os
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||
@@ -77,8 +78,12 @@ class Mark:
|
||||
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
|
||||
|
||||
def run(self, result_path, with_plot):
|
||||
for bench in self.benchmarks:
|
||||
self._run(bench, result_path, with_plot)
|
||||
with open(os.path.join(result_path, "results.html"), "w") as html:
|
||||
html.write("<html><body>\n")
|
||||
for bench in self.benchmarks:
|
||||
self._run(bench, result_path, with_plot)
|
||||
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
||||
html.write("</body></html>\n")
|
||||
|
||||
def perf_report(benchmarks):
|
||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||
|
76
python/tutorials/01-vector-add.py
Normal file
76
python/tutorials/01-vector-add.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
# source-code for Triton compute kernel
|
||||
# here we just copy-paste the above code without the extensive comments.
|
||||
# you may prefer to store it in a .c file and load it from there instead.
|
||||
_src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
// program id
|
||||
int pid = get_program_id(0);
|
||||
// create arrays of pointers
|
||||
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||
float* pz[BLOCK] = z + offset;
|
||||
float* px[BLOCK] = x + offset;
|
||||
float* py[BLOCK] = y + offset;
|
||||
// bounds checking
|
||||
bool check[BLOCK] = offset < N;
|
||||
// write-back
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
"""
|
||||
# This function returns a callable `triton.kernel` object
|
||||
# created from the above source code.
|
||||
# For portability, we maintain a cache of kernels for different `torch.device`
|
||||
# We compile the kernel with -DBLOCK=1024
|
||||
_kernels = dict()
|
||||
|
||||
def make_add_kernel(device):
|
||||
if device not in _kernels:
|
||||
defines = {'BLOCK': 1024}
|
||||
autotune_vals = [({'BLOCK': '1024'}, 4), ({'BLOCK': '2048'}, 4)]
|
||||
autotune_key = ["N"]
|
||||
_kernels[device] = triton.kernel(_src, device=device, defines=defines, autotune_vals=autotune_vals,
|
||||
autotune_key=autotune_key)
|
||||
return _kernels[device]
|
||||
|
||||
# This is a standard torch custom autograd Function
|
||||
# The only difference is that we can now use the above kernel
|
||||
# in the `forward` and `backward` functions.`
|
||||
class _add(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
# constraints of the op
|
||||
assert x.dtype == torch.float32
|
||||
# *allocate output*
|
||||
z = torch.empty_like(x)
|
||||
# *create launch grid*:
|
||||
# this is a function which takes compilation parameters `opt`
|
||||
# as input and returns a tuple of int (i.e., launch grid) for the kernel.
|
||||
# triton.cdiv is a shortcut for ceil division:
|
||||
# triton.cdiv(a, b) = (a + b - 1) // b
|
||||
grid = lambda opt: (triton.cdiv(z.shape[0], opt.BLOCK), )
|
||||
# *launch kernel*:
|
||||
# pointer to the data of torch tensors can be retrieved with
|
||||
# the `.data_ptr()` method
|
||||
kernel = make_add_kernel(z.device)
|
||||
kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), z.shape[0], grid=grid)
|
||||
return z
|
||||
|
||||
# Just like we standard PyTorch ops
|
||||
# We use the `.apply` method to create a
|
||||
# callable object for our function
|
||||
add = _add.apply
|
||||
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(32, device='cuda')
|
||||
y = torch.rand(32, device='cuda')
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
print(za)
|
||||
print(zb)
|
||||
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
|
||||
|
||||
th_ms = triton.testing.do_bench(lambda: x + y)
|
||||
tr_ms = triton.testing.do_bench(lambda: add(x, y))
|
||||
print(th_ms, tr_ms)
|
Reference in New Issue
Block a user