[RUNTIME] Added auto-alignment mechanism (#71)
This PR adds an automatic memory alignment mechanism in the Triton runtime. Specifically, the JIT compiler detects the alignment (in bytes) of each pointer argument as well as the largest power of two divisor (between 1 and 16) of each integer argument. Proper .aligned and .multipleof attributes are then added to the Triton-IR on-the-fly for all auto-tunable kernels. There is a cache that remembers all the kernels compiled for each possible configuration. This PR also includes substantial cleaning of the Python API. This adds 2-3us overhead, mostly due to accessing integer #defines from the auto-tuned compilation options. The previous solution was slightly faster but hacky and potentially unsafe, so this is preferred for now.
This commit is contained in:
committed by
Philippe Tillet
parent
ff62f7fffc
commit
62835a0979
@@ -10,6 +10,7 @@
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
#include "triton/ir/builder.h"
|
#include "triton/ir/builder.h"
|
||||||
#include "triton/ir/metadata.h"
|
#include "triton/ir/metadata.h"
|
||||||
|
#include "triton/ir/context.h"
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
|
|
||||||
@@ -60,7 +61,7 @@ private:
|
|||||||
void push_function(function *fn) { functions_.push_back(fn); }
|
void push_function(function *fn) { functions_.push_back(fn); }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
module(const std::string &name, context &ctx);
|
module(const std::string &name);
|
||||||
context& get_context();
|
context& get_context();
|
||||||
builder& get_builder();
|
builder& get_builder();
|
||||||
// Setters
|
// Setters
|
||||||
@@ -94,7 +95,7 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::string name_;
|
std::string name_;
|
||||||
context &context_;
|
context context_;
|
||||||
builder builder_;
|
builder builder_;
|
||||||
std::map<val_key_t, value*> values_;
|
std::map<val_key_t, value*> values_;
|
||||||
std::map<val_key_t, type*> types_;
|
std::map<val_key_t, type*> types_;
|
||||||
|
@@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
namespace ir{
|
namespace ir{
|
||||||
@@ -17,73 +18,8 @@ namespace driver{
|
|||||||
|
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
|
||||||
enum arg_type {
|
|
||||||
INT1_T,
|
|
||||||
INT8_T,
|
|
||||||
INT16_T,
|
|
||||||
INT32_T,
|
|
||||||
INT64_T,
|
|
||||||
HALF_T,
|
|
||||||
FLOAT_T,
|
|
||||||
DOUBLE_T,
|
|
||||||
BUFFER_T
|
|
||||||
};
|
|
||||||
|
|
||||||
arg_type convert(ir::type *ty);
|
|
||||||
|
|
||||||
|
|
||||||
inline size_t size_of(arg_type ty){
|
|
||||||
switch(ty){
|
|
||||||
case INT1_T: return 1;
|
|
||||||
case INT8_T: return 1;
|
|
||||||
case INT16_T: return 2;
|
|
||||||
case INT32_T: return 4;
|
|
||||||
case INT64_T: return 8;
|
|
||||||
case HALF_T: return 2;
|
|
||||||
case FLOAT_T: return 4;
|
|
||||||
case DOUBLE_T: return 8;
|
|
||||||
case BUFFER_T: return 8;
|
|
||||||
default: throw std::runtime_error("unknown type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool is_int_type(arg_type ty){
|
|
||||||
return ty == INT1_T || ty == INT8_T || ty == INT16_T ||
|
|
||||||
ty == INT32_T || ty == INT64_T;
|
|
||||||
}
|
|
||||||
|
|
||||||
class arg {
|
|
||||||
public:
|
|
||||||
union value_t {
|
|
||||||
bool int1;
|
|
||||||
int8_t int8;
|
|
||||||
int16_t int16;
|
|
||||||
int32_t int32;
|
|
||||||
int64_t int64;
|
|
||||||
uint16_t fp16;
|
|
||||||
float fp32;
|
|
||||||
double fp64;
|
|
||||||
driver::buffer* buf;
|
|
||||||
};
|
|
||||||
|
|
||||||
public:
|
|
||||||
// construct from primitive types
|
|
||||||
arg(arg_type ty, value_t val): ty_(ty) { val_ = val; }
|
|
||||||
arg(int32_t x): ty_(INT32_T) { val_.int32 = x; }
|
|
||||||
arg(int64_t x): ty_(INT64_T) { val_.int64 = x; }
|
|
||||||
arg(float x): ty_(FLOAT_T) { val_.fp32 = x; }
|
|
||||||
arg(double x): ty_(DOUBLE_T) { val_.fp64 = x; }
|
|
||||||
arg(driver::buffer* x): ty_(BUFFER_T) { val_.buf = x; }
|
|
||||||
// accessors
|
|
||||||
arg_type type() const { return ty_; }
|
|
||||||
void* data() const { return (void*)&val_; }
|
|
||||||
driver::buffer* buffer() const { return val_.buf; }
|
|
||||||
|
|
||||||
|
|
||||||
private:
|
|
||||||
arg_type ty_;
|
|
||||||
value_t val_;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -11,6 +11,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
// codegen
|
// codegen
|
||||||
|
#include "triton/ir/function.h"
|
||||||
#include "triton/ir/context.h"
|
#include "triton/ir/context.h"
|
||||||
#include "triton/runtime/arg.h"
|
#include "triton/runtime/arg.h"
|
||||||
#include "triton/runtime/error.h"
|
#include "triton/runtime/error.h"
|
||||||
@@ -37,63 +38,86 @@ class context;
|
|||||||
namespace triton{
|
namespace triton{
|
||||||
namespace runtime{
|
namespace runtime{
|
||||||
|
|
||||||
typedef std::vector<size_t> grid_t;
|
|
||||||
typedef std::map<std::string, size_t> params_t;
|
/* ------------------------- */
|
||||||
template<typename T> inline T convert(const std::string& name);
|
/* Compilation options */
|
||||||
template<> inline long convert<long>(const std::string& name) { return std::stol(name); }
|
/* ------------------------- */
|
||||||
template<> inline int convert<int>(const std::string& name) { return std::stoi(name); }
|
|
||||||
|
struct options_t {
|
||||||
|
template<class T>
|
||||||
|
T D(const std::string& name) const {
|
||||||
|
return std::stoi(defines.at(name));
|
||||||
|
}
|
||||||
|
std::unordered_map<std::string, std::string> defines;
|
||||||
|
int num_warps;
|
||||||
|
};
|
||||||
|
|
||||||
|
/* ------------------------- */
|
||||||
|
/* Runtime arguments */
|
||||||
|
/* ------------------------- */
|
||||||
|
|
||||||
|
enum arg_type {
|
||||||
|
INT1_T,
|
||||||
|
INT8_T,
|
||||||
|
INT16_T,
|
||||||
|
INT32_T,
|
||||||
|
INT64_T,
|
||||||
|
HALF_T,
|
||||||
|
FLOAT_T,
|
||||||
|
DOUBLE_T,
|
||||||
|
BUFFER_T
|
||||||
|
};
|
||||||
|
|
||||||
|
inline size_t size_of(arg_type ty){
|
||||||
|
switch(ty){
|
||||||
|
case INT1_T : return 1;
|
||||||
|
case INT8_T : return 1;
|
||||||
|
case INT16_T : return 2;
|
||||||
|
case INT32_T : return 4;
|
||||||
|
case INT64_T : return 8;
|
||||||
|
case HALF_T : return 2;
|
||||||
|
case FLOAT_T : return 4;
|
||||||
|
case DOUBLE_T: return 8;
|
||||||
|
case BUFFER_T: return 8;
|
||||||
|
default: throw std::runtime_error("unknown type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
void add_arg(std::stringstream& ss, T arg) {
|
void add_arg(std::stringstream& ss, T arg) {
|
||||||
ss.write((char*)&arg, sizeof(T));
|
ss.write((char*)&arg, sizeof(T));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* ------------------------- */
|
||||||
|
/* ------------------------- */
|
||||||
|
|
||||||
enum asm_mode_t {
|
enum asm_mode_t {
|
||||||
ASM_LLIR,
|
ASM_LLIR,
|
||||||
ASM_NV_PTX,
|
ASM_NV_PTX,
|
||||||
ASM_NV_SASS
|
ASM_NV_SASS
|
||||||
};
|
};
|
||||||
|
|
||||||
struct options_t {
|
|
||||||
template<class T>
|
|
||||||
T D(const std::string& name) const {
|
|
||||||
return convert<T>(defines.at(name));
|
|
||||||
}
|
|
||||||
std::unordered_map<std::string, std::string> defines;
|
|
||||||
int num_warps;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
/* ------------------------- */
|
|
||||||
|
|
||||||
class kernel{
|
class kernel{
|
||||||
private:
|
public:
|
||||||
static std::string preheader();
|
typedef std::vector<size_t> grid_t;
|
||||||
static arg_type convert(ir::type *ty);
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
kernel(const std::string& src, const options_t& opt, driver::device *device);
|
static std::shared_ptr<ir::module> src_to_ir(const std::string& src, const options_t& opt);
|
||||||
void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector<size_t>& grid) const;
|
static std::tuple<std::shared_ptr<driver::module>,
|
||||||
// getters
|
std::shared_ptr<driver::kernel>,
|
||||||
const std::vector<arg_type>& get_sig() const { return sig_; }
|
size_t> ir_to_bin(ir::module& ir, driver::device *dev, const options_t &opt);
|
||||||
const std::vector<std::string>& get_arg_names() const { return arg_names_; }
|
|
||||||
std::string get_asm(asm_mode_t mode);
|
|
||||||
|
|
||||||
private:
|
public:
|
||||||
void init_ir (const std::string &src);
|
kernel(const std::string& src, const options_t& opt, driver::device *device, const std::map<int, triton::ir::attribute> &attrs = {});
|
||||||
void init_ker();
|
void operator()(const std::string& args, driver::stream *stream, const grid_t& grid) const;
|
||||||
void init_sig();
|
std::string get_asm(asm_mode_t mode);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
const options_t opt;
|
const options_t opt;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
driver::device* dev_;
|
driver::device* dev_;
|
||||||
// signature
|
|
||||||
std::vector<arg_type> sig_;
|
|
||||||
std::vector<std::string> arg_names_;
|
|
||||||
// triton context for parsing
|
|
||||||
ir::context ctx_;
|
|
||||||
// handles
|
// handles
|
||||||
std::shared_ptr<ir::module> ir_;
|
std::shared_ptr<ir::module> ir_;
|
||||||
std::shared_ptr<driver::module> mod_;
|
std::shared_ptr<driver::module> mod_;
|
||||||
@@ -102,36 +126,37 @@ private:
|
|||||||
size_t shared_mem_;
|
size_t shared_mem_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct config {
|
||||||
|
std::map<std::string, std::string> defines;
|
||||||
|
int num_warps;
|
||||||
|
};
|
||||||
|
|
||||||
class function {
|
class function {
|
||||||
public:
|
public:
|
||||||
typedef std::function<grid_t(const options_t&)> grid_fn_ty;
|
typedef std::function<kernel::grid_t(const options_t&)> grid_fn_ty;
|
||||||
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
|
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
|
||||||
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
|
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
|
||||||
typedef std::vector<std::pair<std::map<std::string, std::string>, int>> autotune_vals_t;
|
typedef std::vector<config> autotune_confs_t;
|
||||||
|
|
||||||
private:
|
|
||||||
static void do_loop_nest(std::vector<size_t> const & ranges,
|
|
||||||
std::function<void(std::vector<size_t> const &)> const & f);
|
|
||||||
public:
|
public:
|
||||||
function(const std::string& src, const options_t& opt, driver::device *device,
|
function(const std::string& src, const options_t& opt, driver::device *device,
|
||||||
const autotune_vals_t& autotune_vals = {}, const std::vector<std::string> &autotune_key = {});
|
const std::vector<config>& tune_confs = {}, const std::vector<std::string> &tune_key = {});
|
||||||
void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
kernel* autotune(const std::string& args, const grid_fn_ty& grid, driver::stream *stream);
|
||||||
void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream);
|
void operator()(const std::string& args, const grid_fn_ty& grid, driver::stream *stream);
|
||||||
// auto-tuning
|
const std::vector<arg_type> get_signature() { return sig_; }
|
||||||
cache_t::iterator find_in_cache(void* args, size_t args_size);
|
|
||||||
kernel* autotune(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
|
||||||
// getters
|
|
||||||
const std::vector<kernel_pair_t> get_kernels() { return kernels_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void init_kernels(const std::string& src, const options_t& opt, const autotune_vals_t& autotune_vals, driver::device *device);
|
std::map<std::vector<uint64_t>, std::vector<std::shared_ptr<kernel>>> kernels_;
|
||||||
|
|
||||||
private:
|
|
||||||
std::vector<kernel_pair_t> kernels_;
|
|
||||||
std::map<std::vector<uint64_t>, kernel*> cache_;
|
std::map<std::vector<uint64_t>, kernel*> cache_;
|
||||||
|
std::vector<arg_type> sig_;
|
||||||
|
std::vector<int> align_idxs_;
|
||||||
|
std::vector<int> int_idxs_;
|
||||||
std::vector<int> key_idxs_;
|
std::vector<int> key_idxs_;
|
||||||
std::vector<int> arg_size_;
|
std::vector<int> arg_size_;
|
||||||
std::vector<int> arg_off_;
|
std::vector<int> arg_off_;
|
||||||
|
std::vector<options_t> opts_;
|
||||||
|
std::string src_;
|
||||||
|
driver::device* device_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -10,8 +10,8 @@ namespace triton{
|
|||||||
namespace ir{
|
namespace ir{
|
||||||
|
|
||||||
/* Module */
|
/* Module */
|
||||||
module::module(const std::string &name, context &ctx)
|
module::module(const std::string &name)
|
||||||
: name_(name), context_(ctx), builder_(ctx) {
|
: name_(name), builder_(context_) {
|
||||||
sealed_blocks_.insert(nullptr);
|
sealed_blocks_.insert(nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -40,7 +40,6 @@
|
|||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
std::mutex mut;
|
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
@@ -49,22 +48,9 @@ namespace runtime {
|
|||||||
/* --------------------------------- */
|
/* --------------------------------- */
|
||||||
/* --------------------------------- */
|
/* --------------------------------- */
|
||||||
|
|
||||||
arg_type kernel::convert(ir::type *ty) {
|
std::shared_ptr<ir::module> kernel::src_to_ir(const std::string& _src, const options_t& opt) {
|
||||||
if(ty->is_integer_ty(1)) return INT1_T;
|
std::string src =
|
||||||
if(ty->is_integer_ty(8)) return INT8_T;
|
R"(
|
||||||
if(ty->is_integer_ty(16)) return INT16_T;
|
|
||||||
if(ty->is_integer_ty(32)) return INT32_T;
|
|
||||||
if(ty->is_integer_ty(64)) return INT64_T;
|
|
||||||
if(ty->is_half_ty()) return HALF_T;
|
|
||||||
if(ty->is_float_ty()) return FLOAT_T;
|
|
||||||
if(ty->is_double_ty()) return DOUBLE_T;
|
|
||||||
if(ty->is_pointer_ty()) return BUFFER_T;
|
|
||||||
throw std::runtime_error("unknown type");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
std::string kernel::preheader() {
|
|
||||||
return R"(
|
|
||||||
#define bool _Bool
|
#define bool _Bool
|
||||||
#define true 1
|
#define true 1
|
||||||
#define false 0
|
#define false 0
|
||||||
@@ -116,9 +102,7 @@ typedef short int16;
|
|||||||
typedef int int32;
|
typedef int int32;
|
||||||
typedef long int64;
|
typedef long int64;
|
||||||
)";
|
)";
|
||||||
}
|
src += _src;
|
||||||
|
|
||||||
void kernel::init_ir(const std::string& src) {
|
|
||||||
// pre-process
|
// pre-process
|
||||||
TokenSequence tokens;
|
TokenSequence tokens;
|
||||||
Preprocessor cpp(&src, true);
|
Preprocessor cpp(&src, true);
|
||||||
@@ -129,21 +113,21 @@ void kernel::init_ir(const std::string& src) {
|
|||||||
Parser parser(tokens);
|
Parser parser(tokens);
|
||||||
parser.Parse();
|
parser.Parse();
|
||||||
// ast -> triton-ir
|
// ast -> triton-ir
|
||||||
ir::module* module = new ir::module("", ctx_);
|
auto ret = std::make_shared<ir::module>("");
|
||||||
Generator gen(&parser);
|
Generator gen(&parser);
|
||||||
gen.Gen(module);
|
gen.Gen(&*ret);
|
||||||
ir_.reset(module);
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void kernel::init_ker(){
|
std::tuple<std::shared_ptr<driver::module>,
|
||||||
// triton-ir -> binary
|
std::shared_ptr<driver::kernel>,
|
||||||
std::unique_ptr<driver::module> bin;
|
size_t> kernel::ir_to_bin(ir::module &ir, driver::device* dev, const options_t& opt) {
|
||||||
std::unique_ptr<codegen::target> target = dev_->make_target();
|
|
||||||
// generate llvm code
|
// generate llvm code
|
||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
std::string name = ir_->get_function_list()[0]->get_name();
|
std::string name = ir.get_function_list()[0]->get_name();
|
||||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||||
// optimizations
|
// optimizations
|
||||||
|
std::unique_ptr<codegen::target> target = dev->make_target();
|
||||||
bool cts_use_async = target->as_nvidia()->sm() >= 80;
|
bool cts_use_async = target->as_nvidia()->sm() >= 80;
|
||||||
// create passes
|
// create passes
|
||||||
codegen::analysis::align align;
|
codegen::analysis::align align;
|
||||||
@@ -162,73 +146,61 @@ void kernel::init_ker(){
|
|||||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
|
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
|
||||||
// run passes
|
// run passes
|
||||||
dce.run(*ir_);
|
dce.run(ir);
|
||||||
pipeline.run(*ir_);
|
pipeline.run(ir);
|
||||||
dce.run(*ir_);
|
dce.run(ir);
|
||||||
disassociate.run(*ir_);
|
disassociate.run(ir);
|
||||||
dce.run(*ir_);
|
dce.run(ir);
|
||||||
align.run(*ir_);
|
align.run(ir);
|
||||||
axes.run(*ir_);
|
axes.run(ir);
|
||||||
layouts.run(*ir_);
|
layouts.run(ir);
|
||||||
peephole.run(*ir_);
|
peephole.run(ir);
|
||||||
dce.run(*ir_);
|
dce.run(ir);
|
||||||
if(target->is_gpu())
|
if(target->is_gpu())
|
||||||
cts.run(*ir_);
|
cts.run(ir);
|
||||||
align.run(*ir_);
|
align.run(ir);
|
||||||
axes.run(*ir_);
|
axes.run(ir);
|
||||||
layouts.run(*ir_);
|
layouts.run(ir);
|
||||||
coalesce.run(*ir_);
|
coalesce.run(ir);
|
||||||
dce.run(*ir_);
|
dce.run(ir);
|
||||||
align.run(*ir_);
|
align.run(ir);
|
||||||
dce.run(*ir_);
|
dce.run(ir);
|
||||||
if(target->is_gpu()){
|
if(target->is_gpu()){
|
||||||
reassociate.run(*ir_);
|
reassociate.run(ir);
|
||||||
cts.run(*ir_);
|
cts.run(ir);
|
||||||
}
|
}
|
||||||
dce.run(*ir_);
|
dce.run(ir);
|
||||||
// ir::print(*ir_, std::cout);
|
align.run(ir);
|
||||||
align.run(*ir_);
|
axes.run(ir);
|
||||||
axes.run(*ir_);
|
layouts.run(ir);
|
||||||
layouts.run(*ir_);
|
peephole.run(ir);
|
||||||
peephole.run(*ir_);
|
dce.run(ir);
|
||||||
dce.run(*ir_);
|
align.run(ir);
|
||||||
align.run(*ir_);
|
axes.run(ir);
|
||||||
axes.run(*ir_);
|
layouts.run(ir);
|
||||||
layouts.run(*ir_);
|
swizzle.run(ir);
|
||||||
swizzle.run(*ir_);
|
liveness.run(ir);
|
||||||
liveness.run(*ir_);
|
allocation.run(ir);
|
||||||
allocation.run(*ir_);
|
barriers.run(ir);
|
||||||
shared_mem_ = allocation.allocated_size();
|
isel.visit(ir, *llvm);
|
||||||
// if(allocation.allocated_size() > dev_->max_shared_memory())
|
std::shared_ptr<driver::module> mod(driver::module::create(dev, std::move(llvm)));
|
||||||
// throw exception::out_of_shared_memory();
|
std::shared_ptr<driver::kernel> ker(driver::kernel::create(&*mod, name.c_str()));
|
||||||
barriers.run(*ir_);
|
size_t shared_mem = allocation.allocated_size();
|
||||||
isel.visit(*ir_, *llvm);
|
return std::make_tuple(mod, ker, shared_mem);
|
||||||
//if(res->spilled() > 256)
|
|
||||||
// throw exception::out_of_registers();
|
|
||||||
mod_.reset(driver::module::create(dev_, std::move(llvm)));
|
|
||||||
ker_.reset(driver::kernel::create(&*mod_, name.c_str()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void kernel::init_sig() {
|
kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev, const std::map<int, ir::attribute> &attrs):
|
||||||
ir::function* fn = ir_->get_function_list()[0];
|
|
||||||
ir::function_type* ty = fn->get_fn_type();
|
|
||||||
for(size_t i = 0; i < ty->get_num_params(); i++){
|
|
||||||
sig_.push_back(convert(ty->get_param_ty(i)));
|
|
||||||
if(!fn->has_attr(i+1))
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev):
|
|
||||||
opt(opt), dev_(dev) {
|
opt(opt), dev_(dev) {
|
||||||
init_ir(preheader() + src);
|
// compile to Triton IR
|
||||||
init_ker();
|
ir_ = src_to_ir(src, opt);
|
||||||
init_sig();
|
// add attributes
|
||||||
for(auto arg: ir_->get_function_list()[0]->args())
|
for(const auto&x: attrs)
|
||||||
arg_names_.push_back(arg->get_name());
|
ir_->get_function_list()[0]->add_attr(x.first, x.second);
|
||||||
|
// compile to binary
|
||||||
|
std::tie(mod_, ker_, shared_mem_) = ir_to_bin(*ir_, dev, opt);
|
||||||
}
|
}
|
||||||
|
|
||||||
void kernel::operator()(void *args, size_t args_size, driver::stream *stream, const std::vector<size_t>& _grid) const{
|
void kernel::operator()(const std::string& args, driver::stream *stream, const std::vector<size_t>& _grid) const{
|
||||||
// set grid
|
// set grid
|
||||||
if(_grid.size() > 3)
|
if(_grid.size() > 3)
|
||||||
throw std::runtime_error("grid size must be no greater than 3");
|
throw std::runtime_error("grid size must be no greater than 3");
|
||||||
@@ -236,7 +208,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co
|
|||||||
for(size_t i = 0; i < 3; i++)
|
for(size_t i = 0; i < 3; i++)
|
||||||
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
|
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
|
||||||
// enqueue
|
// enqueue
|
||||||
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size, shared_mem_);
|
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, (void*)args.data(), args.size(), shared_mem_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string kernel::get_asm(asm_mode_t mode) {
|
std::string kernel::get_asm(asm_mode_t mode) {
|
||||||
@@ -282,124 +254,124 @@ std::string kernel::get_asm(asm_mode_t mode) {
|
|||||||
/* --------------------------------- */
|
/* --------------------------------- */
|
||||||
/* --------------------------------- */
|
/* --------------------------------- */
|
||||||
|
|
||||||
void function::do_loop_nest(std::vector<size_t> const & ranges,
|
|
||||||
std::function<void(std::vector<size_t> const &)> const & f){
|
|
||||||
size_t D = ranges.size();
|
|
||||||
std::vector<size_t> values(D, 0);
|
|
||||||
size_t i = D - 1;
|
|
||||||
while(true){
|
|
||||||
f(values);
|
|
||||||
while(values[i]++ == ranges[i] - 1){
|
|
||||||
if(i == 0)
|
|
||||||
return;
|
|
||||||
values[i--] = 0;
|
|
||||||
}
|
|
||||||
i = D - 1; options_t opt;
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void function::init_kernels(const std::string& src, const options_t& opt,
|
|
||||||
const autotune_vals_t& confs, driver::device *device) {
|
|
||||||
// list of all possible configs
|
|
||||||
// just augment `opt` with each define of `confs`
|
|
||||||
// and override warp count
|
|
||||||
size_t num_opts = std::max(confs.size(), (size_t)1);
|
|
||||||
std::vector<options_t> opts(num_opts, opt);
|
|
||||||
for(size_t i = 0; i < confs.size(); i++){
|
|
||||||
opts[i].defines.insert(confs[i].first.begin(), confs[i].first.end());
|
|
||||||
opts[i].num_warps = confs[i].second;
|
|
||||||
}
|
|
||||||
// compile all possible configs
|
|
||||||
// compilation errors (e.g., too much shared mem)
|
|
||||||
// will populate `err`
|
|
||||||
std::vector<std::pair<options_t, std::string>> err;
|
|
||||||
for(const options_t& opt: opts) {
|
|
||||||
try{
|
|
||||||
kernels_.push_back({opt, std::make_shared<kernel>(src, opt, device)});
|
|
||||||
}catch(const exception::base& e){
|
|
||||||
err.push_back({opt, e.what()});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// throw an exception if `err` is not empty
|
|
||||||
if(kernels_.empty()){
|
|
||||||
std::ostringstream dbg;
|
|
||||||
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
|
|
||||||
for(auto x: err){
|
|
||||||
dbg << "[ ";
|
|
||||||
dbg << x.first.num_warps << ", ";
|
|
||||||
dbg << "{ ";
|
|
||||||
for(const auto& y: x.first.defines)
|
|
||||||
dbg << '"' << y.first << "\"= \"" << y.second << "\", ";
|
|
||||||
dbg << " } ] -> " << x.second << std::endl;
|
|
||||||
}
|
|
||||||
throw exception::no_valid_configuration(dbg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream* stream) {
|
|
||||||
// fast path -- no autotuning necessary
|
|
||||||
if(kernels_.size() == 1)
|
|
||||||
return &*kernels_.begin()->second;
|
|
||||||
// auto-tuning key
|
|
||||||
std::vector<uint64_t> key(key_idxs_.size());
|
|
||||||
for(size_t i = 0; i < key.size(); i++){
|
|
||||||
int idx = key_idxs_[i];
|
|
||||||
std::memcpy((void*)&key[i], (void*)((char*)args + arg_off_[idx]), arg_size_[idx]);
|
|
||||||
}
|
|
||||||
auto it = cache_.find(key);
|
|
||||||
if(it != cache_.end())
|
|
||||||
return it->second;
|
|
||||||
// run auto-tuner
|
|
||||||
double best_ts = INFINITY;
|
|
||||||
kernel* ret = nullptr;
|
|
||||||
for(auto &x : kernels_){
|
|
||||||
kernel* current = &*x.second;
|
|
||||||
auto grid = grid_fn(x.first);
|
|
||||||
while(grid.size() < 3)
|
|
||||||
grid.push_back(1);
|
|
||||||
double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); },
|
|
||||||
stream, 5, 20);
|
|
||||||
ret = (ts < best_ts) ? current : ret;
|
|
||||||
best_ts = std::min(ts, best_ts);
|
|
||||||
}
|
|
||||||
stream->synchronize();
|
|
||||||
it = cache_.insert({key, ret}).first;
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
function::function(const std::string& src, const options_t &opt, driver::device *device,
|
function::function(const std::string& src, const options_t &opt, driver::device *device,
|
||||||
const autotune_vals_t& autotune_vals, const std::vector<std::string>& autotune_key) {
|
const std::vector<config> &tune_confs, const std::vector<std::string>& tune_key)
|
||||||
// pre-compile all kernels
|
: src_(src), device_(device) {
|
||||||
init_kernels(src, opt, autotune_vals, device);
|
// kernel options
|
||||||
// find indices of autotune keys
|
size_t num_opts = std::max(tune_confs.size(), (size_t)1);
|
||||||
auto arg_names = kernels_.at(0).second->get_arg_names();
|
opts_ = std::vector<options_t>(num_opts, opt);
|
||||||
for(const std::string& name: autotune_key){
|
for(size_t i = 0; i < tune_confs.size(); i++){
|
||||||
auto it = std::find(arg_names.begin(), arg_names.end(), name);
|
opts_[i].defines.insert(tune_confs[i].defines.begin(), tune_confs[i].defines.end());
|
||||||
if(it == arg_names.end())
|
opts_[i].num_warps = tune_confs[i].num_warps;
|
||||||
throw std::runtime_error(name + " is not a valid argument name");
|
|
||||||
key_idxs_.push_back(std::distance(arg_names.begin(), it));
|
|
||||||
}
|
}
|
||||||
|
std::shared_ptr<ir::module> ir = kernel::src_to_ir(src, opts_[0]);
|
||||||
|
std::vector<ir::argument*> args = ir->get_function_list()[0]->args();
|
||||||
|
// signature
|
||||||
|
auto convert = [](ir::type *ty) {
|
||||||
|
if(ty->is_integer_ty(1)) return INT1_T;
|
||||||
|
if(ty->is_integer_ty(8)) return INT8_T;
|
||||||
|
if(ty->is_integer_ty(16)) return INT16_T;
|
||||||
|
if(ty->is_integer_ty(32)) return INT32_T;
|
||||||
|
if(ty->is_integer_ty(64)) return INT64_T;
|
||||||
|
if(ty->is_half_ty()) return HALF_T;
|
||||||
|
if(ty->is_float_ty()) return FLOAT_T;
|
||||||
|
if(ty->is_double_ty()) return DOUBLE_T;
|
||||||
|
if(ty->is_pointer_ty()) return BUFFER_T;
|
||||||
|
throw std::runtime_error("unknown type");
|
||||||
|
};
|
||||||
|
for(ir::argument* arg: args)
|
||||||
|
sig_.push_back(convert(arg->get_type()));
|
||||||
|
// find indices of autotune keys
|
||||||
|
for(const std::string& name: tune_key){
|
||||||
|
auto pred = [&](ir::argument* arg) { return arg->get_name() == name; };
|
||||||
|
auto it = std::find_if(args.begin(), args.end(), pred);
|
||||||
|
if(it == args.end())
|
||||||
|
throw std::runtime_error(name + " is not a valid argument name");
|
||||||
|
key_idxs_.push_back(std::distance(args.begin(), it));
|
||||||
|
}
|
||||||
|
// find indices of pointer
|
||||||
|
for(size_t i = 0; i < args.size(); i++)
|
||||||
|
if(args[i]->get_type()->is_pointer_ty() ||
|
||||||
|
args[i]->get_type()->is_integer_ty())
|
||||||
|
align_idxs_.push_back(i);
|
||||||
// argument size and offset
|
// argument size and offset
|
||||||
auto tys = kernels_.at(0).second->get_sig();
|
|
||||||
size_t curr = 0;
|
size_t curr = 0;
|
||||||
for(arg_type ty: tys){
|
for(arg_type ty: sig_){
|
||||||
arg_size_.push_back(size_of(ty));
|
arg_size_.push_back(size_of(ty));
|
||||||
arg_off_.push_back(curr);
|
arg_off_.push_back(curr);
|
||||||
curr += arg_size_.back();
|
curr += arg_size_.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void function::operator()(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
uint64_t pow2_divisor(uint64_t N){
|
||||||
runtime::kernel* fn = autotune(args, args_size, grid_fn, stream);
|
if(N % 16 == 0) return 16;
|
||||||
(*fn)(args, args_size, stream, grid_fn(fn->opt));
|
if(N % 8 == 0) return 8;
|
||||||
|
if(N % 4 == 0) return 4;
|
||||||
|
if(N % 2 == 0) return 2;
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
void function::operator()(void* args, size_t args_size, const grid_t& grid, driver::stream* stream) {
|
kernel* function::autotune(const std::string &args, const grid_fn_ty& grid_fn, driver::stream* stream) {
|
||||||
return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream);
|
// align key
|
||||||
|
std::vector<uint64_t> rt_key(align_idxs_.size(), 0);
|
||||||
|
for(size_t i = 0; i < align_idxs_.size(); i++){
|
||||||
|
int idx = align_idxs_[i];
|
||||||
|
uint64_t tmp = 0;
|
||||||
|
std::memcpy((void*)&tmp, (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]);
|
||||||
|
rt_key[i] = pow2_divisor(tmp);
|
||||||
|
}
|
||||||
|
// auto-tuning key
|
||||||
|
std::vector<uint64_t> at_key(key_idxs_.size(), 0);
|
||||||
|
for(size_t i = 0; i < at_key.size(); i++){
|
||||||
|
int idx = key_idxs_[i];
|
||||||
|
std::memcpy((void*)&at_key[i], (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]);
|
||||||
|
}
|
||||||
|
// cache key
|
||||||
|
std::vector<uint64_t> cache_key;
|
||||||
|
cache_key.reserve(rt_key.size() + at_key.size());
|
||||||
|
cache_key.insert(cache_key.end(), rt_key.begin(), rt_key.end());
|
||||||
|
cache_key.insert(cache_key.end(), at_key.begin(), at_key.end());
|
||||||
|
auto it = cache_.find(cache_key);
|
||||||
|
if(it != cache_.end())
|
||||||
|
return it->second;
|
||||||
|
// compile kernels
|
||||||
|
if(kernels_.find(rt_key) == kernels_.end()){
|
||||||
|
std::map<int, ir::attribute> attrs;
|
||||||
|
for(size_t i = 0; i < align_idxs_.size(); i++){
|
||||||
|
bool is_ptr = sig_[align_idxs_[i]] == BUFFER_T;
|
||||||
|
attrs.insert({align_idxs_[i] + 1, ir::attribute(is_ptr ? ir::aligned : ir::multiple_of, rt_key[i])});
|
||||||
|
}
|
||||||
|
for(const options_t& opt: opts_)
|
||||||
|
kernels_[rt_key].emplace_back(new kernel(src_, opt, device_, attrs));
|
||||||
|
}
|
||||||
|
// run auto-tuner
|
||||||
|
double best_ts = INFINITY;
|
||||||
|
auto& kernels = kernels_.at(rt_key);
|
||||||
|
kernel* ret = nullptr;
|
||||||
|
if(kernels.size() == 1)
|
||||||
|
ret = &*kernels.back();
|
||||||
|
else{
|
||||||
|
for(auto ¤t : kernels_.at(rt_key)){
|
||||||
|
auto grid = grid_fn(current->opt);
|
||||||
|
while(grid.size() < 3)
|
||||||
|
grid.push_back(1);
|
||||||
|
double ts = tools::bench([&]() { (*current)(args, stream, grid); },
|
||||||
|
stream, 5, 20);
|
||||||
|
ret = (ts < best_ts) ? &*current : ret;
|
||||||
|
best_ts = std::min(ts, best_ts);
|
||||||
|
}
|
||||||
|
stream->synchronize();
|
||||||
|
}
|
||||||
|
it = cache_.insert({cache_key, ret}).first;
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void function::operator()(const std::string& args, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
||||||
|
runtime::kernel* fn = autotune(args, grid_fn, stream);
|
||||||
|
(*fn)(args, stream, grid_fn(fn->opt));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,12 +1,19 @@
|
|||||||
import triton
|
import triton
|
||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
# square benchmarks
|
def rounded_linspace(low, high, steps, div):
|
||||||
|
ret = torch.linspace(low, high, steps)
|
||||||
|
ret = (ret.int() + div - 1) // div * div
|
||||||
|
ret = torch.unique(ret)
|
||||||
|
return list(map(int, ret))
|
||||||
|
|
||||||
|
# Square benchmarks
|
||||||
nt = {False: "n", True: "t"}
|
nt = {False: "n", True: "t"}
|
||||||
square_confs = [
|
square_confs = [
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=["M", "N", "K"],
|
x_names=["M", "N", "K"],
|
||||||
x_vals=[512 * i for i in range(1, 16)],
|
x_vals=rounded_linspace(512, 8192, 17, 128),
|
||||||
y_name="provider",
|
y_name="provider",
|
||||||
y_vals=["torch", "triton", "cutlass"],
|
y_vals=["torch", "triton", "cutlass"],
|
||||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||||
@@ -17,16 +24,29 @@ square_confs = [
|
|||||||
) for AT in [False, True] for BT in [False, True]
|
) for AT in [False, True] for BT in [False, True]
|
||||||
]
|
]
|
||||||
|
|
||||||
@triton.testing.perf_report(square_confs)
|
# Transformer training benchmarks
|
||||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
transformer_confs = [
|
||||||
import os
|
triton.testing.Benchmark(
|
||||||
|
x_names=[x],
|
||||||
|
x_vals = rounded_linspace(NK//16, NK, 33, 128),
|
||||||
|
y_name="provider",
|
||||||
|
y_vals=["torch", "triton", "cutlass"],
|
||||||
|
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||||
|
ylabel="TFLOPS",
|
||||||
|
loglog=False,
|
||||||
|
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
||||||
|
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
||||||
|
) for NK in [8192]\
|
||||||
|
for i, x in enumerate(["N", "K"])\
|
||||||
|
for M in [2048]
|
||||||
|
]
|
||||||
|
|
||||||
|
@triton.testing.perf_report(square_confs)
|
||||||
|
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40):
|
||||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||||
if AT:
|
if AT: a = a.t()
|
||||||
a = a.t()
|
if BT: b = b.t()
|
||||||
if BT:
|
|
||||||
b = b.t()
|
|
||||||
num_flops = 2 * M * N * K
|
num_flops = 2 * M * N * K
|
||||||
if provider == "torch":
|
if provider == "torch":
|
||||||
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
||||||
@@ -40,7 +60,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
|||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
# run program specified by CUTLASS_PROFILER env variable
|
# run program specified by CUTLASS_PROFILER env variable
|
||||||
layout_a = "column" if AT else "row"
|
layout_a = "column" if AT else "row"
|
||||||
layout_b = "column" if BT else "row"
|
layout_b = "column" if BT else "row"
|
||||||
@@ -61,6 +80,7 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
|||||||
f"--warmup-iterations={warmup}",
|
f"--warmup-iterations={warmup}",
|
||||||
f"--profiling-iterations={rep}",
|
f"--profiling-iterations={rep}",
|
||||||
f"--output={fname}",
|
f"--output={fname}",
|
||||||
|
"--dist=uniform,min:0,max:1,scale:-1",
|
||||||
"--verbose=false",
|
"--verbose=false",
|
||||||
]
|
]
|
||||||
# run cmd
|
# run cmd
|
||||||
@@ -70,6 +90,3 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
|||||||
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
|
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
|
||||||
return cutlass_tflops
|
return cutlass_tflops
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
bench_op.run()
|
|
||||||
|
@@ -5,38 +5,16 @@
|
|||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
|
|
||||||
std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
|
|
||||||
|
|
||||||
namespace torch_utils {
|
namespace torch_utils {
|
||||||
|
|
||||||
void register_device(int64_t dev_id) {
|
uint64_t cu_device(int64_t dev_id) {
|
||||||
if (tt_devices.find(dev_id) != tt_devices.end())
|
|
||||||
return;
|
|
||||||
triton::driver::device *device;
|
|
||||||
if (dev_id >= 0) {
|
|
||||||
CUdevice handle;
|
CUdevice handle;
|
||||||
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
|
triton::driver::dispatch::cuDeviceGet(&handle, dev_id);
|
||||||
device = new triton::driver::cu_device(handle, false);
|
return (uint64_t)handle;
|
||||||
} else
|
|
||||||
device = new triton::driver::host_device();
|
|
||||||
tt_devices[dev_id].reset(device);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void register_stream(int64_t dev_id) {
|
uint64_t cu_stream(int64_t dev_id) {
|
||||||
if (tt_streams.find(dev_id) != tt_streams.end())
|
return (uint64_t)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||||
return;
|
|
||||||
triton::driver::stream *stream;
|
|
||||||
if (dev_id >= 0) {
|
|
||||||
CUstream handle = (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream();
|
|
||||||
stream = new triton::driver::cu_stream(handle, false);
|
|
||||||
} else
|
|
||||||
stream = new triton::driver::host_stream();
|
|
||||||
tt_streams[dev_id].reset(stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
void synchronize(int64_t dev_id) {
|
|
||||||
tt_streams[dev_id]->synchronize();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_device(int64_t dev_id) {
|
void set_device(int64_t dev_id) {
|
||||||
@@ -44,23 +22,11 @@ void set_device(int64_t dev_id) {
|
|||||||
C10_CUDA_CHECK(cudaSetDevice(dev_id));
|
C10_CUDA_CHECK(cudaSetDevice(dev_id));
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor move_out_of_pool(torch::Tensor x) {
|
|
||||||
if (x.nbytes() == 0)
|
|
||||||
return torch::empty_like(x);
|
|
||||||
void *data;
|
|
||||||
cudaMalloc(&data, x.nbytes());
|
|
||||||
auto ret = torch::from_blob((void *)data, x.sizes(), x.strides(), [data](void *ptr) { cudaFree(data); }, x.options());
|
|
||||||
ret.copy_(x);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace torch_utils
|
} // namespace torch_utils
|
||||||
|
|
||||||
void init_torch_utils(pybind11::module &m) {
|
void init_torch_utils(pybind11::module &m) {
|
||||||
pybind11::module subm = m.def_submodule("torch_utils");
|
pybind11::module subm = m.def_submodule("torch_utils");
|
||||||
subm.def("register_device", &torch_utils::register_device);
|
subm.def("cu_device", &torch_utils::cu_device);
|
||||||
subm.def("register_stream", &torch_utils::register_stream);
|
subm.def("cu_stream", &torch_utils::cu_stream);
|
||||||
subm.def("set_device", &torch_utils::set_device);
|
subm.def("set_device", &torch_utils::set_device);
|
||||||
subm.def("synchronize", &torch_utils::synchronize);
|
|
||||||
subm.def("move_out_of_pool", &torch_utils::move_out_of_pool);
|
|
||||||
}
|
}
|
@@ -1,10 +1,4 @@
|
|||||||
#include "triton/driver/stream.h"
|
#include "triton/driver/stream.h"
|
||||||
#include "triton/ir/function.h"
|
|
||||||
#include "triton/ir/module.h"
|
|
||||||
#include "triton/lang/code_gen.h"
|
|
||||||
#include "triton/lang/cpp.h"
|
|
||||||
#include "triton/lang/parser.h"
|
|
||||||
#include "triton/runtime/arg.h"
|
|
||||||
#include "triton/runtime/function.h"
|
#include "triton/runtime/function.h"
|
||||||
#include <pybind11/buffer_info.h>
|
#include <pybind11/buffer_info.h>
|
||||||
#include <pybind11/functional.h>
|
#include <pybind11/functional.h>
|
||||||
@@ -13,72 +7,22 @@
|
|||||||
#include <regex>
|
#include <regex>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
using namespace triton;
|
using namespace triton;
|
||||||
namespace rt = triton::runtime;
|
namespace rt = triton::runtime;
|
||||||
namespace drv = triton::driver;
|
namespace drv = triton::driver;
|
||||||
namespace lng = triton::lang;
|
|
||||||
|
|
||||||
std::unordered_map<const rt::options_t *, pybind11::object> opt_cache_;
|
/*****************************************************************************/
|
||||||
std::map<int, std::shared_ptr<rt::function>> id_fn_map;
|
/* Python bindings for triton::tools */
|
||||||
extern std::map<int, std::shared_ptr<triton::driver::device>> tt_devices;
|
/*****************************************************************************/
|
||||||
extern std::map<int, std::shared_ptr<triton::driver::stream>> tt_streams;
|
|
||||||
|
|
||||||
/* Function utilities */
|
/*!
|
||||||
|
@brief Function for extracting kernels out of a given source-string
|
||||||
void register_fn(int op_id, int dev_id,
|
|
||||||
const std::string &src, const rt::options_t &opt,
|
|
||||||
const rt::function::autotune_vals_t &autotune_vals,
|
|
||||||
const std::vector<std::string> &autotune_key) {
|
|
||||||
if (id_fn_map.find(op_id) == id_fn_map.end()) {
|
|
||||||
id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key));
|
|
||||||
}
|
|
||||||
for (const auto &k : id_fn_map[op_id]->get_kernels()) {
|
|
||||||
const rt::options_t *opt = &k.first;
|
|
||||||
pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference);
|
|
||||||
for (auto x : opt->defines)
|
|
||||||
if (std::all_of(x.second.begin(), x.second.end(), ::isdigit))
|
|
||||||
obj.attr(x.first.c_str()) = std::stoi(x.second);
|
|
||||||
opt_cache_[&k.second->opt] = obj;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void delete_fn(int op_id) {
|
|
||||||
id_fn_map.erase(op_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
void cleanup() {
|
|
||||||
id_fn_map.clear();
|
|
||||||
opt_cache_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t make_op_id() {
|
|
||||||
return id_fn_map.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<rt::arg_type> get_fn_signature(size_t op_id) {
|
|
||||||
return id_fn_map[op_id]->get_kernels()[0].second->get_sig();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments
|
|
||||||
// as a string constructed with struct.pack in python
|
|
||||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string &args, size_t grid_0, size_t grid_1, size_t grid_2) {
|
|
||||||
rt::function *fn = id_fn_map.at(op_id).get();
|
|
||||||
(*fn)((void **)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]);
|
|
||||||
}
|
|
||||||
|
|
||||||
pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string &args, const rt::function::grid_fn_ty &grid) {
|
|
||||||
rt::function *fn = id_fn_map.at(op_id).get();
|
|
||||||
auto wrapper = [&grid](const rt::options_t &opt) {
|
|
||||||
pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference);
|
|
||||||
for (auto x : opt.defines)
|
|
||||||
if (std::all_of(x.second.begin(), x.second.end(), ::isdigit))
|
|
||||||
obj.attr(x.first.c_str()) = std::stoi(x.second);
|
|
||||||
return grid(*obj.cast<rt::options_t *>());
|
|
||||||
};
|
|
||||||
rt::kernel *kernel = fn->autotune((void **)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]);
|
|
||||||
return opt_cache_.at(&kernel->opt);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
This can be important to enable pre-processor macros (or tunable parameters) that should only
|
||||||
|
be defined within the scope of a single kernel function
|
||||||
|
*/
|
||||||
std::string extract_kernels(const std::string &str, const std::vector<std::string> &names) {
|
std::string extract_kernels(const std::string &str, const std::vector<std::string> &names) {
|
||||||
if (names.empty())
|
if (names.empty())
|
||||||
return str;
|
return str;
|
||||||
@@ -94,32 +38,32 @@ std::string extract_kernels(const std::string &str, const std::vector<std::strin
|
|||||||
std::string name = it->str(1);
|
std::string name = it->str(1);
|
||||||
kernels.push_back(std::make_tuple(name, pos, len));
|
kernels.push_back(std::make_tuple(name, pos, len));
|
||||||
}
|
}
|
||||||
|
// check that all the kernels provided actually exist
|
||||||
for (const std::string &name : names) {
|
for (const std::string &name : names) {
|
||||||
// check that str matches any string in kernels using std::any_of
|
|
||||||
auto pred = [&name](const std::tuple<std::string, int, int> &t) { return std::get<0>(t) == name; };
|
auto pred = [&name](const std::tuple<std::string, int, int> &t) { return std::get<0>(t) == name; };
|
||||||
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
|
bool found = std::any_of(kernels.begin(), kernels.end(), pred);
|
||||||
if (!found)
|
if (!found)
|
||||||
throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
|
throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str);
|
||||||
}
|
}
|
||||||
|
// simple parsing logic to extract the declaration and body of each specified kernel
|
||||||
// extract functions
|
|
||||||
std::string ret;
|
std::string ret;
|
||||||
for (const auto &k : kernels) {
|
for (const auto &k : kernels) {
|
||||||
std::string name;
|
std::string name;
|
||||||
int pos, len;
|
int pos, len;
|
||||||
std::tie(name, pos, len) = k;
|
std::tie(name, pos, len) = k;
|
||||||
if (std::find(names.begin(), names.end(), name) != names.end()) {
|
if (std::find(names.begin(), names.end(), name) == names.end())
|
||||||
|
continue;
|
||||||
std::string def = str.substr(pos, str.size() - pos);
|
std::string def = str.substr(pos, str.size() - pos);
|
||||||
int count, pos;
|
|
||||||
// skip over declaration
|
// skip over declaration
|
||||||
count = 1;
|
// by finding matching ')' for first '('
|
||||||
|
int count = 1;
|
||||||
pos = def.find('(');
|
pos = def.find('(');
|
||||||
while (!(def[pos++] == ')' && count == 0) && pos < def.size()) {
|
while (!(def[pos++] == ')' && count == 0) && pos < def.size()) {
|
||||||
count += def[pos] == '(';
|
count += def[pos] == '(';
|
||||||
count -= def[pos] == ')';
|
count -= def[pos] == ')';
|
||||||
}
|
}
|
||||||
// skip over definition
|
// skip over definition
|
||||||
|
// by finding matching '{' for first '}'
|
||||||
count = 1;
|
count = 1;
|
||||||
pos = def.find('{', pos);
|
pos = def.find('{', pos);
|
||||||
while (!(def[pos++] == '}' && count == 0) && pos < def.size()) {
|
while (!(def[pos++] == '}' && count == 0) && pos < def.size()) {
|
||||||
@@ -129,15 +73,47 @@ std::string extract_kernels(const std::string &str, const std::vector<std::strin
|
|||||||
ret += def.substr(0, pos);
|
ret += def.substr(0, pos);
|
||||||
ret += '\n';
|
ret += '\n';
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void init_triton(pybind11::module &m) {
|
void init_triton_tools(py::module &&m) {
|
||||||
pybind11::module subm = m.def_submodule("triton");
|
m.def("extract_kernels", &extract_kernels);
|
||||||
// bindings for triton classes
|
}
|
||||||
pybind11::enum_<rt::arg_type>(subm, "arg_type")
|
|
||||||
|
/*****************************************************************************/
|
||||||
|
/* Python bindings for triton::driver */
|
||||||
|
/*****************************************************************************/
|
||||||
|
|
||||||
|
void init_triton_driver(py::module &&m) {
|
||||||
|
// base device
|
||||||
|
py::class_<drv::device>(m, "device");
|
||||||
|
// cuda device
|
||||||
|
py::class_<drv::cu_device, driver::device>(m, "cu_device")
|
||||||
|
.def(py::init<CUdevice, bool>());
|
||||||
|
// host device
|
||||||
|
py::class_<drv::host_device, driver::device>(m, "host_device")
|
||||||
|
.def(py::init<>());
|
||||||
|
|
||||||
|
// base stream
|
||||||
|
py::class_<drv::stream>(m, "stream");
|
||||||
|
// host stream
|
||||||
|
py::class_<drv::host_stream, drv::stream>(m, "host_stream")
|
||||||
|
.def(py::init<>());
|
||||||
|
// cuda stream
|
||||||
|
py::class_<drv::cu_stream, drv::stream>(m, "cu_stream")
|
||||||
|
// py doesn't support opaque pointer (e.g., CUstream) so
|
||||||
|
// we assume it has been converted to uint64_t
|
||||||
|
.def(py::init([](uint64_t handle, bool take_ownership) {
|
||||||
|
return std::unique_ptr<driver::cu_stream>(new driver::cu_stream((CUstream)handle, take_ownership));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
/*****************************************************************************/
|
||||||
|
/* Python bindings for triton::runtime */
|
||||||
|
/*****************************************************************************/
|
||||||
|
void init_triton_runtime(py::module &&m) {
|
||||||
|
// argument type
|
||||||
|
py::enum_<rt::arg_type>(m, "arg_type")
|
||||||
.value("int1", rt::INT1_T)
|
.value("int1", rt::INT1_T)
|
||||||
.value("int8", rt::INT8_T)
|
.value("int8", rt::INT8_T)
|
||||||
.value("int16", rt::INT16_T)
|
.value("int16", rt::INT16_T)
|
||||||
@@ -147,23 +123,38 @@ void init_triton(pybind11::module &m) {
|
|||||||
.value("float", rt::FLOAT_T)
|
.value("float", rt::FLOAT_T)
|
||||||
.value("double", rt::DOUBLE_T)
|
.value("double", rt::DOUBLE_T)
|
||||||
.value("buffer", rt::BUFFER_T);
|
.value("buffer", rt::BUFFER_T);
|
||||||
|
// assembly mode
|
||||||
pybind11::enum_<rt::asm_mode_t>(subm, "asm_mode")
|
py::enum_<rt::asm_mode_t>(m, "asm_mode")
|
||||||
.value("ptx", rt::ASM_NV_PTX)
|
.value("ptx", rt::ASM_NV_PTX)
|
||||||
.value("sass", rt::ASM_NV_SASS);
|
.value("sass", rt::ASM_NV_SASS);
|
||||||
|
// compilation options
|
||||||
pybind11::class_<rt::options_t>(subm, "options", pybind11::dynamic_attr())
|
py::class_<rt::options_t>(m, "options", py::dynamic_attr())
|
||||||
.def(pybind11::init<>())
|
.def(py::init<>())
|
||||||
.def_readwrite("defines", &rt::options_t::defines)
|
.def_readwrite("defines", &rt::options_t::defines)
|
||||||
.def_readwrite("num_warps", &rt::options_t::num_warps);
|
.def_readwrite("num_warps", &rt::options_t::num_warps)
|
||||||
|
.def("__getattr__", [](rt::options_t *opt, const std::string &name) {
|
||||||
|
return opt->D<int>(name);
|
||||||
|
});
|
||||||
|
// kernel
|
||||||
|
py::class_<rt::kernel>(m, "kernel")
|
||||||
|
.def("__call__", &rt::kernel::operator())
|
||||||
|
.def_readonly("opt", &rt::kernel::opt);
|
||||||
|
// tune conf
|
||||||
|
py::class_<rt::config>(m, "config")
|
||||||
|
.def(py::init<std::map<std::string, std::string>, int>(),
|
||||||
|
py::arg("defines") = std::map<std::string, std::string>(),
|
||||||
|
py::arg("num_warps"));
|
||||||
|
|
||||||
// hooks into triton constructs since frameworks may not use pybind11
|
// function
|
||||||
subm.def("extract_kernels", &extract_kernels);
|
py::class_<rt::function>(m, "function")
|
||||||
subm.def("get_fn_signature", &get_fn_signature);
|
.def(py::init<const std::string &, const rt::options_t &, driver::device *, const std::vector<rt::config> &, const std::vector<std::string> &>())
|
||||||
subm.def("register_fn", ®ister_fn);
|
.def("autotune", &rt::function::autotune, py::return_value_policy::reference_internal)
|
||||||
subm.def("delete_fn", &delete_fn);
|
.def("signature", &rt::function::get_signature);
|
||||||
subm.def("make_op_id", &make_op_id);
|
}
|
||||||
subm.def("cleanup", &cleanup);
|
|
||||||
subm.def("autotune", &autotune, pybind11::return_value_policy::reference);
|
void init_triton(py::module &m) {
|
||||||
subm.def("launch_kernel", &launch_kernel);
|
py::module subm = m.def_submodule("triton");
|
||||||
|
init_triton_driver(std::move(subm.def_submodule("driver")));
|
||||||
|
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||||
|
init_triton_tools(std::move(subm.def_submodule("tools")));
|
||||||
}
|
}
|
||||||
|
@@ -50,8 +50,9 @@ import torch
|
|||||||
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
|
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
|
||||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
defines = {"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}
|
||||||
triton.ops._matmul._kernels = dict()
|
triton.ops._matmul._kernels = dict()
|
||||||
triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)]
|
triton.ops._matmul._CONFIGS = [triton.config(defines=defines, num_warps=NWARP)]
|
||||||
if M is None:
|
if M is None:
|
||||||
M = TM
|
M = TM
|
||||||
if N is None:
|
if N is None:
|
||||||
|
@@ -1,27 +1,21 @@
|
|||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
from typing import Optional, Dict, List
|
from typing import Optional, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
# C bindings
|
# C bindings
|
||||||
import triton._C.libtriton.triton as _triton
|
import triton._C.libtriton.triton as _triton
|
||||||
import triton._C.libtriton.torch_utils as _torch_utils
|
import triton._C.libtriton.torch_utils as _torch_utils
|
||||||
# Make sure internal C resources are cleaned up upon exit
|
|
||||||
import atexit
|
|
||||||
|
|
||||||
@atexit.register
|
|
||||||
def cleanup():
|
|
||||||
_triton.cleanup()
|
|
||||||
|
|
||||||
codes = {
|
codes = {
|
||||||
_triton.arg_type.int1: 'B', _triton.arg_type.int8: 'B', _triton.arg_type.int32: 'I', _triton.arg_type.int64: 'Q',
|
_triton.runtime.arg_type.int1: 'B', _triton.runtime.arg_type.int8: 'B', _triton.runtime.arg_type.int32: 'I',
|
||||||
_triton.arg_type.half: 'H', _triton.arg_type.float: 'f', _triton.arg_type.double: 'd', _triton.arg_type.buffer: 'P'
|
_triton.runtime.arg_type.int64: 'Q', _triton.runtime.arg_type.half: 'H', _triton.runtime.arg_type.float: 'f',
|
||||||
|
_triton.runtime.arg_type.double: 'd', _triton.runtime.arg_type.buffer: 'P'
|
||||||
}
|
}
|
||||||
|
|
||||||
def th_to_triton(obj):
|
def th_to_triton(obj):
|
||||||
tys = {
|
tys = {
|
||||||
torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long', torch.float16: 'half',
|
torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long',\
|
||||||
torch.float32: 'float', torch.float64: 'double'
|
torch.float16: 'half', torch.float32: 'float', torch.float64: 'double'
|
||||||
}
|
}
|
||||||
if isinstance(obj, torch.dtype):
|
if isinstance(obj, torch.dtype):
|
||||||
return tys[obj]
|
return tys[obj]
|
||||||
@@ -30,69 +24,54 @@ def th_to_triton(obj):
|
|||||||
def cdiv(a, b):
|
def cdiv(a, b):
|
||||||
return (a + b - 1) // b
|
return (a + b - 1) // b
|
||||||
|
|
||||||
def synchronize(device):
|
def read(path, kernel_names: Optional[List] = None):
|
||||||
dev_id = device.index
|
|
||||||
dev_id = -1 if dev_id is None else dev_id
|
|
||||||
_torch_utils.synchronize(dev_id)
|
|
||||||
|
|
||||||
def read(path, kernel_names:Optional[List]=None):
|
|
||||||
if kernel_names is None:
|
if kernel_names is None:
|
||||||
kernel_names = []
|
kernel_names = []
|
||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
source = f.read()
|
source = f.read()
|
||||||
source = _triton.extract_kernels(source, kernel_names)
|
source = _triton.tools.extract_kernels(source, kernel_names)
|
||||||
return source
|
return source
|
||||||
|
|
||||||
class kernel:
|
config = _triton.runtime.config
|
||||||
def __init__(self,
|
|
||||||
src,
|
|
||||||
device,
|
|
||||||
defines: Optional[Dict]=None,
|
|
||||||
num_warps:int=4,
|
|
||||||
autotune_vals:Optional[List]=None,
|
|
||||||
autotune_key:Optional[List]=None):
|
|
||||||
|
|
||||||
|
class kernel:
|
||||||
|
def __init__(self, src, device, defines: Optional[Dict] = None, num_warps: int = 4,
|
||||||
|
autotune_vals: Optional[List] = None, autotune_key: Optional[List] = None):
|
||||||
if defines is None:
|
if defines is None:
|
||||||
defines = {}
|
defines = {}
|
||||||
if autotune_vals is None:
|
if autotune_vals is None:
|
||||||
autotune_vals = []
|
autotune_vals = []
|
||||||
if autotune_key is None:
|
if autotune_key is None:
|
||||||
autotune_key = []
|
autotune_key = []
|
||||||
|
|
||||||
|
|
||||||
# check if src is empty
|
# check if src is empty
|
||||||
if src == '':
|
if src == '':
|
||||||
raise ValueError('Kernel source code is empty')
|
raise ValueError('Kernel source code is empty')
|
||||||
self.src = src
|
self.src = src
|
||||||
self.opt = _triton.options()
|
|
||||||
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
|
|
||||||
self.opt.num_warps = num_warps
|
|
||||||
# device
|
# device
|
||||||
assert device.type in ['cuda', 'cpu']
|
assert device.type in ['cuda', 'cpu']
|
||||||
if device.type == 'cuda':
|
if device.type == 'cuda':
|
||||||
self.device = torch.cuda.current_device() if device.index is None else device.index
|
self.device_id = torch.cuda.current_device() if device.index is None else device.index
|
||||||
|
self.device = _triton.driver.cu_device(_torch_utils.cu_device(self.device_id), False)
|
||||||
|
self.stream = _triton.driver.cu_stream(_torch_utils.cu_stream(self.device_id), False)
|
||||||
if device.type == 'cpu':
|
if device.type == 'cpu':
|
||||||
self.device = -1
|
self.device_id = -1
|
||||||
_torch_utils.register_device(self.device)
|
self.device = _triton.driver.host_device()
|
||||||
_torch_utils.register_stream(self.device)
|
self.device = _triton.driver.host_stream()
|
||||||
# C++ function wrapper
|
_torch_utils.set_device(self.device_id)
|
||||||
self.op_id = _triton.make_op_id()
|
# function
|
||||||
_torch_utils.set_device(self.device)
|
self.opt = _triton.runtime.options()
|
||||||
_triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key)
|
self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()}
|
||||||
# debug mode
|
self.opt.num_warps = num_warps
|
||||||
self.is_debug = 'TRITON_DEBUG' in os.environ
|
# autotune_vals = [({}, 4)]
|
||||||
# signature
|
self.fn = _triton.runtime.function(self.src, self.opt, self.device, autotune_vals, autotune_key)
|
||||||
arg_types = _triton.get_fn_signature(self.op_id)
|
self.tys = ''.join([codes[x] for x in self.fn.signature()])
|
||||||
self.tys = ''.join([codes[x] for x in arg_types])
|
|
||||||
|
|
||||||
def __call__(self, *args, grid):
|
def __call__(self, *args, grid):
|
||||||
_torch_utils.set_device(self.device)
|
# make sure that the executing thread is on the right device
|
||||||
|
_torch_utils.set_device(self.device_id)
|
||||||
# pack parameters into a byte buffer
|
# pack parameters into a byte buffer
|
||||||
params = struct.pack(self.tys, *args)
|
params = struct.pack(self.tys, *args)
|
||||||
opt = _triton.autotune(self.op_id, self.device, params, grid)
|
kernel = self.fn.autotune(params, grid, self.stream)
|
||||||
# run kernel
|
# run kernel
|
||||||
grid = grid(opt)
|
grid = grid(kernel.opt)
|
||||||
grid_0 = grid[0]
|
kernel(params, self.stream, grid)
|
||||||
grid_1 = 1 if len(grid) < 2 else grid[1]
|
|
||||||
grid_2 = 1 if len(grid) < 3 else grid[2]
|
|
||||||
_triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2)
|
|
||||||
|
@@ -1,17 +1,17 @@
|
|||||||
__global__ void NAME(TYPE *A __readonly __noalias __aligned(16),
|
__global__ void NAME(TYPE *A __readonly __noalias,
|
||||||
TYPE *B __readonly __noalias __aligned(16),
|
TYPE *B __readonly __noalias,
|
||||||
TYPE *C __noalias __aligned(16),
|
TYPE *C __noalias,
|
||||||
int lda __multipleof(8),
|
int lda,
|
||||||
int ldb __multipleof(8),
|
int ldb,
|
||||||
int ldc __multipleof(8),
|
int ldc,
|
||||||
long stride_za __multipleof(8),
|
long stride_za,
|
||||||
long stride_zb __multipleof(8),
|
long stride_zb,
|
||||||
long stride_zc __multipleof(8),
|
long stride_zc,
|
||||||
long stride_ha __multipleof(8),
|
long stride_ha,
|
||||||
long stride_hb __multipleof(8),
|
long stride_hb,
|
||||||
long stride_hc __multipleof(8),
|
long stride_hc,
|
||||||
int DS0, int DS1,
|
int DS0, int DS1,
|
||||||
int SDD_K __multipleof(16),
|
int SDD_K,
|
||||||
int SDD_off_width,
|
int SDD_off_width,
|
||||||
int *lut, int *locks, int nlocks) {
|
int *lut, int *locks, int nlocks) {
|
||||||
/* ---------------- */
|
/* ---------------- */
|
||||||
|
@@ -1,17 +1,16 @@
|
|||||||
__global__ void forward(TYPE *X __readonly __noalias __aligned(16),
|
__global__ void forward(TYPE *X __readonly __noalias,
|
||||||
float scale,
|
float scale,
|
||||||
int *LUT __readonly __noalias __aligned(16),
|
int *LUT __readonly __noalias,
|
||||||
TYPE *RPE __readonly __noalias __aligned(16),
|
TYPE *RPE __readonly __noalias,
|
||||||
TYPE *KP_M __readonly __noalias __aligned(16),
|
TYPE *KP_M __readonly __noalias,
|
||||||
TYPE *ATTN_M __readonly __noalias __aligned(16),
|
TYPE *ATTN_M __readonly __noalias,
|
||||||
int sizemax,
|
int sizemax,
|
||||||
long stride_zx __multipleof(4),
|
long stride_zx,
|
||||||
long stride_zrpe __multipleof(BLOCK),
|
long stride_zrpe,
|
||||||
int stride_hrpe __multipleof(BLOCK),
|
int stride_hrpe,
|
||||||
int stride_srpe __multipleof(BLOCK),
|
int stride_srpe,
|
||||||
int stride_zkpm __multipleof(BLOCK),
|
int stride_zkpm,
|
||||||
int stride_zattnm __multipleof(BLOCK))
|
int stride_zattnm) {
|
||||||
{
|
|
||||||
int pidhm = get_program_id(0);
|
int pidhm = get_program_id(0);
|
||||||
int pidz = get_program_id(1);
|
int pidz = get_program_id(1);
|
||||||
// create index ranges
|
// create index ranges
|
||||||
@@ -97,14 +96,13 @@ __global__ void forward(TYPE *X __readonly __noalias __aligned(16),
|
|||||||
*? (check)px = y / ysum;
|
*? (check)px = y / ysum;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void backward(TYPE *X __readonly __noalias __aligned(16),
|
__global__ void backward(TYPE *X __readonly __noalias,
|
||||||
float scale,
|
float scale,
|
||||||
TYPE *DX __readonly __noalias __aligned(16),
|
TYPE *DX __readonly __noalias,
|
||||||
int *LUT,
|
int *LUT,
|
||||||
int sizemax,
|
int sizemax,
|
||||||
long stride_zx __multipleof(BLOCK),
|
long stride_zx,
|
||||||
long stride_zdx __multipleof(BLOCK))
|
long stride_zdx) {
|
||||||
{
|
|
||||||
int pidhm = get_program_id(0);
|
int pidhm = get_program_id(0);
|
||||||
int pidz = get_program_id(1);
|
int pidz = get_program_id(1);
|
||||||
// create index ranges
|
// create index ranges
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
|
__global__ void conv(TYPE *A __noalias __readonly,
|
||||||
TYPE *B __noalias __readonly __aligned(16),
|
TYPE *B __noalias __readonly,
|
||||||
TYPE *C __noalias __aligned(16),
|
TYPE *C __noalias,
|
||||||
float alpha,
|
float alpha,
|
||||||
// equivalent matmul
|
// equivalent matmul
|
||||||
int M, int N, int K,
|
int M, int N, int K,
|
||||||
@@ -9,9 +9,9 @@ __global__ void conv(TYPE *A __noalias __readonly __aligned(16),
|
|||||||
// pointer increment
|
// pointer increment
|
||||||
int *ADELTA,
|
int *ADELTA,
|
||||||
// memory strides
|
// memory strides
|
||||||
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
|
int lda_z, int lda_ci, int lda_h, int lda_w,
|
||||||
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
|
int ldb_ci, int ldb_r, int ldb_s, int ldb_co,
|
||||||
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
|
int ldc_z, int ldc_co, int ldc_p, int ldc_q) {
|
||||||
// prologue
|
// prologue
|
||||||
int ridx = get_program_id(0);
|
int ridx = get_program_id(0);
|
||||||
int ridy = get_program_id(1);
|
int ridy = get_program_id(1);
|
||||||
@@ -35,30 +35,32 @@ __global__ void conv(TYPE *A __noalias __readonly __aligned(16),
|
|||||||
int rz[TM] = rzp / PP;
|
int rz[TM] = rzp / PP;
|
||||||
// unpack aggregate reduction
|
// unpack aggregate reduction
|
||||||
// k = (ci, r, s)
|
// k = (ci, r, s)
|
||||||
int rs [TK] = rk % SS;
|
int rs[TK] = rk % SS;
|
||||||
int rcir[TK] = rk / SS;
|
int rcir[TK] = rk / SS;
|
||||||
int rr [TK] = rcir % RR;
|
int rr[TK] = rcir % RR;
|
||||||
int rci [TK] = rcir / RR;
|
int rci[TK] = rcir / RR;
|
||||||
|
|
||||||
// padding / striding
|
// padding / striding
|
||||||
int rh_0[TM] = rp * stride_h - pad_h;
|
int rh_0[TM] = rp * stride_h - pad_h;
|
||||||
int rw_0[TM] = rq * stride_w - pad_w;
|
int rw_0[TM] = rq * stride_w - pad_w;
|
||||||
int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :];
|
int rh[TM, TK] = rh_0[:, newaxis] + rr [newaxis, :];
|
||||||
int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :];
|
int rw[TM, TK] = rw_0[:, newaxis] + rs [newaxis, :];
|
||||||
|
|
||||||
// pointers to lhs
|
// pointers to lhs
|
||||||
int offa[TM, TK] = rz [:, newaxis] * lda_z +
|
int offa[TM, TK] = rz[:, newaxis] * lda_z +
|
||||||
rci[newaxis, :] * lda_ci +
|
rci [newaxis, :] * lda_ci +
|
||||||
rh * lda_h +
|
rh * lda_h +
|
||||||
rw * 1;
|
rw * 1;
|
||||||
TYPE* pa[TM, TK] = A + offa;
|
TYPE *pa[TM, TK] = A + offa;
|
||||||
int* padelta[TK] = ADELTA + rk;
|
int *padelta[TK] = ADELTA + rk;
|
||||||
// pointers to rhs
|
// pointers to rhs
|
||||||
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
|
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
|
||||||
rr [:, newaxis] * ldb_r +
|
rr
|
||||||
rs [:, newaxis] * ldb_s +
|
[:, newaxis] * ldb_r +
|
||||||
|
rs
|
||||||
|
[:, newaxis] * ldb_s +
|
||||||
rn [newaxis, :] * 1;
|
rn [newaxis, :] * 1;
|
||||||
TYPE* pb[TK, TN] = B + offb;
|
TYPE *pb[TK, TN] = B + offb;
|
||||||
|
|
||||||
// prefetches operands
|
// prefetches operands
|
||||||
bool checkam[TM, TK] = rm[:, newaxis] < M;
|
bool checkam[TM, TK] = rm[:, newaxis] < M;
|
||||||
@@ -70,26 +72,26 @@ __global__ void conv(TYPE *A __noalias __readonly __aligned(16),
|
|||||||
|
|
||||||
// reduction loop
|
// reduction loop
|
||||||
float acc[TM, TN] = 0;
|
float acc[TM, TN] = 0;
|
||||||
for(int k = K; k > 0; k -= TK){
|
for (int k = K; k > 0; k -= TK) {
|
||||||
acc += a @ b;
|
acc += a @b;
|
||||||
// increment A
|
// increment A
|
||||||
int adelta[TK] = *padelta;
|
int adelta[TK] = *padelta;
|
||||||
padelta += TK;
|
padelta += TK;
|
||||||
pa += adelta[newaxis, :];
|
pa += adelta [newaxis, :];
|
||||||
// bounds-checking A
|
// bounds-checking A
|
||||||
rk += TK;
|
rk += TK;
|
||||||
rs = rk % SS;
|
rs = rk % SS;
|
||||||
rcir = rk / SS;
|
rcir = rk / SS;
|
||||||
rr = rcir % RR;
|
rr = rcir % RR;
|
||||||
rh = rh_0[:, newaxis] + rr[newaxis, :];
|
rh = rh_0[:, newaxis] + rr [newaxis, :];
|
||||||
rw = rw_0[:, newaxis] + rs[newaxis, :];
|
rw = rw_0[:, newaxis] + rs [newaxis, :];
|
||||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||||
// increment B
|
// increment B
|
||||||
pb += TK * ldb_s;
|
pb += TK * ldb_s;
|
||||||
// bounds-checking B
|
// bounds-checking B
|
||||||
bool checkb[TK, TN] = k > TK;
|
bool checkb[TK, TN] = k > TK;
|
||||||
a = checka ? *pa : 0;
|
a = checka ? *pa : 0;
|
||||||
b = *?(checkb)pb;
|
b = *? (checkb)pb;
|
||||||
}
|
}
|
||||||
acc = acc * alpha;
|
acc = acc * alpha;
|
||||||
TYPE c[TM, TN] = acc;
|
TYPE c[TM, TN] = acc;
|
||||||
@@ -101,25 +103,28 @@ __global__ void conv(TYPE *A __noalias __readonly __aligned(16),
|
|||||||
rzp = rm / QQ;
|
rzp = rm / QQ;
|
||||||
rp = rzp % PP;
|
rp = rzp % PP;
|
||||||
rz = rzp / PP;
|
rz = rzp / PP;
|
||||||
int offc[TM, TN] = rz [:, newaxis] * ldc_z +
|
int offc[TM, TN] = rz[:, newaxis] * ldc_z +
|
||||||
rn [newaxis, :] * ldc_co+
|
rn [newaxis, :] * ldc_co +
|
||||||
rp [:, newaxis] * ldc_p +
|
rp
|
||||||
rq [:, newaxis] * 1;
|
[:, newaxis] * ldc_p +
|
||||||
TYPE* pc[TM, TN] = C + offc;
|
rq
|
||||||
bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N;
|
[:, newaxis] * 1;
|
||||||
|
TYPE *pc[TM, TN] = C + offc;
|
||||||
|
bool checkc[TM, TN] = rm[:, newaxis] < M && rn [newaxis, :] < N;
|
||||||
|
|
||||||
#if (TZ==1)
|
#if (TZ == 1)
|
||||||
*?(checkc) pc = c;
|
*? (checkc)pc = c;
|
||||||
#else
|
#else
|
||||||
// accumulate partial result using spin-locks
|
// accumulate partial result using spin-locks
|
||||||
int *plock = locks + rid;
|
int *plock = locks + rid;
|
||||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
|
||||||
|
;
|
||||||
int count = *pcount;
|
int count = *pcount;
|
||||||
if(count == 0)
|
if (count == 0)
|
||||||
*?(checkc) pc = c;
|
*? (checkc)pc = c;
|
||||||
else
|
else
|
||||||
*?(checkc) pc = c + *?(checkc)pc;
|
*? (checkc)pc = c + *? (checkc)pc;
|
||||||
atomic_xchg(pcount, (count + 1) % TZ);
|
atomic_xchg(pcount, (count + 1) % TZ);
|
||||||
atomic_xchg(plock, 0);
|
atomic_xchg(plock, 0);
|
||||||
#endif
|
#endif
|
||||||
|
@@ -1,8 +1,4 @@
|
|||||||
__global__ void forward(TYPE *logit __aligned(16),
|
__global__ void forward(TYPE *logit, TYPE *modified_logit, long *indices, TYPE *result, int n_cols) {
|
||||||
TYPE *modified_logit __aligned(16),
|
|
||||||
long *indices __readonly,
|
|
||||||
TYPE *result __aligned(16),
|
|
||||||
int n_cols __multipleof(N_COLS_MULT)) {
|
|
||||||
int row = get_program_id(0);
|
int row = get_program_id(0);
|
||||||
|
|
||||||
bool check[TILE] = ((0 ... TILE) < n_cols);
|
bool check[TILE] = ((0 ... TILE) < n_cols);
|
||||||
@@ -19,10 +15,7 @@ __global__ void forward(TYPE *logit __aligned(16),
|
|||||||
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
|
*(result + row) = *(modified_logit + (local_ind + n_cols * row));
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void backward(TYPE *neg_logprobs __aligned(16),
|
__global__ void backward(TYPE *neg_logprobs, long *indices, TYPE *dneg_logprobs, int n_cols) {
|
||||||
long *indices __aligned(16),
|
|
||||||
TYPE *dneg_logprobs __aligned(16),
|
|
||||||
int n_cols __multipleof(N_COLS_MULT)) {
|
|
||||||
|
|
||||||
int row = get_program_id(0);
|
int row = get_program_id(0);
|
||||||
// pointer arithmetic
|
// pointer arithmetic
|
||||||
|
@@ -1,16 +1,12 @@
|
|||||||
#define STM 8
|
#define STM 8
|
||||||
#define STN 8
|
#define STN 8
|
||||||
|
|
||||||
__global__ void matmul(TYPE *A __noalias __readonly __aligned(16),
|
__global__ void matmul(TYPE *A __noalias __readonly,
|
||||||
TYPE *B __noalias __readonly __aligned(16),
|
TYPE *B __noalias __readonly,
|
||||||
TYPE *C __noalias __aligned(16),
|
TYPE *C __noalias,
|
||||||
float alpha,
|
float alpha,
|
||||||
int M,
|
int M, int N, int K,
|
||||||
int N,
|
int lda, int ldb, int ldc,
|
||||||
int K __multipleof(16),
|
|
||||||
int lda __multipleof(LDA_POW2_DIV),
|
|
||||||
int ldb __multipleof(LDB_POW2_DIV),
|
|
||||||
int ldc __multipleof(LDC_POW2_DIV),
|
|
||||||
int *locks) {
|
int *locks) {
|
||||||
// prologue
|
// prologue
|
||||||
int pid = get_program_id(0);
|
int pid = get_program_id(0);
|
||||||
|
@@ -6,18 +6,18 @@ class _matmul(torch.autograd.Function):
|
|||||||
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
|
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
|
||||||
|
|
||||||
_DEFAULT_CONFIGS = [
|
_DEFAULT_CONFIGS = [
|
||||||
({"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, 4),
|
triton.config(defines={"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, num_warps=4),
|
||||||
({'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, 4),
|
triton.config(defines={'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, num_warps=4),
|
||||||
({'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, 4),
|
triton.config(defines={'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, num_warps=4),
|
||||||
({'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 4),
|
triton.config(defines={'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
|
||||||
({'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, 4),
|
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
|
||||||
({'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 4),
|
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=4),
|
||||||
({'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 2),
|
triton.config(defines={'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=2),
|
||||||
({'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 2),
|
triton.config(defines={'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=2),
|
||||||
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
|
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4),
|
||||||
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
|
triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4),
|
||||||
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
|
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4),
|
||||||
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
|
triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4),
|
||||||
]
|
]
|
||||||
_CONFIGS = _DEFAULT_CONFIGS
|
_CONFIGS = _DEFAULT_CONFIGS
|
||||||
|
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
def sparsify_tensor(x, mask, block):
|
def sparsify_tensor(x, mask, block):
|
||||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||||
@@ -77,8 +78,12 @@ class Mark:
|
|||||||
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
|
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
|
||||||
|
|
||||||
def run(self, result_path, with_plot):
|
def run(self, result_path, with_plot):
|
||||||
|
with open(os.path.join(result_path, "results.html"), "w") as html:
|
||||||
|
html.write("<html><body>\n")
|
||||||
for bench in self.benchmarks:
|
for bench in self.benchmarks:
|
||||||
self._run(bench, result_path, with_plot)
|
self._run(bench, result_path, with_plot)
|
||||||
|
html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
|
||||||
|
html.write("</body></html>\n")
|
||||||
|
|
||||||
def perf_report(benchmarks):
|
def perf_report(benchmarks):
|
||||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||||
|
76
python/tutorials/01-vector-add.py
Normal file
76
python/tutorials/01-vector-add.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
# source-code for Triton compute kernel
|
||||||
|
# here we just copy-paste the above code without the extensive comments.
|
||||||
|
# you may prefer to store it in a .c file and load it from there instead.
|
||||||
|
_src = """
|
||||||
|
__global__ void add(float* z, float* x, float* y, int N){
|
||||||
|
// program id
|
||||||
|
int pid = get_program_id(0);
|
||||||
|
// create arrays of pointers
|
||||||
|
int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK;
|
||||||
|
float* pz[BLOCK] = z + offset;
|
||||||
|
float* px[BLOCK] = x + offset;
|
||||||
|
float* py[BLOCK] = y + offset;
|
||||||
|
// bounds checking
|
||||||
|
bool check[BLOCK] = offset < N;
|
||||||
|
// write-back
|
||||||
|
*?(check)pz = *?(check)px + *?(check)py;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# This function returns a callable `triton.kernel` object
|
||||||
|
# created from the above source code.
|
||||||
|
# For portability, we maintain a cache of kernels for different `torch.device`
|
||||||
|
# We compile the kernel with -DBLOCK=1024
|
||||||
|
_kernels = dict()
|
||||||
|
|
||||||
|
def make_add_kernel(device):
|
||||||
|
if device not in _kernels:
|
||||||
|
defines = {'BLOCK': 1024}
|
||||||
|
autotune_vals = [({'BLOCK': '1024'}, 4), ({'BLOCK': '2048'}, 4)]
|
||||||
|
autotune_key = ["N"]
|
||||||
|
_kernels[device] = triton.kernel(_src, device=device, defines=defines, autotune_vals=autotune_vals,
|
||||||
|
autotune_key=autotune_key)
|
||||||
|
return _kernels[device]
|
||||||
|
|
||||||
|
# This is a standard torch custom autograd Function
|
||||||
|
# The only difference is that we can now use the above kernel
|
||||||
|
# in the `forward` and `backward` functions.`
|
||||||
|
class _add(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, y):
|
||||||
|
# constraints of the op
|
||||||
|
assert x.dtype == torch.float32
|
||||||
|
# *allocate output*
|
||||||
|
z = torch.empty_like(x)
|
||||||
|
# *create launch grid*:
|
||||||
|
# this is a function which takes compilation parameters `opt`
|
||||||
|
# as input and returns a tuple of int (i.e., launch grid) for the kernel.
|
||||||
|
# triton.cdiv is a shortcut for ceil division:
|
||||||
|
# triton.cdiv(a, b) = (a + b - 1) // b
|
||||||
|
grid = lambda opt: (triton.cdiv(z.shape[0], opt.BLOCK), )
|
||||||
|
# *launch kernel*:
|
||||||
|
# pointer to the data of torch tensors can be retrieved with
|
||||||
|
# the `.data_ptr()` method
|
||||||
|
kernel = make_add_kernel(z.device)
|
||||||
|
kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), z.shape[0], grid=grid)
|
||||||
|
return z
|
||||||
|
|
||||||
|
# Just like we standard PyTorch ops
|
||||||
|
# We use the `.apply` method to create a
|
||||||
|
# callable object for our function
|
||||||
|
add = _add.apply
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
x = torch.rand(32, device='cuda')
|
||||||
|
y = torch.rand(32, device='cuda')
|
||||||
|
za = x + y
|
||||||
|
zb = add(x, y)
|
||||||
|
print(za)
|
||||||
|
print(zb)
|
||||||
|
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
|
||||||
|
|
||||||
|
th_ms = triton.testing.do_bench(lambda: x + y)
|
||||||
|
tr_ms = triton.testing.do_bench(lambda: add(x, y))
|
||||||
|
print(th_ms, tr_ms)
|
Reference in New Issue
Block a user