[PYTHON][TESTS][DOC] Various improvement of the API and code quality:
* Simplified `triton.kernel` API to achieve lower latency: > .data_ptr() must now be passed as kernel argument. No more implicit conversion from torch.tensor > compilation options are now constant attributes, i.e., opt.d('VAR') becomes opt.VAR > torch.device must now be passed explicitly to triton.kernel (no longer inferred from torch.tensor arguments) * C++ tests moved to `python/tests/` * C++ tutorial created in `tutorials/` * Python tutorial created in python/tutorials/ * Version changed to 1.0alpha * No longer copying C++ headers into the Python package * added python/triton/ops/ package for pre-written Triton ops
This commit is contained in:
@@ -29,7 +29,7 @@ public:
|
||||
static driver::stream* create(backend_t backend);
|
||||
// methods
|
||||
virtual void synchronize() = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void **args = NULL, size_t args_size = 0) = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args = NULL, size_t args_size = 0) = 0;
|
||||
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
|
||||
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
|
||||
// template helpers
|
||||
@@ -44,7 +44,7 @@ class host_stream: public stream {
|
||||
public:
|
||||
host_stream();
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void **args, size_t args_size);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
@@ -55,7 +55,7 @@ public:
|
||||
cu_stream(CUstream str, bool take_ownership);
|
||||
cu_stream();
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void **args, size_t args_size);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
|
@@ -4,27 +4,19 @@
|
||||
#define _TRITON_RUNTIME_FUNCTION_H_
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <set>
|
||||
// codegen
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
#include "triton/runtime/error.h"
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
class LLVMContext;
|
||||
}
|
||||
|
||||
class Parser;
|
||||
|
||||
// driver forward declaration
|
||||
namespace triton {
|
||||
|
||||
namespace driver{
|
||||
class module;
|
||||
class stream;
|
||||
@@ -32,26 +24,19 @@ namespace driver{
|
||||
class context;
|
||||
class device;
|
||||
}
|
||||
|
||||
namespace lang{
|
||||
class translation_unit;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
class tiles;
|
||||
}
|
||||
}
|
||||
|
||||
// ir forward declaration
|
||||
namespace triton{
|
||||
namespace ir {
|
||||
class module;
|
||||
class function;
|
||||
class context;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace runtime{
|
||||
|
||||
|
||||
typedef std::vector<size_t> grid_t;
|
||||
typedef std::map<std::string, size_t> params_t;
|
||||
template<typename T> inline T convert(const std::string& name);
|
||||
@@ -72,8 +57,7 @@ enum asm_mode_t {
|
||||
struct options_space_t {
|
||||
typedef std::pair<std::string, std::vector<std::string>> define_t;
|
||||
std::vector<define_t> defines;
|
||||
std::vector<int> num_warps;
|
||||
std::vector<int> recompile_key;
|
||||
std::vector<int> num_warps;
|
||||
};
|
||||
|
||||
struct options_t {
|
||||
@@ -81,88 +65,69 @@ struct options_t {
|
||||
T D(const std::string& name) const {
|
||||
return convert<T>(defines.at(name));
|
||||
}
|
||||
bool operator<(const options_t& other) const {
|
||||
return std::make_pair(defines, num_warps) <
|
||||
std::make_pair(other.defines, other.num_warps);
|
||||
}
|
||||
std::string to_str() const;
|
||||
|
||||
std::map<std::string, std::string> defines;
|
||||
std::unordered_map<std::string, std::string> defines;
|
||||
size_t num_warps;
|
||||
};
|
||||
|
||||
|
||||
/* ------------------------- */
|
||||
|
||||
class kernel{
|
||||
private:
|
||||
static std::string preheader();
|
||||
static arg_type convert(ir::type *ty);
|
||||
|
||||
public:
|
||||
kernel(const std::string& src, const options_t& opt, driver::device *device);
|
||||
void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector<size_t>& grid) const;
|
||||
// getters
|
||||
const std::vector<arg_type>& get_sig() const { return sig_; }
|
||||
|
||||
private:
|
||||
void init_ir (const std::string &src);
|
||||
void init_ker();
|
||||
void init_sig();
|
||||
|
||||
public:
|
||||
const options_t opt;
|
||||
|
||||
private:
|
||||
driver::device* dev_;
|
||||
// signature
|
||||
std::vector<arg_type> sig_;
|
||||
// triton context for parsing
|
||||
ir::context ctx_;
|
||||
// handles
|
||||
std::shared_ptr<ir::module> ir_;
|
||||
std::shared_ptr<driver::module> mod_;
|
||||
std::shared_ptr<driver::kernel> ker_;
|
||||
};
|
||||
|
||||
class function {
|
||||
public:
|
||||
typedef std::function<grid_t(const options_t&)> grid_fn_ty;
|
||||
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
|
||||
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
|
||||
|
||||
private:
|
||||
class caller {
|
||||
public:
|
||||
// constructors
|
||||
caller(std::ifstream& ifs, const options_t& opt);
|
||||
caller(ir::function *ir, std::shared_ptr<driver::module> program, const options_t& opt);
|
||||
// serialization
|
||||
void write(std::ofstream& ofs);
|
||||
void read(std::ifstream& ifs);
|
||||
// accessors
|
||||
const options_t opt() const { return opt_; }
|
||||
const driver::module* parent() const { return &*parent_; }
|
||||
const driver::kernel* bin() const { return &*bin_; }
|
||||
arg_type param_ty(size_t i) const { return param_tys_.at(i);}
|
||||
const std::vector<arg_type>& param_tys() const { return param_tys_; }
|
||||
|
||||
std::vector<int> retune() const { return retune_; }
|
||||
// entry points
|
||||
void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size, const std::map<std::string, std::vector<char>>& = {}) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<driver::kernel> bin_;
|
||||
std::shared_ptr<driver::module> parent_;
|
||||
std::vector<arg_type> param_tys_;
|
||||
std::vector<int> retune_;
|
||||
options_t opt_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
private:
|
||||
typedef std::pair<driver::device*, std::vector<int32_t>> cache_key_t;
|
||||
|
||||
private:
|
||||
// cache
|
||||
static std::string get_cache_prefix();
|
||||
// make
|
||||
triton::lang::translation_unit *make_ast(const std::string &src);
|
||||
std::unique_ptr<ir::module> make_ir(Parser &parser);
|
||||
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::device *device, const options_t &opt);
|
||||
void make(driver::device *device, options_t opt);
|
||||
void precompile(driver::device *device, const options_space_t& tuning_space);
|
||||
// autotune
|
||||
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size);
|
||||
|
||||
static void do_loop_nest(std::vector<size_t> const & ranges,
|
||||
std::function<void(std::vector<size_t> const &)> const & f);
|
||||
public:
|
||||
static std::string preheader();
|
||||
|
||||
public:
|
||||
function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = "");
|
||||
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream, driver::device* device);
|
||||
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream, driver::device* device);
|
||||
void set_cst(const char* name, void* data, size_t n_bytes);
|
||||
std::string get_asm(asm_mode_t mode, driver::device *device, const options_t& opt);
|
||||
function(const std::string& src, const options_space_t& opt, driver::device *device);
|
||||
void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream);
|
||||
// auto-tuning
|
||||
cache_t::iterator find_in_cache(void* args, size_t args_size);
|
||||
kernel* autotune(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
||||
// getters
|
||||
const std::vector<kernel_pair_t> get_kernels() { return kernels_; }
|
||||
|
||||
private:
|
||||
std::map<std::string, std::vector<char>> cst_;
|
||||
// pre-compilation
|
||||
ir::context ctx_;
|
||||
std::string src_;
|
||||
options_space_t opt_;
|
||||
std::set<options_t> compiled_;
|
||||
std::map<options_t, std::unique_ptr<caller>> callers_;
|
||||
std::vector<int> args_off_;
|
||||
size_t args_size_;
|
||||
// caching
|
||||
std::string cache_ref_;
|
||||
std::string cache_path_;
|
||||
std::map<cache_key_t, caller*> cache_;
|
||||
void init_kernels(const std::string& src, const options_space_t& opt, driver::device *device);
|
||||
|
||||
private:
|
||||
std::vector<kernel_pair_t> kernels_;
|
||||
std::map<std::vector<uint64_t>, kernel*> cache_;
|
||||
};
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user