[GENERAL] Merged einsum feature branch. Various feature, performance
improvements and bugfixes: * Added preliminary support for extended Einstein summation in PyTriton * Significant performance improvement on FP32 kernels containing matrix multiplication * Added re-coalescing pass for FP16 kernels containing matrix multiplication * Various bugfixes
This commit is contained in:
@@ -1,83 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_INSTRUCTIONS_H_
|
||||
#define _TRITON_CODEGEN_INSTRUCTIONS_H_
|
||||
|
||||
#include "triton/ir/enums.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
enum storage_info_t {
|
||||
NONE,
|
||||
ANY,
|
||||
SHARED,
|
||||
DISTRIBUTED,
|
||||
REPLICATED
|
||||
};
|
||||
|
||||
typedef std::pair<storage_info_t, std::vector<storage_info_t>> inst_storage_info_t;
|
||||
static const std::map<ir::value_id_t, inst_storage_info_t> storage_info = {
|
||||
// scalars
|
||||
{ ir::INST_GET_PROGRAM_ID, {REPLICATED, {}}},
|
||||
{ ir::INST_GET_NUM_PROGRAMS, {REPLICATED, {}}},
|
||||
// scalar/array
|
||||
{ ir::INST_PHI, {ANY, {ANY, ANY}}},
|
||||
{ ir::INST_BINOP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
|
||||
{ ir::INST_GETELEMENTPTR, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
|
||||
{ ir::INST_SELECT, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED, DISTRIBUTED}}},
|
||||
{ ir::INST_SQRT, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
// cmp
|
||||
{ ir::INST_ICMP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
|
||||
{ ir::INST_FCMP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
|
||||
// cast
|
||||
{ ir::INST_CAST_TRUNC, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_CAST_ZEXT, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_CAST_SEXT, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_CAST_FP_TRUNC, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_CAST_FP_EXT, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_CAST_UI_TO_FP, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_CAST_SI_TO_FP, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_CAST_FP_TO_UI, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_CAST_FP_TO_SI, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_CAST_PTR_TO_INT, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_CAST_INT_TO_PTR, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_CAST_BIT_CAST, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_CAST_ADDR_SPACE_CAST, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
// io
|
||||
{ ir::INST_UNMASKED_LOAD, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_MASKED_LOAD, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
|
||||
{ ir::INST_UNMASKED_STORE, {NONE , {DISTRIBUTED, DISTRIBUTED}}},
|
||||
{ ir::INST_MASKED_STORE, {NONE , {DISTRIBUTED, DISTRIBUTED, DISTRIBUTED}}},
|
||||
// retile
|
||||
{ ir::INST_RESHAPE, {DISTRIBUTED, {DISTRIBUTED}}},
|
||||
{ ir::INST_SPLAT, {DISTRIBUTED, {REPLICATED}}},
|
||||
{ ir::INST_BROADCAST, {DISTRIBUTED, {REPLICATED}}},
|
||||
{ ir::INST_DOWNCAST, {DISTRIBUTED, {REPLICATED}}},
|
||||
// array arithmetic
|
||||
{ ir::INST_TRANS, {SHARED, {SHARED}}},
|
||||
{ ir::INST_REDUCE, {SHARED, {DISTRIBUTED}}},
|
||||
{ ir::INST_DOT, {DISTRIBUTED, {SHARED, SHARED, DISTRIBUTED}}},
|
||||
// terminator
|
||||
{ ir::INST_RETURN, {NONE, {}}},
|
||||
{ ir::INST_UNCOND_BRANCH, {NONE, {}}},
|
||||
{ ir::INST_COND_BRANCH, {NONE, {REPLICATED}}},
|
||||
// intrinsics
|
||||
{ ir::INST_COPY_TO_SHARED, {SHARED, {DISTRIBUTED}}},
|
||||
{ ir::INST_COPY_FROM_SHARED, {DISTRIBUTED, {SHARED}}},
|
||||
{ ir::INST_BARRIER, {NONE, {}}},
|
||||
{ ir::INST_MAKE_RANGE_DYN, {DISTRIBUTED, {}}},
|
||||
{ ir::INST_MAKE_RANGE_STA, {DISTRIBUTED, {}}},
|
||||
{ ir::INST_MAKE_RANGE, {DISTRIBUTED, {}}}
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -122,6 +122,7 @@ public:
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
void visit_select_inst(ir::select_inst*);
|
||||
|
||||
void visit_recoalesce_inst(ir::recoalesce_inst*);
|
||||
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
|
||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||
void visit_barrier_inst(ir::barrier_inst*);
|
||||
|
@@ -128,22 +128,22 @@ private:
|
||||
Type *make_vector_ty(Type *ty, size_t vector_size);
|
||||
|
||||
public:
|
||||
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder, bool vectorize);
|
||||
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder);
|
||||
void set_value(indices_t idx, Value *v);
|
||||
Value* get_value(indices_t idx);
|
||||
const std::vector<int>& get_order() { return order_; }
|
||||
unsigned get_linear_index(indices_t idx);
|
||||
indices_t get_ordered_indices(unsigned id);
|
||||
void for_each(std::function<void(indices_t)> fn);
|
||||
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }
|
||||
void for_each(std::function<void(indices_t)> fn, int start = 0, int end = -1);
|
||||
void for_each(std::function<void(indices_t)> fn, std::vector<int> start, std::vector<int> size);
|
||||
|
||||
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }
|
||||
private:
|
||||
axes_t axes_;
|
||||
std::vector<int> order_;
|
||||
indices_map_t indices_;
|
||||
values_map_t values_;
|
||||
ordered_indices_vec_t ordered_indices_;
|
||||
size_t vector_size_;
|
||||
Builder &builder_;
|
||||
};
|
||||
|
||||
|
@@ -45,6 +45,8 @@ public:
|
||||
llvm::SmallVectorImpl<char> &buffer,
|
||||
const std::string &features,
|
||||
file_type_t file_type);
|
||||
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
|
||||
|
||||
|
||||
protected:
|
||||
driver::context* ctx_;
|
||||
@@ -54,13 +56,14 @@ protected:
|
||||
class host_module: public module{
|
||||
public:
|
||||
host_module(driver::context* context, std::unique_ptr<llvm::Module> module);
|
||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||
};
|
||||
|
||||
// OpenCL
|
||||
class ocl_module: public module{
|
||||
|
||||
public:
|
||||
ocl_module(driver::context* context, std::unique_ptr<llvm::Module> module);
|
||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||
};
|
||||
|
||||
// CUDA
|
||||
@@ -70,7 +73,7 @@ class cu_module: public module {
|
||||
public:
|
||||
cu_module(driver::context* context, std::unique_ptr<llvm::Module> module);
|
||||
cu_module(driver::context* context, const std::string& source);
|
||||
cu_buffer* symbol(const char * name) const;
|
||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||
|
||||
private:
|
||||
std::string source_;
|
||||
|
@@ -37,6 +37,7 @@ public:
|
||||
// Constants
|
||||
value *get_int32(unsigned val);
|
||||
// Types
|
||||
type *get_void_ty();
|
||||
type *get_int1_ty();
|
||||
type *get_int8_ty();
|
||||
type *get_int16_ty();
|
||||
@@ -115,10 +116,10 @@ public:
|
||||
value *create_and(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_xor(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_or(value *lhs, value *rhs, const std::string &name = "");
|
||||
// Side effects
|
||||
value *create_fneg(value *arg, const std::string &name = "");
|
||||
value *create_neg(value *arg, const std::string &name = "");
|
||||
value *create_not(value *arg, const std::string &name = "");
|
||||
// Unary
|
||||
// value *create_fneg(value *arg, const std::string &name = "");
|
||||
// value *create_neg(value *arg, const std::string &name = "");
|
||||
// value *create_not(value *arg, const std::string &name = "");
|
||||
// Input/Output
|
||||
value *create_load(value *arg, const std::string &name = "");
|
||||
value *create_store(value *ptr, value *val, const std::string &name = "");
|
||||
|
@@ -134,6 +134,7 @@ enum value_id_t: unsigned {
|
||||
// intrinsics
|
||||
INST_COPY_TO_SHARED,
|
||||
INST_COPY_FROM_SHARED,
|
||||
INST_RECOALESCE,
|
||||
INST_BARRIER,
|
||||
INST_MAKE_RANGE_DYN,
|
||||
INST_MAKE_RANGE_STA,
|
||||
|
@@ -148,9 +148,9 @@ public:
|
||||
// Factory methods
|
||||
static binary_operator *create(binary_op_t op, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
// static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
// static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
// static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(binary_operator)
|
||||
_TRITON_DEFINE_ACCEPT(binary_operator)
|
||||
@@ -732,6 +732,17 @@ public:
|
||||
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
|
||||
};
|
||||
|
||||
class recoalesce_inst: public unary_inst{
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "recoalesce_inst"; }
|
||||
|
||||
public:
|
||||
static recoalesce_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(recoalesce_inst)
|
||||
_TRITON_DEFINE_ACCEPT(recoalesce_inst)
|
||||
};
|
||||
|
||||
class barrier_inst: public instruction{
|
||||
private:
|
||||
barrier_inst(context &ctx, const std::string &name, instruction *next);
|
||||
|
@@ -59,6 +59,7 @@ class sqrt_inst;
|
||||
class reduce_inst;
|
||||
class select_inst;
|
||||
|
||||
class recoalesce_inst;
|
||||
class copy_to_shared_inst;
|
||||
class copy_from_shared_inst;
|
||||
class barrier_inst;
|
||||
@@ -129,6 +130,7 @@ public:
|
||||
virtual void visit_reduce_inst(reduce_inst*) = 0;
|
||||
virtual void visit_select_inst(select_inst*) = 0;
|
||||
|
||||
virtual void visit_recoalesce_inst(recoalesce_inst*) = 0;
|
||||
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
|
||||
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
|
||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||
|
@@ -47,6 +47,7 @@ protected:
|
||||
};
|
||||
|
||||
void set_ret(ir::value* value);
|
||||
ir::value *GenUnaryMinus(ir::value* arg);
|
||||
|
||||
public:
|
||||
Generator(Parser* parser) : parser_(parser) {}
|
||||
|
@@ -145,6 +145,7 @@ public:
|
||||
THREAD, // _Thread_local
|
||||
AUTO,
|
||||
GLOBAL,
|
||||
CMEM, // constant memory
|
||||
|
||||
// STORAGE CLASS SPECIFIER END
|
||||
BREAK,
|
||||
|
@@ -39,7 +39,7 @@ enum {
|
||||
S_EXTERN = 0x02,
|
||||
S_STATIC = 0x04,
|
||||
S_THREAD = 0x08,
|
||||
S_AUTO = 0x10,
|
||||
S_CONSTANT = 0x10,
|
||||
S_GLOBAL = 0x20,
|
||||
|
||||
// Type specifier
|
||||
@@ -73,7 +73,8 @@ struct Qualifier {
|
||||
CONST = 0x01,
|
||||
RESTRICT = 0x02,
|
||||
VOLATILE = 0x04,
|
||||
MASK = CONST | RESTRICT | VOLATILE
|
||||
CMEM = 0x08,
|
||||
MASK = CONST | RESTRICT | VOLATILE | CMEM
|
||||
};
|
||||
};
|
||||
|
||||
@@ -111,6 +112,7 @@ public:
|
||||
bool IsConstQualified() const { return ptr_ & Qualifier::CONST; }
|
||||
bool IsRestrictQualified() const { return ptr_ & Qualifier::RESTRICT; }
|
||||
bool IsVolatileQualified() const { return ptr_ & Qualifier::VOLATILE; }
|
||||
bool IsConstantQualified() const { return ptr_ & Qualifier::CMEM; }
|
||||
|
||||
private:
|
||||
intptr_t ptr_;
|
||||
|
@@ -8,11 +8,9 @@
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
// codegen
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
|
||||
namespace llvm {
|
||||
@@ -20,6 +18,8 @@ namespace llvm {
|
||||
class LLVMContext;
|
||||
}
|
||||
|
||||
class Parser;
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace driver{
|
||||
@@ -106,14 +106,14 @@ public:
|
||||
function(const std::string& src, const options_space_t& opt = options_space_t());
|
||||
void operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream* stream);
|
||||
void operator()(const std::vector<arg>& args, const grid_fn_ty& grid, driver::stream *stream);
|
||||
std::string make_tensorflow_src(const std::vector<size_t> &outputs, const std::string ¯o);
|
||||
void set_cst(const std::string& name, void* data, size_t n_bytes);
|
||||
|
||||
private:
|
||||
ir::context ctx_;
|
||||
std::string src_;
|
||||
options_space_t opt_space_;
|
||||
std::map<cache_key_t, caller> cache_;
|
||||
std::mutex src_mutex_;
|
||||
std::map<std::string, std::vector<char>> cst_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -30,7 +30,7 @@ private:
|
||||
high_resolution_clock::time_point _start;
|
||||
};
|
||||
|
||||
inline double bench(std::function<void()> const & op, driver::stream * stream)
|
||||
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false)
|
||||
{
|
||||
// const driver::device * device = stream->context()->device();
|
||||
timer tmr;
|
||||
@@ -38,9 +38,10 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
|
||||
double total_time = 0;
|
||||
op();
|
||||
stream->synchronize();
|
||||
while(total_time*1e-9 < 1e-2){
|
||||
while(total_time*1e-9 < 1e-1){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to reduce noise in auto-tuning
|
||||
if(normalize)
|
||||
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
||||
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
tmr.start();
|
||||
|
@@ -5,24 +5,41 @@
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
inline int gcd(int a, int b) {
|
||||
if (a == 0)
|
||||
return b;
|
||||
if (b == 0)
|
||||
return a;
|
||||
if (a == b)
|
||||
return a;
|
||||
if (a > b)
|
||||
return gcd(a - b, b);
|
||||
return gcd(a, b - a);
|
||||
// Function for extended Euclidean Algorithm
|
||||
int gcd_impl(int a, int b, int *x, int *y)
|
||||
{
|
||||
// Base Case
|
||||
if (a == 0)
|
||||
{
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
}
|
||||
|
||||
int x1, y1; // To store results of recursive call
|
||||
int gcd = gcd_impl(b%a, a, &x1, &y1);
|
||||
|
||||
// Update x and y using results of
|
||||
// recursive call
|
||||
*x = y1 - (b/a) * x1;
|
||||
*y = x1;
|
||||
|
||||
return gcd;
|
||||
}
|
||||
|
||||
int gcd(int a, int b) {
|
||||
int x, y;
|
||||
return gcd_impl(a, b, &x, &y);
|
||||
}
|
||||
|
||||
|
||||
inline int lcm(int a, int b) {
|
||||
return (a * b) / gcd(a, b);
|
||||
}
|
||||
@@ -156,7 +173,7 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
if(is_constant_.find(v) != is_constant_.end())
|
||||
return is_constant_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(v, {cst_info{true, (unsigned)x->get_value()}}, is_constant_);
|
||||
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
|
||||
if(dynamic_cast<ir::make_range_sta*>(v))
|
||||
return add_to_cache(v, {cst_info{true, 0}}, is_constant_);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
@@ -448,7 +465,7 @@ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_starting_multiple_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_value()}, starting_multiple_);
|
||||
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range_dyn*>(v))
|
||||
@@ -484,6 +501,7 @@ void align::populate(ir::value *v) {
|
||||
populate_is_constant(v);
|
||||
populate_starting_multiple(v);
|
||||
populate_max_contiguous(v);
|
||||
|
||||
}
|
||||
|
||||
void align::run(ir::module &mod) {
|
||||
|
@@ -113,6 +113,7 @@ void axes::update_graph(ir::instruction *i) {
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);;
|
||||
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_RECOALESCE: return update_graph_no_edge(i);
|
||||
default: return update_graph_elementwise(i);
|
||||
}
|
||||
return;
|
||||
|
@@ -1,9 +1,9 @@
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/instructions.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
@@ -148,8 +148,11 @@ layout_t::layout_t(layout_type_t _type,
|
||||
extract_io_use(v, ptr);
|
||||
order.resize(axes.size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
for(ir::value *v: ptr){
|
||||
auto max_contiguous = align->contiguous(v);
|
||||
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
|
||||
return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank();
|
||||
});
|
||||
if(*largest){
|
||||
auto max_contiguous = align->contiguous(*largest);
|
||||
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
|
||||
return max_contiguous[a] > max_contiguous[b];
|
||||
});
|
||||
@@ -166,9 +169,8 @@ layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
|
||||
const std::vector<unsigned>& _shapes,
|
||||
const std::vector<ir::value *> &values, ir::type *_ty, size_t _id,
|
||||
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _ty, _id, align) {
|
||||
|
||||
unsigned shape_0 = shapes[order[0]];
|
||||
unsigned shape_1 = shapes[order[1]];
|
||||
unsigned shape_0 = shapes[0];
|
||||
unsigned shape_1 = shapes[1];
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
fpw = {1, 1, 1};
|
||||
@@ -196,6 +198,7 @@ layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_warps *= wpt[d];
|
||||
|
||||
if(num_warps != effective_num_warps)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
@@ -213,20 +216,38 @@ layout_scanline_t::layout_scanline_t(size_t num_warps,
|
||||
unsigned num_threads = num_warps * 32;
|
||||
nts.resize(shapes.size());
|
||||
mts.resize(shapes.size());
|
||||
bool is_dot = std::any_of(values.begin(), values.end(),
|
||||
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
|
||||
|
||||
ir::value *ptr = nullptr;
|
||||
for(ir::value *v: values)
|
||||
for(ir::user *usr: v->get_users())
|
||||
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
|
||||
ptr = st->get_pointer_operand();
|
||||
|
||||
unsigned i = order[0];
|
||||
nts[i] = clamp(size / num_threads, 1, 4);
|
||||
int contiguous = 4;
|
||||
if(ptr)
|
||||
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
|
||||
|
||||
nts[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shapes[i]));
|
||||
mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]);
|
||||
num_threads = num_threads / mts[i];
|
||||
size /= shapes[i];
|
||||
num_threads /= mts[i];
|
||||
if(is_dot)
|
||||
nts[order[1]] = clamp(size / num_threads, 1, std::min<int>(4, shapes[order[1]]));
|
||||
for(size_t d = 1; d < shapes.size(); d++){
|
||||
i = order[d];
|
||||
nts[i] = 1;
|
||||
mts[i] = clamp(num_threads, 1, shapes[i]);
|
||||
if(d > 1 || !is_dot)
|
||||
nts[i] = 1;
|
||||
mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]);
|
||||
num_threads = num_threads / mts[i];
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_threads *= mts[d];
|
||||
|
||||
if(num_warps * 32 != effective_num_threads)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
@@ -259,8 +280,8 @@ void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_
|
||||
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
|
||||
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
|
||||
if(!i_0 || !i_1 ||
|
||||
storage_info.at(i_0->get_id()).first != codegen::SHARED ||
|
||||
storage_info.at(i_1->get_id()).first != codegen::SHARED)
|
||||
!dynamic_cast<ir::copy_to_shared_inst*>(i_0) ||
|
||||
!dynamic_cast<ir::copy_to_shared_inst*>(i_1) )
|
||||
return;
|
||||
if(is_latch_1)
|
||||
res.reset(new double_buffer_info_t{value_0, value_1, phi});
|
||||
@@ -284,10 +305,9 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
||||
extract_double_bufferable(v, double_buffer);
|
||||
|
||||
// order
|
||||
if(arg->type == SCANLINE)
|
||||
order = arg->order;
|
||||
else
|
||||
order = arg->order;
|
||||
std::vector<int> arg_order = arg ? arg->order : std::vector<int>{0};
|
||||
order = arg_order;
|
||||
|
||||
ir::value* dot_a = nullptr;
|
||||
ir::value* dot_b = nullptr;
|
||||
ir::value* hmma_dot_a = nullptr;
|
||||
@@ -304,24 +324,27 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
||||
col.push_back(s);
|
||||
row.push_back(s);
|
||||
}
|
||||
|
||||
|
||||
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
||||
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
||||
if(is_nonhmma_dot_a)
|
||||
order = is_trans(dot_a) ? row : col;
|
||||
if(is_nonhmma_dot_b)
|
||||
else if(is_nonhmma_dot_b)
|
||||
order = is_trans(dot_b) ? col : row;
|
||||
|
||||
// else
|
||||
// order = row;
|
||||
// padding
|
||||
pad = 0;
|
||||
if(hmma_dot_a){
|
||||
bool row = is_trans(hmma_dot_a) ^ order[0] != 0;
|
||||
pad = 24 - shapes[row ? order[0] : order[1]] % 32;
|
||||
pad = 24 - shapes[row ? 0 : 1] % 32;
|
||||
}
|
||||
else if(hmma_dot_b){
|
||||
bool row = is_trans(hmma_dot_b) ^ order[0] != 0;
|
||||
pad = 24 - shapes[row ? order[1] : order[0]] % 32;
|
||||
pad = 24 - shapes[row ? 1 : 0] % 32;
|
||||
}
|
||||
else if(order != arg->order) {
|
||||
else if(order != arg_order) {
|
||||
pad = 4;
|
||||
}
|
||||
shapes[order[0]] += pad;
|
||||
@@ -395,6 +418,29 @@ void layout::run(ir::module &mod) {
|
||||
layouts_[id] = new layout_shared_t(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), id, align_);
|
||||
tmp_[red] = id;
|
||||
}
|
||||
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
|
||||
ir::value *val = recoalasce->get_operand(0);
|
||||
const layout_t* in_layout = get(val);
|
||||
const layout_t* out_layout = get(i);
|
||||
if(in_layout->type != HMMA_884)
|
||||
return;
|
||||
id++;
|
||||
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
|
||||
ir::type::tile_shapes_t shape(in_shape.size());
|
||||
size_t ld = out_layout->order[0];
|
||||
shape[ld] = in_shape[ld];
|
||||
for(size_t k = 0; k < in_shape.size(); k++)
|
||||
if(k != ld)
|
||||
shape[k] = 4*in_layout->fpw[k]*in_layout->wpt[k];
|
||||
// create layout
|
||||
layouts_[id] = new layout_shared_t(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), id, align_);
|
||||
tmp_[recoalasce] = id;
|
||||
}
|
||||
if(auto *atom = dynamic_cast<ir::atomic_cas_inst*>(i)){
|
||||
id++;
|
||||
layouts_[id] = new layout_shared_t(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), id, align_);
|
||||
tmp_[atom] = id;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
@@ -7,7 +7,6 @@
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/instructions.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
@@ -351,10 +350,9 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0) {
|
||||
Value *ptr = pointers->get_value(idx);
|
||||
|
||||
|
||||
ptr = builder_->CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
|
||||
ptr->getType()->getPointerAddressSpace()));
|
||||
|
||||
Value *mask = masks->get_value(idx);
|
||||
BasicBlock *current_bb = builder_->GetInsertBlock();
|
||||
Function *parent = builder_->GetInsertBlock()->getParent();
|
||||
@@ -386,9 +384,9 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
// if(cst)
|
||||
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*2*vector_size);
|
||||
// Type *fp16x2_ty = VectorType::get(builder_->getHalfTy(), 2);
|
||||
// Type *fp16x2_pack4_ty = StructType::get(ctx, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
||||
// Type *fp16x2_pack4_ty = StructType::get(*ctx_, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
||||
// FunctionType *ty = FunctionType::get(fp16x2_pack4_ty, {mask->getType(), ptr->getType()}, false);
|
||||
// std::string asm_str = "@$0 ld.global.nc.b32 {$1, $2, $3, $4}, [$5" + offset + "];";
|
||||
// std::string asm_str = "@$0 ld.global.nc.v4.b32 {$1, $2, $3, $4}, [$5" + offset + "];";
|
||||
// if(false_values)
|
||||
// asm_str += "\n\t@!$0 mov.v4.b32 {$1, $2, $3, $4}, {0, 0, 0, 0};";
|
||||
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true);
|
||||
@@ -420,31 +418,83 @@ void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* st) {
|
||||
|
||||
void generator::visit_masked_store_inst(ir::masked_store_inst* st) {
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(st->get_pointer_operand());
|
||||
distributed_tile* scalars = (distributed_tile*)tmap_.at(st->get_value_operand());
|
||||
ir::value *mask = st->get_mask_operand();
|
||||
distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
|
||||
ptrs->for_each([&](indices_t idx){
|
||||
Value *scalar = scalars->get_value(idx);
|
||||
Value *ptr = ptrs->get_value(idx);
|
||||
Value *pred = preds->get_value(idx);
|
||||
Function *parent = builder_->GetInsertBlock()->getParent();
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
|
||||
builder_->CreateCondBr(pred, mask_then_bb, mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_then_bb);
|
||||
builder_->CreateStore(scalar, ptr);
|
||||
builder_->CreateBr(mask_done_bb);
|
||||
builder_->SetInsertPoint(mask_done_bb);
|
||||
// std::string offset = "";
|
||||
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
// if(gep->getNumIndices() == 1)
|
||||
// if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
|
||||
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*4);
|
||||
// }
|
||||
// FunctionType *ty = FunctionType::get(Type::getVoidTy(ctx), {pred->getType(), ptr->getType(), scalar->getType()}, false);
|
||||
// std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;";
|
||||
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true);
|
||||
// builder.CreateCall(iasm, {pred, ptr, scalar});
|
||||
distributed_tile* masks = (distributed_tile*)tmap_.at(st->get_mask_operand());
|
||||
// vector size
|
||||
int vector_size = 1;
|
||||
int ld = ptrs->get_order()[0];
|
||||
unsigned alignment = alignment_->get(st->get_pointer_operand(), ld);
|
||||
vector_size = std::min<unsigned>(ptrs->axis(ld).contiguous, alignment);
|
||||
// create packets
|
||||
std::map<unsigned, Value*> packets;
|
||||
ir::value *arg = st->get_value_operand();
|
||||
for_each(arg, [&](indices_t idx){
|
||||
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
|
||||
unsigned linear = in->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
Value *in_value = in->get_value(idx);
|
||||
if(linear % vector_size == 0)
|
||||
packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size));
|
||||
packets[id] = builder_->CreateInsertElement(packets.at(id), in_value, linear % vector_size);
|
||||
});
|
||||
// write-back packets
|
||||
for_each(arg, [&](indices_t idx){
|
||||
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
|
||||
unsigned linear = in->get_linear_index(idx);
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0){
|
||||
// fetch tile elements
|
||||
Value *elt = packets[id];
|
||||
Value *ptr = ptrs->get_value(idx);
|
||||
Value *pred = masks->get_value(idx);
|
||||
// type information
|
||||
Type *ty = elt->getType();
|
||||
unsigned nbits = ty->getScalarSizeInBits();
|
||||
unsigned nbytes = nbits / 8;
|
||||
// extract pointer offset
|
||||
std::string offset = "";
|
||||
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
if(gep->getNumIndices() == 1)
|
||||
if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
|
||||
offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbytes);
|
||||
ptr = gep->getPointerOperand();
|
||||
}
|
||||
ptr = builder_->CreateBitCast(ptr, ty->getPointerTo(1));
|
||||
// asm argument type
|
||||
std::vector<Type*> arg_ty = {pred->getType(), ptr->getType()};
|
||||
for(int v = 0; v < vector_size; v++)
|
||||
arg_ty.push_back(ty->getScalarType());
|
||||
// asm function type
|
||||
FunctionType *fn_ty = FunctionType::get(builder_->getVoidTy(), arg_ty, false);
|
||||
// asm string
|
||||
std::string asm_str;
|
||||
asm_str += "@$0 st.global";
|
||||
if(vector_size > 1)
|
||||
asm_str += ".v" + std::to_string(vector_size);
|
||||
asm_str += ".b" + std::to_string(nbits) + " [$1" + offset + "],";
|
||||
if(vector_size > 1)
|
||||
asm_str += "{";
|
||||
for(int v = 0; v < vector_size; v++){
|
||||
if(v > 0)
|
||||
asm_str += ", ";
|
||||
asm_str += "$" + std::to_string(2 + v);
|
||||
}
|
||||
if(vector_size > 1)
|
||||
asm_str += "}";
|
||||
asm_str += ";";
|
||||
// asm constraint
|
||||
std::string constraint = "b,l";
|
||||
for(int v = 0; v < vector_size; v++){
|
||||
constraint += ",";
|
||||
constraint += (nbits == 32 ? "r" : "h");
|
||||
}
|
||||
// create inline asm
|
||||
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||
// call asm
|
||||
std::vector<Value*> args = {pred, ptr};
|
||||
for(int v = 0; v < vector_size; v++)
|
||||
args.push_back(builder_->CreateExtractElement(elt, builder_->getInt32(v)));
|
||||
builder_->CreateCall(iasm, args);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -504,23 +554,27 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||
Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
Value *ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(cas))));
|
||||
ptr = builder_->CreateBitCast(ptr, PointerType::get(builder_->getInt32Ty(), ptr->getType()->getPointerAddressSpace()));
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
tgt_->add_barrier(module, *builder_);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
Value *cas_ptr = vmap_.at(cas->get_operand(0));
|
||||
Value *cas_cmp = vmap_.at(cas->get_operand(1));
|
||||
Value *cas_val = vmap_.at(cas->get_operand(2));
|
||||
Value *old = builder_->CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
|
||||
Value *old = builder_->CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val,
|
||||
AtomicOrdering::Monotonic,
|
||||
AtomicOrdering::Monotonic);
|
||||
old = builder_->CreateExtractValue(old, {0});
|
||||
builder_->CreateStore(old, ptr);
|
||||
Value *atom_ptr;
|
||||
atom_ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))));
|
||||
atom_ptr = builder_->CreateBitCast(atom_ptr, PointerType::get(old->getType(), 3));
|
||||
|
||||
builder_->CreateStore(old, atom_ptr);
|
||||
builder_->CreateBr(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
tgt_->add_barrier(module, *builder_);
|
||||
vmap_[cas] = builder_->CreateLoad(ptr);
|
||||
vmap_[cas] = builder_->CreateLoad(atom_ptr);
|
||||
}
|
||||
|
||||
void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
@@ -533,14 +587,14 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
tgt_->add_barrier(module, *builder_);
|
||||
builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_bb);
|
||||
vmap_[xchg] = builder_->CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, AtomicOrdering::Monotonic, SyncScope::System);
|
||||
builder_->CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val,
|
||||
AtomicOrdering::Monotonic,
|
||||
SyncScope::System);
|
||||
builder_->CreateBr(tid_0_done_bb);
|
||||
builder_->SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, *builder_);
|
||||
tgt_->add_barrier(module, *builder_);
|
||||
}
|
||||
|
||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst*) {
|
||||
@@ -861,6 +915,115 @@ void generator::visit_select_inst(ir::select_inst* select) {
|
||||
|
||||
}
|
||||
|
||||
void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
|
||||
ir::value *op = rc->get_operand(0);
|
||||
ir::tile_type::tile_shapes_t shape = rc->get_type()->get_tile_shapes();
|
||||
size_t rank = shape.size();
|
||||
// temporary layout
|
||||
shared_tile *tmp = (shared_tile*)machine_layouts_.at(layouts_->get(layouts_->tmp(rc)))
|
||||
->create(rc);
|
||||
// pointer to temporary shared memory
|
||||
Type *ty = llvm_type(rc->get_type()->get_scalar_ty(), *ctx_);
|
||||
// layouts
|
||||
const analysis::layout_t* in_layout = layouts_->get(op);
|
||||
const analysis::layout_t* out_layout = layouts_->get(rc);
|
||||
// machine tiles
|
||||
distributed_tile *in_dt = (distributed_tile*)(tmap_.at(op));
|
||||
distributed_tile *out_dt = (distributed_tile*)(tmap_.at(rc));
|
||||
// WMMA configuration
|
||||
long wmma_pt[3] = { 2, 4, 1};
|
||||
long wmma[3] = { 8*in_layout->wpt[0]*in_layout->fpw[0],
|
||||
8*in_layout->wpt[1]*in_layout->fpw[1],
|
||||
1};
|
||||
// Work per thread for input layout
|
||||
long in_pt[3] = { shape[0] / wmma[0],
|
||||
shape[1] / wmma[1],
|
||||
1 };
|
||||
// Work per thread for output layout
|
||||
long out_pt[3] = { shape[0] / out_layout->mts[0],
|
||||
shape[1] / out_layout->mts[1],
|
||||
1 };
|
||||
if(rank > 2){
|
||||
wmma[2] = in_layout->wpt[2]*in_layout->fpw[2];
|
||||
in_pt[2] = shape[2] / wmma[2];
|
||||
out_pt[2] = shape[2] / out_layout->mts[2];
|
||||
}
|
||||
// Orders
|
||||
auto ord = out_layout->order;
|
||||
if(ord.size() < 3)
|
||||
ord.push_back(2);
|
||||
// pointer lanes
|
||||
std::vector<std::vector<Value*>> ptrs;
|
||||
for(int in_zz = 0; in_zz < wmma_pt[ord[2]]; in_zz++) {
|
||||
std::vector<Value*> current;
|
||||
for(int in_cc = 0; in_cc < wmma_pt[ord[1]]; in_cc++) {
|
||||
Value *base;
|
||||
base = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(layouts_->tmp(rc)))));
|
||||
base = builder_->CreateBitCast(base, PointerType::get(ty, 3));
|
||||
|
||||
// shared memory stride
|
||||
Value *stride_0 = builder_->getInt32(tmp->get_shapes()[ord[0]]);
|
||||
// indices
|
||||
Value *idx_cc = axes_.at(a_axes_->get(op, ord[1])).values[in_cc];
|
||||
// offset
|
||||
Value *off = builder_->CreateMul(stride_0, idx_cc);
|
||||
if(rank > 2){
|
||||
Value *stride_1 = builder_->CreateMul(stride_0,
|
||||
builder_->getInt32(tmp->get_shapes()[ord[1]]));
|
||||
Value *idx_zz = axes_.at(a_axes_->get(op, ord[2])).values[in_zz];
|
||||
off = builder_->CreateAdd(off, builder_->CreateMul(stride_1, idx_zz));
|
||||
}
|
||||
current.push_back(builder_->CreateGEP(base, off));
|
||||
}
|
||||
ptrs.push_back(current);
|
||||
}
|
||||
// Re-coalesce loops
|
||||
for(int in_z = 0; in_z < in_pt[ord[2]]; in_z++)
|
||||
for(int in_c = 0; in_c < in_pt[ord[1]]; in_c++){
|
||||
// write to shared
|
||||
tgt_->add_barrier(mod_, *builder_);
|
||||
for(int in_zz = 0; in_zz < wmma_pt[ord[2]]; in_zz++)
|
||||
for(int in_cc = 0; in_cc < wmma_pt[ord[1]]; in_cc++){
|
||||
std::vector<int> starts(rank), len(rank);
|
||||
starts[ord[0]] = 0;
|
||||
starts[ord[1]] = in_c*wmma_pt[ord[1]] + in_cc;
|
||||
len[ord[0]] = wmma_pt[ord[0]]*in_pt[ord[0]];
|
||||
len[ord[1]] = 1;
|
||||
if(rank > 2){
|
||||
starts[ord[2]] = in_z*wmma_pt[ord[2]] + in_zz;
|
||||
len[ord[2]] = 1;
|
||||
}
|
||||
in_dt->for_each([&](indices_t idx){
|
||||
Value *write_ptr = builder_->CreateGEP(ptrs[in_zz][in_cc], idx[ord[0]]);
|
||||
builder_->CreateStore(in_dt->get_value(idx), write_ptr);
|
||||
}, starts, len);
|
||||
}
|
||||
tgt_->add_barrier(mod_, *builder_);
|
||||
// load from shared
|
||||
for(int out_zz = 0; out_zz < out_pt[ord[2]] / in_pt[ord[2]]; out_zz++)
|
||||
for(int out_cc = 0; out_cc < out_pt[ord[1]] / in_pt[ord[1]]; out_cc++){
|
||||
std::vector<int> starts(rank), len(rank);
|
||||
starts[ord[0]] = 0;
|
||||
starts[ord[1]] = in_c*(out_pt[ord[1]] / in_pt[ord[1]]) + out_cc;
|
||||
len[ord[0]] = out_pt[ord[0]];
|
||||
len[ord[1]] = 1;
|
||||
if(rank > 2){
|
||||
starts[ord[2]] = in_z*(out_pt[ord[2]] / in_pt[ord[2]]) + out_zz;
|
||||
len[ord[2]] = 1;
|
||||
}
|
||||
out_dt->for_each([&](indices_t idx){
|
||||
indices_t read_idx(rank);
|
||||
read_idx[ord[0]] = idx[ord[0]];
|
||||
read_idx[ord[1]] = axes_.at(a_axes_->get(rc, ord[1])).values[out_cc];
|
||||
if(rank > 2)
|
||||
read_idx[ord[2]] = axes_.at(a_axes_->get(rc, ord[2])).values[out_zz];
|
||||
out_dt->set_value(idx, tmp->get_value(read_idx));
|
||||
}, starts, len);
|
||||
}
|
||||
}
|
||||
tgt_->add_barrier(mod_, *builder_);
|
||||
}
|
||||
|
||||
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
unsigned vector_size = 1;
|
||||
auto x_order = layouts_->get(cts)->order;
|
||||
@@ -1126,16 +1289,14 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
|
||||
if(tgt_->is_gpu())
|
||||
if(unsigned alloc_size = alloc_->allocated_size()){
|
||||
Type *int_8_ty = Type::getInt8Ty(*ctx_);
|
||||
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
||||
Type *int_32_ty = Type::getInt32Ty(*ctx_);
|
||||
ArrayType *array_ty = ArrayType::get(int_32_ty, alloc_size/4);
|
||||
Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
||||
GlobalVariable *sh_mem_array =
|
||||
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
sh_mem_ptr_ = builder_->CreateBitCast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
// allocate constant memory
|
||||
for(ir::alloc_const *x: src.allocs())
|
||||
visit_alloc_const(x);
|
||||
// visit functions
|
||||
for(ir::function *fn: src.get_function_list())
|
||||
visit_function(fn);
|
||||
|
@@ -143,7 +143,7 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
|
||||
axes[d].values = {builder_->getInt32(0)};
|
||||
}
|
||||
}
|
||||
return new distributed_tile(ty, shapes, layout_->order, axes, *builder_, false);
|
||||
return new distributed_tile(ty, shapes, layout_->order, axes, *builder_);
|
||||
}
|
||||
|
||||
machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder,
|
||||
|
@@ -45,9 +45,8 @@ llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size)
|
||||
return VectorType::get(ty, vector_size);
|
||||
}
|
||||
|
||||
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize)
|
||||
: tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), order_(order), builder_(builder) {
|
||||
vector_size_ = vectorize?ty_->getVectorNumElements():1;
|
||||
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder)
|
||||
: tile(ty, shapes), axes_(axes), order_(order), builder_(builder) {
|
||||
init_indices();
|
||||
}
|
||||
|
||||
@@ -73,13 +72,31 @@ indices_t distributed_tile::get_ordered_indices(unsigned id) {
|
||||
}
|
||||
|
||||
|
||||
void distributed_tile::for_each(std::function<void (indices_t)> fn) {
|
||||
for(unsigned i = 0; i < ordered_indices_.size(); i++){
|
||||
if(i % vector_size_ == 0)
|
||||
fn(ordered_indices_[i]);
|
||||
void distributed_tile::for_each(std::function<void (indices_t)> fn, int start, int end) {
|
||||
if(end < 0)
|
||||
end = ordered_indices_.size() + end + 1;
|
||||
for(unsigned i = start; i < end; i++)
|
||||
fn(ordered_indices_[i]);
|
||||
}
|
||||
|
||||
void distributed_tile::for_each(std::function<void(indices_t)> fn, std::vector<int> starts, std::vector<int> sizes){
|
||||
int rank = sizes.size();
|
||||
int len = 1;
|
||||
for(int s: sizes)
|
||||
len *= s;
|
||||
|
||||
for(int i = 0; i < len; i++){
|
||||
indices_t idx(rank);
|
||||
int current = i;
|
||||
for(int k = 0; k < rank; k++){
|
||||
idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]);
|
||||
current = current / sizes[k];
|
||||
}
|
||||
fn(idx);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Shared Tile */
|
||||
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
|
||||
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
|
||||
@@ -126,7 +143,9 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_
|
||||
}
|
||||
|
||||
|
||||
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx) {
|
||||
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes,
|
||||
const std::vector<int>& perm, const std::vector<int>& order,
|
||||
indices_t idx) {
|
||||
// strides
|
||||
std::vector<Value*> strides(order.size());
|
||||
strides[order[0]] = builder.getInt32(1);
|
||||
|
@@ -1,6 +1,8 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
@@ -60,8 +62,43 @@ ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
|
||||
}
|
||||
|
||||
void coalesce::run(ir::module &mod) {
|
||||
// find values to rematerialize
|
||||
size_t num_groups = layout_->num_layouts();
|
||||
|
||||
for(size_t id = 0; id < num_groups; id++) {
|
||||
if(layout_->get(id)->type != analysis::HMMA_884)
|
||||
continue;
|
||||
// extract memory stores
|
||||
const auto& values = layout_->values_of(id);
|
||||
ir::value* dot = nullptr;
|
||||
for(ir::value *v: values)
|
||||
if(auto x = dynamic_cast<ir::dot_inst*>(v))
|
||||
dot = x;
|
||||
|
||||
ir::builder& builder = mod.get_builder();
|
||||
std::vector<ir::value*> worklist = {dot};
|
||||
std::set<ir::value*> seen;
|
||||
while(!worklist.empty()) {
|
||||
ir::value *current = worklist.back();
|
||||
seen.insert(current);
|
||||
worklist.pop_back();
|
||||
// stop if trunc
|
||||
if(auto x = dynamic_cast<ir::fp_trunc_inst*>(current)){
|
||||
builder.set_insert_point_after(x);
|
||||
ir::recoalesce_inst* rc = ir::recoalesce_inst::create(x);
|
||||
builder.insert(rc);
|
||||
x->replace_all_uses_with(rc);
|
||||
rc->replace_uses_of_with(rc, x);
|
||||
break;
|
||||
}
|
||||
// recurse
|
||||
for(ir::user *u: current->get_users())
|
||||
if(seen.find(u) == seen.end())
|
||||
worklist.push_back(u);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// find values to rematerialize
|
||||
std::vector<ir::io_inst*> remat;
|
||||
for(size_t id = 0; id < num_groups; id++) {
|
||||
const auto& values = layout_->values_of(id);
|
||||
@@ -71,8 +108,10 @@ void coalesce::run(ir::module &mod) {
|
||||
extract_io_use(v, io);
|
||||
// extract leading axes
|
||||
std::map<int, std::vector<ir::io_inst*>> axes;
|
||||
for(ir::io_inst *i: io)
|
||||
extract_ld(i, axes);
|
||||
for(ir::io_inst *i: io){
|
||||
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->axes.size())
|
||||
extract_ld(i, axes);
|
||||
}
|
||||
// update list of values to rematerialize
|
||||
if(axes.empty())
|
||||
continue;
|
||||
|
@@ -1,21 +1,37 @@
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/codegen/instructions.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
inline bool is_shared(ir::value *v) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
|
||||
inline bool is_shmem_op(ir::instruction* i, int op) {
|
||||
if(i->get_id() == ir::INST_DOT)
|
||||
return op==0 || op==1;
|
||||
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
|
||||
return op==0;
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool is_shmem_res(ir::value* v){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return false;
|
||||
return storage_info.at(i->get_id()).first == codegen::SHARED;
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_REDUCE)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// run pass on module
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||
@@ -36,9 +52,8 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool
|
||||
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
|
||||
return;
|
||||
}
|
||||
ir::value_id_t id = i->get_id();
|
||||
// already in shared memory
|
||||
if(to_shared && storage_info.at(id).first == SHARED)
|
||||
if(to_shared && is_shmem_res(i))
|
||||
return;
|
||||
// copy
|
||||
builder.set_insert_point_after(i);
|
||||
@@ -53,18 +68,19 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool
|
||||
void cts::run(ir::module &mod) {
|
||||
// Add shared copies
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
auto storage = storage_info.at(i->get_id());
|
||||
for(ir::function* fn: mod.get_function_list()){
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
size_t num_op = i->get_num_operands();
|
||||
// copy to shared operands
|
||||
for(size_t k = 0; k < storage.second.size(); k++)
|
||||
if(storage.second[k] == SHARED)
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(is_shmem_op(i, k))
|
||||
add_copy(i, i->get_operand(k), builder, true);
|
||||
// copy from shared operands
|
||||
for(size_t k = 0; k < storage.second.size(); k++)
|
||||
if(storage.second[k] == DISTRIBUTED &&
|
||||
is_shared(i->get_operand(k))){
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(!dynamic_cast<ir::phi_node*>(i) &&
|
||||
!is_shmem_op(i,k) &&
|
||||
is_shmem_res(i->get_operand(k))){
|
||||
add_copy(i, i->get_operand(k), builder, false);
|
||||
}
|
||||
}
|
||||
|
@@ -3,7 +3,6 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/instructions.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
|
@@ -180,6 +180,11 @@ host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module
|
||||
hst_->engine = builder.create();
|
||||
}
|
||||
|
||||
std::unique_ptr<buffer> host_module::symbol(const char *name) const {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------ */
|
||||
// OpenCL //
|
||||
/* ------------------------ */
|
||||
@@ -211,10 +216,21 @@ ocl_module::ocl_module(driver::context * context, std::unique_ptr<llvm::Module>
|
||||
// }
|
||||
}
|
||||
|
||||
std::unique_ptr<buffer> ocl_module::symbol(const char *name) const {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
/* ------------------------ */
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){
|
||||
size_t start_replace = str.find(begin);
|
||||
size_t end_replace = str.find(end, start_replace);
|
||||
if(start_replace == std::string::npos)
|
||||
return false;
|
||||
str.replace(start_replace, end_replace + 1 - start_replace, target);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
|
||||
// options
|
||||
@@ -231,19 +247,17 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "ptx63", Assembly);
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
size_t start_replace = result.find(".version");
|
||||
size_t end_replace = result.find('\n', start_replace);
|
||||
assert(start_replace != std::string::npos);
|
||||
result.replace(start_replace, end_replace - start_replace, ".version 6.4");
|
||||
find_and_replace(result, ".version", "\n", ".version 6.4\n");
|
||||
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
|
||||
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
|
||||
return result;
|
||||
}
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// exit(EXIT_FAILURE);
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// std::cout << source << std::endl;
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
unsigned int errbufsize = 8096;
|
||||
@@ -260,11 +274,12 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
|
||||
}
|
||||
}
|
||||
|
||||
cu_buffer* cu_module::symbol(const char *name) const{
|
||||
std::unique_ptr<buffer> cu_module::symbol(const char *name) const{
|
||||
CUdeviceptr handle;
|
||||
size_t size;
|
||||
dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name);
|
||||
return new cu_buffer(ctx_, size, handle, false);
|
||||
std::unique_ptr<buffer> res(new cu_buffer(ctx_, size, handle, false));
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
|
||||
|
@@ -48,6 +48,9 @@ value *builder::get_int32(unsigned val) {
|
||||
return constant_int::get(type::get_int32_ty(ctx_), val);
|
||||
}
|
||||
|
||||
type *builder::get_void_ty()
|
||||
{ return type::get_void_ty(ctx_); }
|
||||
|
||||
type *builder::get_int1_ty()
|
||||
{ return type::get_int1_ty(ctx_); }
|
||||
|
||||
@@ -132,19 +135,12 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved, const std::string
|
||||
return insert(binary_operator::create(OPCODE, lhs, rhs), name);\
|
||||
}
|
||||
|
||||
#define DEFINE_UNARY_FLOAT(SUFFIX)\
|
||||
value *builder::create_ ## SUFFIX(value *arg, const std::string &name){\
|
||||
return insert(binary_operator::create_ ## SUFFIX(arg), name);\
|
||||
}
|
||||
|
||||
// Binary
|
||||
DEFINE_BINARY_FLOAT(fmul, binary_op_t::FMul)
|
||||
DEFINE_BINARY_FLOAT(fdiv, binary_op_t::FDiv)
|
||||
DEFINE_BINARY_FLOAT(frem, binary_op_t::FRem)
|
||||
DEFINE_BINARY_FLOAT(fadd, binary_op_t::FAdd)
|
||||
DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub)
|
||||
// Unary
|
||||
DEFINE_UNARY_FLOAT(fneg)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -171,10 +167,7 @@ value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs,
|
||||
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, name, false, false);\
|
||||
}
|
||||
|
||||
#define DEFINE_UNARY_INT(SUFFIX)\
|
||||
value *builder::create_ ## SUFFIX(value *arg, const std::string &name){\
|
||||
return insert(binary_operator::create_ ## SUFFIX(arg), name);\
|
||||
}
|
||||
|
||||
|
||||
// Binary
|
||||
DEFINE_NOWRAP_BINARY(mul, binary_op_t::Mul)
|
||||
@@ -190,9 +183,6 @@ DEFINE_BINARY_INT(urem, binary_op_t::URem)
|
||||
DEFINE_BINARY_INT(and, binary_op_t::And)
|
||||
DEFINE_BINARY_INT(or, binary_op_t::Or)
|
||||
DEFINE_BINARY_INT(xor, binary_op_t::Xor)
|
||||
// Unary
|
||||
DEFINE_UNARY_INT(neg)
|
||||
DEFINE_UNARY_INT(not)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -138,23 +138,23 @@ binary_operator *binary_operator::create(binary_op_t op, value *lhs, value *rhs,
|
||||
return new binary_operator(op, lhs, rhs, lhs->get_type(), name, next);
|
||||
}
|
||||
|
||||
binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){
|
||||
assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty());
|
||||
value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
|
||||
return binary_operator::create(binary_op_t::FSub, zero, arg, name, next);
|
||||
}
|
||||
//binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){
|
||||
// assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty());
|
||||
// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
|
||||
// return binary_operator::create(binary_op_t::FSub, zero, arg, name, next);
|
||||
//}
|
||||
|
||||
binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){
|
||||
assert(arg->get_type()->get_scalar_ty()->is_integer_ty());
|
||||
value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
|
||||
return binary_operator::create(binary_op_t::Sub, zero, arg, name, next);
|
||||
}
|
||||
//binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){
|
||||
// assert(arg->get_type()->get_scalar_ty()->is_integer_ty());
|
||||
// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type()->get_scalar_ty());
|
||||
// return binary_operator::create(binary_op_t::Sub, zero, arg, name, next);
|
||||
//}
|
||||
|
||||
binary_operator *binary_operator::create_not(value *arg, const std::string &name, instruction *next){
|
||||
assert(arg->get_type()->is_integer_ty());
|
||||
constant *mask = constant::get_all_ones_value(arg->get_type());
|
||||
return binary_operator::create(binary_op_t::Xor, arg, mask, name, next);
|
||||
}
|
||||
//binary_operator *binary_operator::create_not(value *arg, const std::string &name, instruction *next){
|
||||
// assert(arg->get_type()->is_integer_ty());
|
||||
// constant *mask = constant::get_all_ones_value(arg->get_type());
|
||||
// return binary_operator::create(binary_op_t::Xor, arg, mask, name, next);
|
||||
//}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cmp_inst classes
|
||||
@@ -762,6 +762,12 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri
|
||||
return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next);
|
||||
}
|
||||
|
||||
// recoalesce
|
||||
recoalesce_inst* recoalesce_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new recoalesce_inst(arg->get_type(), INST_RECOALESCE, arg, name, next);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// barrier
|
||||
barrier_inst::barrier_inst(context &ctx, const std::string &name,
|
||||
|
@@ -57,7 +57,10 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
}
|
||||
case Token::MASKED_DEREF: {
|
||||
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
|
||||
return set_ret(bld_->create_masked_load(rhs, lhs, ir::undef_value::get(ret_ty)));
|
||||
ir::value* false_value = ir::undef_value::get(ret_ty->get_scalar_ty());
|
||||
if(ret_ty->is_tile_ty())
|
||||
false_value = bld_->create_splat(false_value, ret_ty->get_tile_shapes());
|
||||
return set_ret(bld_->create_masked_load(rhs, lhs, false_value));
|
||||
}
|
||||
case Token::ELLIPSIS: {
|
||||
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
|
||||
@@ -76,7 +79,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
return set_ret(bld_->create_add(lhs, rhs));
|
||||
case '-':
|
||||
if(binary->lhs_->Type()->ToPointer())
|
||||
return set_ret(bld_->create_gep(lhs, {bld_->create_neg(rhs)}));
|
||||
return set_ret(bld_->create_gep(lhs, {GenUnaryMinus(rhs)}));
|
||||
else if(flt)
|
||||
return set_ret(bld_->create_fsub(lhs, rhs));
|
||||
else
|
||||
@@ -147,7 +150,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
if(flt)
|
||||
return set_ret(bld_->create_fcmpONE(lhs, rhs));
|
||||
else
|
||||
return set_ret(bld_->create_icmpEQ(lhs, rhs));
|
||||
return set_ret(bld_->create_icmpNE(lhs, rhs));
|
||||
default:
|
||||
error_not_implemented();
|
||||
}
|
||||
@@ -166,6 +169,16 @@ ir::reduce_inst::op_t reduce_op(int tag, bool is_float) {
|
||||
should_not_happen();
|
||||
return reduce_inst::op_t();
|
||||
}
|
||||
|
||||
ir::value* Generator::GenUnaryMinus(ir::value* arg) {
|
||||
ir::type *ty = arg->get_type();
|
||||
ir::type *sca_ty = ty->get_scalar_ty();
|
||||
ir::value *_0 = ir::constant_fp::get_zero_value_for_negation(sca_ty);
|
||||
if(ty->is_tile_ty())
|
||||
_0 = bld_->create_splat(_0, ty->get_tile_shapes());
|
||||
return bld_->create_sub(_0, arg);
|
||||
}
|
||||
|
||||
void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||
// recursion
|
||||
Visit(unary->operand_);
|
||||
@@ -174,17 +187,17 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||
ir::type *arg_scal_ty = arg_ty->get_scalar_ty();
|
||||
// return
|
||||
switch (unary->op_) {
|
||||
case Token::PREFIX_INC: return error_not_implemented();
|
||||
case Token::PREFIX_DEC: return error_not_implemented();
|
||||
case Token::PREFIX_INC: return error_not_implemented();
|
||||
case Token::PREFIX_DEC: return error_not_implemented();
|
||||
case Token::POSTFIX_INC: return error_not_implemented();
|
||||
case Token::POSTFIX_DEC: return error_not_implemented();
|
||||
case Token::ADDR: return error_not_implemented();
|
||||
case Token::DEREF: return set_ret(bld_->create_load(arg));
|
||||
case Token::PLUS: return error_not_implemented();
|
||||
case Token::MINUS: return error_not_implemented();
|
||||
case '~': return set_ret(bld_->create_neg(arg));
|
||||
case '!': return set_ret(bld_->create_not(arg));
|
||||
case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
||||
case Token::ADDR: return error_not_implemented();
|
||||
case Token::DEREF: return set_ret(bld_->create_load(arg));
|
||||
case Token::PLUS: return error_not_implemented();
|
||||
case Token::MINUS: return set_ret(GenUnaryMinus(arg));
|
||||
case '~': return error_not_implemented();
|
||||
case '!': return error_not_implemented();
|
||||
case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
||||
case Token::REDUCE: {
|
||||
int ax, tag;
|
||||
UnaryOp::decodeRed(unary->info_, ax, tag);
|
||||
@@ -232,11 +245,54 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
||||
else
|
||||
return should_not_happen();
|
||||
}
|
||||
if(name == "get_num_programs"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
ir::value* ret = ret_;
|
||||
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
|
||||
return set_ret(bld_->create_get_num_program(axis->get_value()));
|
||||
else
|
||||
return should_not_happen();
|
||||
}
|
||||
if(name == "atomic_cas"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
ir::value* ptr = ret_;
|
||||
VisitExpr(funcCall->Args()->at(1));
|
||||
ir::value* cmp = ret_;
|
||||
VisitExpr(funcCall->Args()->at(2));
|
||||
ir::value* val = ret_;
|
||||
return set_ret(bld_->create_atomic_cas(ptr, cmp, val));
|
||||
}
|
||||
if(name == "atomic_xchg"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
ir::value* ptr = ret_;
|
||||
VisitExpr(funcCall->Args()->at(1));
|
||||
ir::value* val = ret_;
|
||||
return set_ret(bld_->create_atomic_exch(ptr, val));
|
||||
}
|
||||
if(name == "sqrtf"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
ir::value* ret = ret_;
|
||||
return set_ret(bld_->create_sqrt(ret));
|
||||
}
|
||||
if(name == "calloc"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
ir::value* ret = ret_;
|
||||
ir::constant_int *size = dynamic_cast<ir::constant_int*>(ret);
|
||||
assert(size);
|
||||
ir::alloc_const* alloc = new ir::alloc_const(bld_->get_int8_ty(), size);
|
||||
mod_->add_alloc(alloc);
|
||||
return set_ret(alloc);
|
||||
}
|
||||
//TODO: integrate this into conditionalop
|
||||
if(name == "select"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
ir::value* cond = ret_;
|
||||
VisitExpr(funcCall->Args()->at(1));
|
||||
ir::value* true_val = ret_;
|
||||
VisitExpr(funcCall->Args()->at(2));
|
||||
ir::value* false_val = ret_;
|
||||
return set_ret(bld_->create_select(cond, true_val, false_val));
|
||||
}
|
||||
return error_not_implemented();
|
||||
}
|
||||
|
||||
@@ -350,12 +406,15 @@ void Generator::VisitForStmt(ForStmt *forStmt) {
|
||||
ir::value *cond = ret_;
|
||||
return bld_->create_cond_br(cond, loop_bb, next_bb);
|
||||
});
|
||||
VisitStmt(init_);
|
||||
VisitExpr(cond_);
|
||||
ir::value *cond = ret_;
|
||||
bld_->create_cond_br(cond, loop_bb, next_bb);
|
||||
if(init_)
|
||||
VisitStmt(init_);
|
||||
// VisitExpr(cond_);
|
||||
// ir::value *cond = ret_;
|
||||
// bld_->create_cond_br(cond, loop_bb, next_bb);
|
||||
bld_->create_br(loop_bb);
|
||||
bld_->set_insert_point(loop_bb);
|
||||
VisitStmt(body_);
|
||||
if(body_)
|
||||
VisitStmt(body_);
|
||||
if(!is_terminator(ret_))
|
||||
mod_->get_continue_fn()();
|
||||
ir::basic_block *stop_bb = bld_->get_insert_block();
|
||||
@@ -512,6 +571,8 @@ ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) {
|
||||
else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() &&
|
||||
src_scalar_ty->get_integer_bitwidth())
|
||||
return bld_->create_int_cast(src, dst_ty, dst_signed);
|
||||
else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_pointer_ty())
|
||||
return bld_->create_cast(ir::BitCast, src, dst_ty);
|
||||
else{
|
||||
should_not_happen();
|
||||
return nullptr;
|
||||
@@ -611,6 +672,8 @@ ir::type* Generator::GenIRFuncType(FuncType* type, ir::context& ctx) {
|
||||
ir::type* Generator::GenIRPointerType(PointerType* type, ir::context& ctx) {
|
||||
ir::type* ele_ty = GenIRType(type->Derived().GetPtr(), ctx);
|
||||
unsigned addr_space = 1;
|
||||
if(type->Derived().IsConstantQualified())
|
||||
addr_space = 4;
|
||||
return ir::pointer_type::get(ele_ty, addr_space);
|
||||
}
|
||||
|
||||
|
@@ -1083,14 +1083,12 @@ QualType Parser::ParseDeclSpec(int* storageSpec, int* funcSpec, int* alignSpec)
|
||||
*storageSpec |= S_THREAD;
|
||||
break;
|
||||
|
||||
case Token::AUTO:
|
||||
EnsureAndSetStorageSpec(tok, storageSpec, S_AUTO);
|
||||
break;
|
||||
|
||||
// Type qualifier
|
||||
case Token::CONST: qualSpec |= Qualifier::CONST; break;
|
||||
case Token::RESTRICT: qualSpec |= Qualifier::RESTRICT; break;
|
||||
case Token::VOLATILE: qualSpec |= Qualifier::VOLATILE; break;
|
||||
case Token::CMEM: qualSpec |= Qualifier::CMEM; break;
|
||||
|
||||
// Type specifier
|
||||
case Token::SIGNED:
|
||||
@@ -1551,6 +1549,7 @@ int Parser::ParseQual() {
|
||||
case Token::CONST: qualSpec |= Qualifier::CONST; break;
|
||||
case Token::RESTRICT: qualSpec |= Qualifier::RESTRICT; break;
|
||||
case Token::VOLATILE: qualSpec |= Qualifier::VOLATILE; break;
|
||||
case Token::CMEM: qualSpec |= Qualifier::CMEM; break;
|
||||
case Token::ATOMIC: Error(tok, "do not support 'atomic'"); break;
|
||||
default: ts_.PutBack(); return qualSpec;
|
||||
}
|
||||
@@ -1769,6 +1768,7 @@ QualType Parser::ParseArrayFuncDeclarator(const Token* ident, QualType base) {
|
||||
if (!base->Complete()) {
|
||||
Error(ident, "'%s' has incomplete element type", ident->str_.c_str());
|
||||
}
|
||||
// return a pointer for tiles in constant memory:
|
||||
return TileType::New(shape, base);
|
||||
|
||||
} else if (ts_.Try('(')) { // Function declaration
|
||||
|
@@ -7,6 +7,7 @@
|
||||
static MemPoolImp<Token> tokenPool;
|
||||
|
||||
const std::unordered_map<std::string, int> Token::kwTypeMap_ {
|
||||
{ "__constant__", Token::CMEM },
|
||||
{ "__global__", Token::GLOBAL },
|
||||
{ "auto", Token::AUTO },
|
||||
{ "break", Token::BREAK },
|
||||
|
@@ -294,7 +294,8 @@ std::string ArithmType::Str() const {
|
||||
bool PointerType::Compatible(const Type& other) const {
|
||||
// C11 6.7.6.1 [2]: pointer compatibility
|
||||
auto otherPointer = other.ToPointer();
|
||||
return otherPointer && derived_->Compatible(*otherPointer->derived_);
|
||||
return otherPointer &&
|
||||
derived_->Compatible(*otherPointer->derived_);
|
||||
|
||||
// FIXME(wgtdkp): cannot loose compatible constraints
|
||||
//return other.IsInteger() ||
|
||||
|
@@ -184,10 +184,20 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
||||
// kernel uses too much resources
|
||||
if(!bin)
|
||||
return;
|
||||
// copy constants
|
||||
std::unique_ptr<driver::buffer> buffer;
|
||||
for(ir::alloc_const* alloc: ir->allocs()){
|
||||
std::string name = alloc->get_name();
|
||||
auto it = cst_.find(name);
|
||||
if(it == cst_.end())
|
||||
throw std::runtime_error("constant not set before execution");
|
||||
buffer = bin->symbol(name.c_str());
|
||||
stream->write(&*buffer, true, 0, it->second);
|
||||
}
|
||||
// benchmark
|
||||
ir::function *tmp = ir->get_function_list()[0];
|
||||
caller call(tmp, std::move(bin), opt);
|
||||
double ts = tools::bench([&]() { call(stream, grid_fn(opt), args); }, stream);
|
||||
double ts = tools::bench([&]() { call(stream, grid_fn(opt), args); }, stream, true);
|
||||
// save best
|
||||
if(ts < best_ts) {
|
||||
best_ts = ts;
|
||||
@@ -222,20 +232,14 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
|
||||
disassociate.run(module);
|
||||
|
||||
// ir::print(module, std::cout);
|
||||
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
align.run(module);
|
||||
cts.run(module);
|
||||
axes.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
layouts.run(module);
|
||||
coalesce.run(module);
|
||||
dce.run(module);
|
||||
@@ -246,17 +250,19 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
dce.run(module);
|
||||
align.run(module);
|
||||
axes.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
layouts.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run(module);
|
||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
return std::unique_ptr<driver::module>();
|
||||
barriers.run(module);
|
||||
// std::cout << "isel" << std::endl;
|
||||
// ir::print(module, std::cout);
|
||||
isel.visit(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
// done
|
||||
// exit(EXIT_FAILURE);
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -273,8 +279,13 @@ R"(
|
||||
#define __aligned(A) __attribute__((aligned(A)))
|
||||
#define __multipleof(A) __attribute__((multipleof(A)))
|
||||
|
||||
extern int atomic_cas(int*, int, int);
|
||||
extern int atomic_xchg(int*, int);
|
||||
extern int get_program_id(int);
|
||||
extern int get_num_programs(int);
|
||||
extern float sqrtf(float);
|
||||
extern int select(bool, int, int);
|
||||
extern char __constant__ * calloc(int);
|
||||
)";
|
||||
}
|
||||
|
||||
@@ -316,5 +327,9 @@ void function::operator()(const std::vector<arg>& args, const grid_t& grid, driv
|
||||
return this->operator()(args, [&grid](const options_t&){ return grid; }, stream);
|
||||
}
|
||||
|
||||
void function::set_cst(const std::string& name, void* data, size_t n_bytes) {
|
||||
cst_[name] = std::vector<char>((char*)data, (char*)data + n_bytes);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -1,157 +0,0 @@
|
||||
import tensorflow as tf
|
||||
import triton
|
||||
import numpy as np
|
||||
|
||||
src = '''
|
||||
#if AT == 1
|
||||
#define USE_A ^a
|
||||
#define STRIDE_AK lda
|
||||
#define STRIDE_AM 1
|
||||
#define BROADCAST_AK :, newaxis
|
||||
#define BROADCAST_AM newaxis, :
|
||||
#define SHAPE_A TK, TM
|
||||
#else
|
||||
#define USE_A a
|
||||
#define STRIDE_AK 1
|
||||
#define STRIDE_AM lda
|
||||
#define BROADCAST_AK newaxis, :
|
||||
#define BROADCAST_AM :, newaxis
|
||||
#define SHAPE_A TM, TK
|
||||
#endif
|
||||
|
||||
#if BT == 1
|
||||
#define USE_B ^b
|
||||
#define STRIDE_BK 1
|
||||
#define STRIDE_BM ldb
|
||||
#define BROADCAST_BN newaxis, :
|
||||
#define BROADCAST_BK :, newaxis
|
||||
#define SHAPE_B TN, TK
|
||||
#else
|
||||
#define USE_B b
|
||||
#define STRIDE_BK ldb
|
||||
#define STRIDE_BM 1
|
||||
#define BROADCAST_BN :, newaxis
|
||||
#define BROADCAST_BK newaxis, :
|
||||
#define SHAPE_B TK, TN
|
||||
#endif
|
||||
|
||||
void dot (TYPE* A __readonly __noalias __align(16),
|
||||
TYPE* B __readonly __noalias __align(16),
|
||||
TYPE* C __writeonly __noalias __align(16),
|
||||
int lda, int ldb, int ldc,
|
||||
int N, int* lut,
|
||||
int* locks, int nlocks) {
|
||||
int ridx = get_program_id(0);
|
||||
float c[TM, TN] = 0;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
// load LUT header
|
||||
int *header = lut + get_program_id(1) * 4;
|
||||
int offset = *(header + 0);
|
||||
int K = *(header + 1);
|
||||
int column = *(header + 2);
|
||||
int lockid = *(header + 3);
|
||||
int *plut = lut + offset * 2;
|
||||
int offx = ridx;
|
||||
int offy = 0;
|
||||
// compute x, y offsets
|
||||
int rxa[TM] = offx * TM + (0 ... TM);
|
||||
int ryb[TN] = offy * TN + (0 ... TN);
|
||||
// bounds checking
|
||||
bool checka[SHAPE_A] = (rxa < N)[:, newaxis];
|
||||
bool checkb[SHAPE_B] = 1;
|
||||
// base offset
|
||||
int offa[SHAPE_A] = rxa[BROADCAST_AM] * STRIDE_AM + rka[BROADCAST_AK] * STRIDE_AK;
|
||||
int offb[SHAPE_B] = ryb[BROADCAST_BN] * STRIDE_BN + rkb[BROADCAST_BK] * STRIDE_BK;
|
||||
for(int k = K; k > 0; k -= 1) {
|
||||
// fetch block indices
|
||||
int ak = *(plut + 0);
|
||||
int bk = *(plut + 1);
|
||||
lut += 2;
|
||||
// compute pointers to blocks
|
||||
TYPE* pa[SHAPE_A] = A + offa + ak * TK * lda;
|
||||
TYPE* pb[SHAPE_B] = B + offb + bk * TK * TN;
|
||||
// load blocks
|
||||
TYPE a[SHAPE_A] = checka ? *pa : 0;
|
||||
TYPE b[SHAPE_B] = *pb;
|
||||
// multiply blocks
|
||||
c += USE_A @ USE_B;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = column * TN + (0 ... TN);
|
||||
TYPE* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
|
||||
bool checkc[TM, TN] = (rxc < N)[:, newaxis];
|
||||
if(lockid == 0) {
|
||||
*?(checkc) pc = c;
|
||||
}
|
||||
else {
|
||||
int *plock = locks + ridx*nlocks + lockid - 1;
|
||||
int *pcount = plock + get_num_program(0)*nlocks;
|
||||
while(atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *pc;
|
||||
atomic_exch(pcount, 1);
|
||||
atomic_exch(plock, 0);
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
# std::string dot::triton_c_src_dw() const {
|
||||
# bool AT = (op_ == WGRAD);
|
||||
# bool BT = (op_ == FPROP);
|
||||
# std::string usea = AT ? "trans(a)" : "a";
|
||||
# std::string useb = BT ? "trans(b)" : "b";
|
||||
# std::string sizea = AT ? "TK, TM" : "TM, TK";
|
||||
# std::string sizeb = BT ? "TN, TK" : "TK, TN";
|
||||
# std::string bca0 = AT ? "newaxis, :" : ":, newaxis";
|
||||
# std::string bca1 = AT ? ":, newaxis" : "newaxis, :";
|
||||
# std::string bcb0 = BT ? ":, newaxis" : "newaxis, :";
|
||||
# std::string bcb1 = BT ? "newaxis, :" : ":, newaxis";
|
||||
# std::string lda0 = AT ? "*lda" : "";
|
||||
# std::string lda1 = AT ? "" : "*lda";
|
||||
# std::string ldb0 = BT ? "" : "*ldb";
|
||||
# std::string ldb1 = BT ? "*ldb" : "" ;
|
||||
# std::string result =
|
||||
# R"(
|
||||
# const tunable int TM = {)" + std::to_string(BS_) + R"(};
|
||||
# const tunable int TN = {)" + std::to_string(BS_) + R"(};
|
||||
# const tunable int TK = {32};
|
||||
# void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||
# restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||
# )" + c_ty_ + R"(* C,
|
||||
# int lda, int ldb, int ldc,
|
||||
# int N, int* lut,
|
||||
# int* locks, int nlocks) {
|
||||
# int ridx = get_range_id(0);
|
||||
# float acc[TM, TN] = 0;
|
||||
# int rka[TK] = 0 ... TK;
|
||||
# int rkb[TK] = 0 ... TK;
|
||||
# int *header = lut + ridx * 2;
|
||||
# int offx = *(header + 0);
|
||||
# int offy = *(header + 1);
|
||||
# int rxa[TM] = offx*TM + (0 ... TM);
|
||||
# int ryb[TN] = offy*TN + (0 ... TN);
|
||||
# bool checka[TK, TM] = (rka < N)[:, newaxis];
|
||||
# bool checkb[TK, TN] = (rkb < N)[:, newaxis];
|
||||
# int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(;
|
||||
# int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||
# )" + ab_ty_ + " * pa[" + sizea + R"(] = A + offa;
|
||||
# )" + ab_ty_ + " * pb[" + sizeb + R"(] = B + offb;
|
||||
# )" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0;
|
||||
# )" + ab_ty_ + " b[" + sizeb + R"(] = checkb ? *pb : 0;
|
||||
# for(int k = N; k > 0; k = k - TK) {
|
||||
# acc = dot()" + usea + ", " + useb + R"(, acc);
|
||||
# pa = pa + TK)" + lda1 + R"(;
|
||||
# pb = pb + TK)" + ldb1 + R"(;
|
||||
# a = checka ? *pa : 0;
|
||||
# b = checkb ? *pb : 0;
|
||||
# }
|
||||
# int rxc[TM] = (0 ... TM);
|
||||
# int ryc[TN] = (0 ... TN);
|
||||
# )" + c_ty_ + R"( c[TM, TN] = acc;
|
||||
# )" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis]*TM + ryc[newaxis, :] + ridx*TM*TN;
|
||||
# *pc = c;
|
||||
# })";
|
@@ -1,15 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
N, C, K = 32, 8, 32
|
||||
H, W = 16, 16
|
||||
R, S = 3, 3
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn(N, C, H, W).cuda()
|
||||
b = torch.ones(C, R, S, K).cuda()
|
||||
|
||||
rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
||||
tc = triton.ops.conv(a, b)
|
||||
print((rc - tc).abs().max())
|
||||
#print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max())
|
||||
#print(tc[31, 31,:,:])
|
@@ -1,71 +0,0 @@
|
||||
import numpy as np
|
||||
import triton
|
||||
|
||||
def run_tf():
|
||||
M, N, K = 2048, 2048, 2048
|
||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||
triton_c = triton.ops.dot(a, b, False, True, 1)
|
||||
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
||||
triton_y = tf.math.reduce_mean(triton_d)
|
||||
fw_c = tf.matmul(a, b, False, True)
|
||||
fw_d = tf.matmul(fw_c, b, True, False)
|
||||
fw_y = tf.math.reduce_mean(fw_d)
|
||||
# Gradient
|
||||
triton_da, triton_db = tf.gradients(triton_y, [a, b])
|
||||
fw_da, fw_db = tf.gradients(fw_y, [a, b])
|
||||
# Reference
|
||||
feed_dict = {a: np.random.rand(M, K).astype(np.float32),
|
||||
b: np.random.rand(K, N).astype(np.float32)}
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([triton_da, fw_da, triton_db, fw_db, fw_y, triton_y], feed_dict = feed_dict)
|
||||
triton_da, fw_da = result[0][0], result[1][0]
|
||||
triton_db, fw_db = result[2][0], result[3][0]
|
||||
# Benchmark
|
||||
nanosec = triton.bench_registry[triton_d]
|
||||
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
print('Diff DA:', (triton_da - fw_da).max())
|
||||
print('Diff DB:', (triton_db - fw_db).max())
|
||||
|
||||
|
||||
def run_torch():
|
||||
torch.manual_seed(0)
|
||||
M, N, K = 2048, 2048, 2048
|
||||
a = torch.randn(M, K).cuda()
|
||||
b = torch.randn(K, N).cuda()
|
||||
a.requires_grad_(True)
|
||||
b.requires_grad_(True)
|
||||
torch_c = torch.matmul(a, torch.t(b))
|
||||
torch_d = torch.matmul(torch.t(torch_c), b)
|
||||
torch_y = torch.mean(torch_d)
|
||||
triton_c = triton.ops.dot(a, b, False, True, 1)
|
||||
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
||||
triton_y = torch.mean(triton_d)
|
||||
# torch gradient
|
||||
torch_y.backward()
|
||||
torch_da = a.grad.clone()
|
||||
torch_db = b.grad.clone()
|
||||
# triton gradient
|
||||
a.grad.zero_()
|
||||
b.grad.zero_()
|
||||
triton_y.backward()
|
||||
triton_da = a.grad.clone()
|
||||
triton_db = b.grad.clone()
|
||||
|
||||
#nanosec = triton.bench_registry[triton_d]
|
||||
#print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
print('Diff DA:', (torch_da - triton_da).max())
|
||||
print('Diff DB:', (torch_db - triton_db).max())
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
run_tf()
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import torch
|
||||
run_torch()
|
||||
except ModuleNotFoundError:
|
||||
pass
|
@@ -1,92 +1,194 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
import triton
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
import numpy as np
|
||||
#import utils
|
||||
from time import time
|
||||
|
||||
class MODE(Enum):
|
||||
TF = 1
|
||||
TORCH = 2
|
||||
#torch.backends.cudnn.benchmark = True
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
mode = MODE.TF
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
configs = []
|
||||
|
||||
try:
|
||||
import torch
|
||||
mode = MODE.TORCH
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
# Matrix multiplication
|
||||
MNK = [
|
||||
(512, 512 ,512),
|
||||
(2048, 2048, 2048),
|
||||
(8192, 8192, 8192),
|
||||
|
||||
(64, 64, 64000),
|
||||
(64, 64, 128000),
|
||||
(256, 256, 64000),
|
||||
(256, 256, 128000),
|
||||
|
||||
cases = []
|
||||
# Matmul
|
||||
cases += [[[4, 1024, 1024], [1024, 1024], [4, 1024, 1024], "btc,ck->btk"]]
|
||||
# Attention
|
||||
# cases += [[[4, 256, 8, 2, 64], [8, 2, 512, 64], [4, 256, 8, 2, 512], "bchak,hank->bchan"]]
|
||||
(1536, 16, 1536),
|
||||
(1536, 32, 1536),
|
||||
(1536, 64, 1536),
|
||||
(1536, 128, 1536),
|
||||
(4096, 16, 4096),
|
||||
(4096, 32, 4096),
|
||||
(4096, 64, 4096),
|
||||
(4096, 128, 4096),
|
||||
|
||||
#(127008, 768, 576)
|
||||
]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b)
|
||||
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||
configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b.t())
|
||||
configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
|
||||
|
||||
if mode == MODE.TF:
|
||||
sess = tf.InteractiveSession()
|
||||
# Relative attention
|
||||
NTHSE = [
|
||||
#(16, 512, 1, 64, 64),
|
||||
# (16, 512, 1, 128, 128),
|
||||
# (16, 512, 1, 256, 256),
|
||||
# (16, 512, 1, 256, 512),
|
||||
#(16, 512, 8, 64, 64),
|
||||
# (16, 512, 8, 128, 128),
|
||||
# (16, 512, 8, 256, 256),
|
||||
# (16, 512, 8, 256, 512),
|
||||
|
||||
for a_shape, b_shape, c_shape, einsum in cases:
|
||||
# (64, 1024, 1, 64, 64),
|
||||
#(64, 1024, 1, 128, 128),
|
||||
# (64, 1024, 1, 256, 256),
|
||||
# (64, 1024, 1, 256, 512),
|
||||
# (64, 1024, 8, 64, 64),
|
||||
#(64, 1024, 8, 128, 128),
|
||||
# (64, 1024, 8, 256, 256),
|
||||
# (64, 1024, 8, 256, 512),
|
||||
|
||||
A = np.random.uniform(-1.0, 1.0, a_shape).astype(np.float16).astype(np.float32)
|
||||
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
|
||||
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
|
||||
# (128, 1024, 1, 64, 64),
|
||||
# (128, 1024, 1, 128, 128),
|
||||
# (128, 1024, 1, 256, 256),
|
||||
#(128, 1024, 1, 256, 512),
|
||||
# (128, 1024, 8, 64, 64),
|
||||
# (128, 1024, 8, 128, 128),
|
||||
# (128, 1024, 8, 256, 256),
|
||||
#(128, 1024, 8, 256, 512)
|
||||
]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
|
||||
|
||||
# Execute (tensorflow)
|
||||
if mode == MODE.TF:
|
||||
a = tf.placeholder(tf.float32, a_shape, name="a")
|
||||
b = tf.placeholder(tf.float32, b_shape, name="b")
|
||||
e = tf.placeholder(tf.float32, c_shape, name="e")
|
||||
c = triton.ops.einsum(einsum, a, b, 1)
|
||||
da, db = tf.gradients(c, [a, b], e)
|
||||
feed_dict = { a: A.astype(np.float32),
|
||||
b: B.astype(np.float32),
|
||||
e: E }
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([c, da, db], feed_dict = feed_dict)
|
||||
# Execute (torch)
|
||||
if mode == MODE.TORCH:
|
||||
a = torch.from_numpy(A).cuda()
|
||||
b = torch.from_numpy(B).cuda()
|
||||
e = torch.from_numpy(E).cuda()
|
||||
a.requires_grad_(True)
|
||||
b.requires_grad_(True)
|
||||
c = triton.ops.einsum(einsum, a, b, 1)
|
||||
torch.autograd.backward(c, e)
|
||||
da = a.grad
|
||||
db = b.grad
|
||||
result = [c.cpu().detach().numpy(), da.cpu().detach().numpy(), db.cpu().detach().numpy()]
|
||||
|
||||
# benchmark
|
||||
nanosec = triton.bench_registry[c]
|
||||
ctx = triton.ctx_registry[c]
|
||||
b, m, n, k = tuple((ctx.bmnk[i] for i in range(0, 4)))
|
||||
ops = 2.*b*m*n*k
|
||||
print('C TFLOPS:', ops / triton.bench_registry[c] * 1e-3)
|
||||
#print('DA TFLOPS:', ops / triton.bench_registry[da] * 1e-3)
|
||||
#print('DB TFLOPS:', ops / triton.bench_registry[db] * 1e-3)
|
||||
# 1D Dense convolution
|
||||
NCHKR = [
|
||||
# (1, 1152, 12602, 512, 3)
|
||||
]
|
||||
for N, C, H, K, R in NCHKR:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
|
||||
configs += [([N, C, H],
|
||||
[C, R, K],
|
||||
[N, K, H - R + 1],
|
||||
torch_fn,
|
||||
'nc(h+r),crk->nkh',
|
||||
dict())]
|
||||
|
||||
# test
|
||||
ctx = triton.ctx_registry[c]
|
||||
t_a = ctx.trans_a
|
||||
t_b = ctx.trans_b
|
||||
e_a = ctx.einsum_a
|
||||
e_b = ctx.einsum_b
|
||||
e_c = ctx.einsum_c
|
||||
C = np.einsum(einsum, A, B)
|
||||
if not t_a and not t_b: # NN
|
||||
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
|
||||
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
|
||||
elif not t_a and t_b: # NT
|
||||
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
|
||||
DB = np.einsum(f"{e_c},{e_a}->{e_b}", E, A)
|
||||
elif t_a and not t_b: # TN
|
||||
DA = np.einsum(f"{e_b},{e_c}->{e_a}", B, E)
|
||||
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
|
||||
c, da, db = result[0], result[1], result[2]
|
||||
print('C diff:', np.abs((C - c)).max())
|
||||
print('DA diff:', np.abs((DA - da)).max())
|
||||
print('DB diff:', np.abs((DB - db)).max())
|
||||
# 2D Dense convolution
|
||||
NCHWKRS = [
|
||||
#(8, 64, 128, 128, 768, 3, 3),
|
||||
#(8, 128, 64, 64, 256, 3, 3),
|
||||
#(8, 256, 32, 32, 512, 3, 3),
|
||||
#(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
for N, C, H, W, K, R, S in NCHWKRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
||||
configs += [([N, C, H, W],
|
||||
[C, R, S, K],
|
||||
[N, K, H - R + 1, W - R + 1],
|
||||
torch_fn,
|
||||
'nc(h+r)(w+s),crsk->nkhw',
|
||||
dict())]
|
||||
|
||||
# 3D Dense Convolution
|
||||
NCDHWKTRS = [
|
||||
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
|
||||
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
|
||||
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
|
||||
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
|
||||
]
|
||||
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
|
||||
configs += [([N, C, D, H, W],
|
||||
[C, T, R, S, K],
|
||||
[N, K, D - T + 1, H - R + 1, W - R + 1],
|
||||
torch_fn,
|
||||
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
|
||||
dict())]
|
||||
|
||||
|
||||
# Shift convolution
|
||||
shift_cuda = torch.utils.cpp_extension.load(
|
||||
'shift_cuda', ['kernels/shift_cuda.cpp',
|
||||
'kernels/shift_cuda_kernel.cu'],
|
||||
extra_cflags=['-O3'])
|
||||
class shift(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shift):
|
||||
ctx.save_for_backward(shift)
|
||||
return shift_cuda.forward(x, shift)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
shift, = ctx.saved_tensors
|
||||
grad_output = shift_cuda.backward(grad_output, shift)
|
||||
|
||||
return grad_output, None
|
||||
|
||||
NCHWKRS = [
|
||||
#(8, 64, 128, 128, 128, 3, 3),
|
||||
#(8, 128, 64, 64, 256, 3, 3),
|
||||
#(8, 256, 32, 32, 512, 3, 3),
|
||||
#(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
for N, C, H, W, K, R, S in NCHWKRS:
|
||||
shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2
|
||||
shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2
|
||||
def shift_conv(a, b, **kwargs):
|
||||
shift_h, shift_w = kwargs['sh'], kwargs['sw']
|
||||
shift_torch = np.column_stack((shift_w*-1, shift_h*-1))
|
||||
shift_torch = torch.from_numpy(shift_torch).cuda()
|
||||
a = shift.apply(a, shift_torch)
|
||||
b = b.permute(1, 0)
|
||||
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
|
||||
return torch.nn.functional.conv2d(a, b)
|
||||
configs += [([N, C, H, W],
|
||||
[C, K],
|
||||
[N, K, H, W],
|
||||
shift_conv,
|
||||
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
|
||||
{'sh': shift_h, 'sw': shift_w})]
|
||||
|
||||
# Benchmark
|
||||
torch.set_num_threads(1)
|
||||
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
dtype = torch.cuda.HalfTensor
|
||||
# initialize input tensors
|
||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||
# triton output
|
||||
#ta = triton.ops._einsum.pad(a, [4,4,4,4])
|
||||
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
|
||||
# reference output
|
||||
if torch_fn:
|
||||
rc = torch_fn(a, b, **arrays)
|
||||
else:
|
||||
rc = torch.einsum(expr, a, b)
|
||||
# performance relative to equivalent matrix multiplication
|
||||
ctx = triton.ctx_registry[tc]
|
||||
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
||||
# a = torch.rand(B, M, K).type(dtype).cuda()
|
||||
# b = torch.rand(B, K, N).type(dtype).cuda()
|
||||
# tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
|
||||
# ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
|
||||
ratio = 0
|
||||
# test and benchmark
|
||||
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3
|
||||
diff = (tc - rc).abs().max() / rc.abs().max()
|
||||
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} ({ratio:4.2f}); {diff:4.2f}')
|
||||
|
42
python/examples/kernels/shift_cuda.cpp
Normal file
42
python/examples/kernels/shift_cuda.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
// CUDA forward declarations
|
||||
|
||||
at::Tensor shift_cuda_forward(
|
||||
const at::Tensor input,
|
||||
const at::Tensor shift);
|
||||
|
||||
at::Tensor shift_cuda_backward(
|
||||
const at::Tensor grad_input,
|
||||
const at::Tensor shift);
|
||||
|
||||
// C++ interface
|
||||
|
||||
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
at::Tensor shift_forward(
|
||||
const at::Tensor input,
|
||||
const at::Tensor shift) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(shift);
|
||||
|
||||
return shift_cuda_forward(input, shift);
|
||||
}
|
||||
|
||||
at::Tensor shift_backward(
|
||||
const at::Tensor grad_input,
|
||||
const at::Tensor shift) {
|
||||
CHECK_INPUT(grad_input);
|
||||
CHECK_INPUT(shift);
|
||||
return shift_cuda_backward(grad_input, shift);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &shift_forward, "Shift forward (CUDA)");
|
||||
m.def("backward", &shift_backward, "Shift backward (CUDA)");
|
||||
}
|
111
python/examples/kernels/shift_cuda_kernel.cu
Normal file
111
python/examples/kernels/shift_cuda_kernel.cu
Normal file
@@ -0,0 +1,111 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
__global__ void shift_cuda_forward_kernel(
|
||||
const scalar_t* __restrict__ input,
|
||||
const int32_t* __restrict__ shift,
|
||||
scalar_t* __restrict__ output,
|
||||
const int32_t B,
|
||||
const int32_t C,
|
||||
const int32_t H,
|
||||
const int32_t W) {
|
||||
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32_t size = B*C*H*W;
|
||||
|
||||
const int32_t CHW = C*H*W;
|
||||
const int32_t HW = H*W;
|
||||
const int32_t b = idx / CHW;
|
||||
const int32_t c = (idx - b*CHW) / HW;
|
||||
const int32_t h = (idx - b*CHW - c*HW) / W;
|
||||
const int32_t w = idx - b*CHW - c*HW - h*W;
|
||||
const int32_t target_w = w + shift[2*c];
|
||||
const int32_t target_h = h + shift[2*c + 1];
|
||||
const int32_t target_idx = b*CHW + c*HW + target_h*W + target_w;
|
||||
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
|
||||
output[target_idx] = input[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void shift_cuda_backward_kernel(
|
||||
const scalar_t* __restrict__ grad_input,
|
||||
scalar_t* __restrict__ grad_output,
|
||||
const int32_t* __restrict__ shift,
|
||||
const int32_t B,
|
||||
const int32_t C,
|
||||
const int32_t W,
|
||||
const int32_t H) {
|
||||
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32_t size = B*C*W*H;
|
||||
const int32_t CWH = C*W*H;
|
||||
const int32_t WH = W*H;
|
||||
const int32_t b = idx / CWH;
|
||||
const int32_t c = (idx - b*CWH) / WH;
|
||||
const int32_t w = (idx - b*CWH - c*WH) / W;
|
||||
const int32_t h = idx - b*CWH - c*WH - w*H;
|
||||
const int32_t target_w = w - shift[2*c];
|
||||
const int32_t target_h = h - shift[2*c + 1];
|
||||
const int32_t target_idx = b*CWH + c*WH + target_w*W + target_h;
|
||||
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
|
||||
grad_output[target_idx] = grad_input[idx];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
at::Tensor shift_cuda_forward(
|
||||
const at::Tensor input,
|
||||
const at::Tensor shift) {
|
||||
const auto B = input.size(0);
|
||||
const auto C = input.size(1);
|
||||
const auto H = input.size(2);
|
||||
const auto W = input.size(3);
|
||||
const auto size = B*C*W*H;
|
||||
const int threads = 1024;
|
||||
const int blocks = (size + threads - 1) / threads;
|
||||
auto output = at::zeros_like(input);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "shift_forward_cuda", ([&] {
|
||||
shift_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
|
||||
input.data<scalar_t>(),
|
||||
shift.data<int32_t>(),
|
||||
output.data<scalar_t>(),
|
||||
B,
|
||||
C,
|
||||
H,
|
||||
W);
|
||||
}));
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
at::Tensor shift_cuda_backward(
|
||||
const at::Tensor grad_input,
|
||||
const at::Tensor shift) {
|
||||
const auto B = grad_input.size(0);
|
||||
const auto C = grad_input.size(1);
|
||||
const auto H = grad_input.size(2);
|
||||
const auto W = grad_input.size(3);
|
||||
const auto size = B*C*W*H;
|
||||
const int threads = 1024;
|
||||
const int blocks = (size + threads - 1) / threads;
|
||||
auto grad_output = at::zeros_like(grad_input);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.type(), "shift_backward_cuda", ([&] {
|
||||
shift_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
|
||||
grad_input.data<scalar_t>(),
|
||||
grad_output.data<scalar_t>(),
|
||||
shift.data<int32_t>(),
|
||||
B,
|
||||
C,
|
||||
H,
|
||||
W);
|
||||
}));
|
||||
|
||||
return grad_output;
|
||||
}
|
@@ -77,7 +77,7 @@ class CMakeBuild(build_ext):
|
||||
pass
|
||||
|
||||
cfg = 'Debug' if self.debug else 'Release'
|
||||
#cfg = 'Release'
|
||||
cfg = 'Release'
|
||||
build_args = ['--config', cfg]
|
||||
|
||||
if platform.system() == "Windows":
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <string>
|
||||
@@ -48,6 +49,11 @@ void delete_fn(size_t id) {
|
||||
id_fn_map.erase(id);
|
||||
}
|
||||
|
||||
void register_cst(size_t id, const std::string& name, pybind11::buffer& data) {
|
||||
pybind11::buffer_info info = data.request();
|
||||
id_fn_map[id]->set_cst(name, info.ptr, info.size*info.itemsize);
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
id_grid_map.clear();
|
||||
id_fn_map.clear();
|
||||
@@ -508,7 +514,8 @@ void gen_torch_make_handles(std::ostream &os,
|
||||
os << " " << to_c_ty(ty) << " arg_" << name << " = " << name << ";" << std::endl;
|
||||
else{
|
||||
os << " CHECK_INPUT(" << name << ");" << std::endl;
|
||||
os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), (CUdeviceptr)" + name + ".storage().data(), false);" << std::endl;
|
||||
os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), "
|
||||
" (CUdeviceptr)((char*)" + name + ".storage().data() + " + name + ".storage_offset() * " + name + ".itemsize()), false);" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -526,8 +533,8 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
|
||||
os << name;
|
||||
}
|
||||
os << "}, *id_grid_map.at(id), &stream);\n";
|
||||
os << " };\n ";
|
||||
os << " run();";
|
||||
os << " };\n";
|
||||
os << " run();\n";
|
||||
os << " if(bench > 0)\n ";
|
||||
os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n ";
|
||||
}
|
||||
@@ -562,18 +569,15 @@ std::tuple<std::string,
|
||||
std::ostringstream oss;
|
||||
oss << R"(
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "torch/extension.h"
|
||||
#include "torch/script.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include "ATen/cuda/detail/CUDAHooks.h"
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x);
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
namespace drv = triton::driver;
|
||||
@@ -628,6 +632,7 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("delete_grid", &delete_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
m.def("register_cst", ®ister_cst);
|
||||
m.def("delete_fn", &delete_fn);
|
||||
m.def("make_op_id", &make_op_id);
|
||||
m.def("make_scalar_id", &make_scalar_id);
|
||||
|
@@ -36,14 +36,16 @@ class function(metaclass = function_meta):
|
||||
def apply_torch(cls, *args, **kwargs):
|
||||
class TorchFunction(fw.torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, *targs, **tkwargs):
|
||||
y = cls.forward(ctx, *targs, **tkwargs)
|
||||
def forward(ctx, *targs):
|
||||
y = cls.forward(ctx, *targs, **cls.torch_kwargs)
|
||||
ctx_registry[y] = ctx
|
||||
return y
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return cls.backward(ctx, grad_output)
|
||||
return TorchFunction.apply(*args, **kwargs)
|
||||
cls.torch_kwargs = kwargs
|
||||
return TorchFunction.apply(*args)
|
||||
torch_kwargs = 0
|
||||
|
||||
@classmethod
|
||||
def extract_tf_tensors(cls, lst, err):
|
||||
|
@@ -6,6 +6,8 @@ import hashlib
|
||||
import sysconfig
|
||||
import sys
|
||||
import weakref
|
||||
import contextlib
|
||||
import io
|
||||
# import for just-in-time compilation
|
||||
import distutils
|
||||
import setuptools.command.build_ext
|
||||
@@ -56,7 +58,16 @@ def _write_bindings(src, root):
|
||||
handle.writelines(src)
|
||||
# return path of cpp file
|
||||
return (cpp, so)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def quiet():
|
||||
old_stdout, old_stderr = sys.stdout, sys.stderr
|
||||
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.stdout, sys.stderr = old_stdout, old_stderr
|
||||
|
||||
def _build(src, path):
|
||||
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
|
||||
ccdir = os.path.realpath(ccdir)
|
||||
@@ -102,7 +113,7 @@ def _build(src, path):
|
||||
language = 'c++',
|
||||
sources = [src],
|
||||
include_dirs = include_dirs,
|
||||
extra_compile_args = extra_compile_args,
|
||||
extra_compile_args = extra_compile_args + ['-g0'],
|
||||
extra_link_args = extra_link_args,
|
||||
library_dirs = library_dirs,
|
||||
libraries = libraries,
|
||||
@@ -119,7 +130,8 @@ def _build(src, path):
|
||||
ext_modules = [ext],
|
||||
script_args = args,
|
||||
)
|
||||
setuptools.setup(**args)
|
||||
with quiet():
|
||||
setuptools.setup(**args)
|
||||
shutil.rmtree(tmp)
|
||||
|
||||
def _cvt_to_def_str(obj):
|
||||
@@ -188,6 +200,10 @@ class kernel:
|
||||
self.src = src
|
||||
self.outputs = outputs
|
||||
self.tmp = tmp
|
||||
self.cst = dict()
|
||||
|
||||
def set_constant(self, name, value):
|
||||
self.cst[name] = value
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# create a new framework op when defines are different
|
||||
@@ -204,15 +220,16 @@ class kernel:
|
||||
defines.append((k, values))
|
||||
opt = libtriton.options_space()
|
||||
opt.defines = defines
|
||||
opt.num_warps = [4]
|
||||
opt.num_warps = [2, 4]
|
||||
# create unique id for this op
|
||||
op_id = libtriton.make_op_id()
|
||||
self.fw_id[key] = op_id
|
||||
# register function
|
||||
libtriton.register_fn(op_id, self.src, opt)
|
||||
for name, value in self.cst.items():
|
||||
libtriton.register_cst(op_id, name, value)
|
||||
if self.fw_op is None:
|
||||
self.fw_op = _make_framework_op(self.src, self.outputs, self.tmp, opt)
|
||||
|
||||
# benchmarking info
|
||||
bench = 0
|
||||
if 'bench' in kwargs:
|
||||
@@ -252,9 +269,9 @@ class kernel:
|
||||
for y in ret:
|
||||
bench_registry[y] = triton.utils.id_dict.lazy_entry(bench_id)
|
||||
elif fw.has_torch():
|
||||
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
|
||||
args = [x if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
|
||||
ret = self.fw_op(op_id, bench, bench_id, *args)
|
||||
if bench > 0:
|
||||
bench_registry[ret] = libtriton.retrieve_scalar(op_id)
|
||||
bench_registry[ret] = libtriton.retrieve_scalar(bench_id)
|
||||
else:
|
||||
assert False
|
@@ -38,25 +38,19 @@ void convnd(A_TYPE *A,
|
||||
int rah[TM] = rabh % CH;
|
||||
rah = rah * UPAW - off_uah;
|
||||
raw = raw * UPAH - off_uaw;
|
||||
int racr[TK] = rk / BW;
|
||||
int ras[TK] = rk % BW;
|
||||
int rac[TK] = racr / BH;
|
||||
int rar[TK] = racr % BH;
|
||||
rar = UPAR * rar;
|
||||
ras = UPAS * ras;
|
||||
int ram[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||
int rak[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
int rak[TK] = *(ADELTA + rk);
|
||||
A_TYPE* pa[TM, TK] = A + ram[:, newaxis] + rak[newaxis, :];
|
||||
|
||||
// pointers for B
|
||||
int rbk[TK] = rk;
|
||||
int rbn[TN] = ryb;
|
||||
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rbk[:, newaxis] * ldb_s;
|
||||
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rbk[:, newaxis] * ldb_c;
|
||||
|
||||
// pointers for A look-up table
|
||||
int rklut[TK] = rk % LUT_SIZE;
|
||||
int* padiff[TK] = ADIFF + rklut;
|
||||
int* padelta[TK] = ADELTA + rklut + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
|
||||
int* padelta[TK] = ADELTA + TK + rklut + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
|
||||
int adiff[TK] = *padiff;
|
||||
int adelta[TK] = *padelta;
|
||||
|
||||
@@ -66,7 +60,7 @@ void convnd(A_TYPE *A,
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
c += a @ b;
|
||||
pa += adelta[newaxis, :];
|
||||
pb += TK * ldb_s;
|
||||
pb += TK * ldb_c;
|
||||
// increment A look-up table
|
||||
padelta = padelta + adiff;
|
||||
adelta = *padelta;
|
||||
@@ -99,29 +93,54 @@ void convnd(A_TYPE *A,
|
||||
kernel = triton.kernel(src, ['C'])
|
||||
|
||||
@staticmethod
|
||||
def _unpack(idx, D, H, W):
|
||||
cdh = idx // W
|
||||
w = idx % W
|
||||
cd = cdh // H
|
||||
h = cdh % H
|
||||
c = cd // D
|
||||
d = cd % D
|
||||
return c, d, h, w
|
||||
def _unpack(idx, order, shape_b):
|
||||
_123 = idx // shape_b[order[0]]
|
||||
_0 = idx % shape_b[order[0]]
|
||||
_23 = _123 // shape_b[order[1]]
|
||||
_1 = _123 % shape_b[order[1]]
|
||||
_3 = _23 // shape_b[order[2]]
|
||||
_2 = _23 % shape_b[order[2]]
|
||||
return _0, _1, _2, _3
|
||||
|
||||
@staticmethod
|
||||
def _delta_a(upsample_d, upsample_h, upsample_w, depth, TK,
|
||||
T, R, S, stride_a):
|
||||
def _roundup(x, div):
|
||||
return (x + div - 1) // div * div
|
||||
|
||||
@staticmethod
|
||||
def _delta_a(upsample_d, upsample_h, upsample_w,
|
||||
bc, bd, bh, bw,
|
||||
ac, ad, ah, aw,
|
||||
stride_a, shape_b,
|
||||
TK):
|
||||
# Parse the axes so that the reduction is done
|
||||
# from the innermost dimension outward
|
||||
order = sorted([bc, bd, bh, bw], reverse = True)
|
||||
c, d, h, w = [order.index(x) for x in [bc, bd, bh, bw]]
|
||||
# Size of the lookup table is the product of the 3 innermost dimensions
|
||||
K = _conv._roundup(TK, shape_b[order[0]] * shape_b[order[1]] * shape_b[order[2]])
|
||||
# Allocate temporary arrays
|
||||
ud = np.arange(upsample_d, dtype=np.int32)[:, np.newaxis, np.newaxis, np.newaxis]
|
||||
uh = np.arange(upsample_h, dtype=np.int32)[np.newaxis, :, np.newaxis, np.newaxis]
|
||||
uw = np.arange(upsample_w, dtype=np.int32)[np.newaxis, np.newaxis, :, np.newaxis]
|
||||
ctrs = np.arange(depth, dtype=np.int32)[np.newaxis, np.newaxis, np.newaxis, :]
|
||||
c, t, r, s = _conv._unpack(ctrs, T, R, S)
|
||||
nextc, nextt, nextr, nexts = _conv._unpack(ctrs + TK, T, R, S)
|
||||
cdiff = nextc - c
|
||||
tdiff = nextt - t
|
||||
rdiff = nextr - r
|
||||
sdiff = nexts - s
|
||||
return cdiff*stride_a[1] + tdiff*stride_a[2] + rdiff*stride_a[3] + sdiff*stride_a[4]
|
||||
k = np.arange(K , dtype=np.int32)[np.newaxis, np.newaxis, np.newaxis, :]
|
||||
# Find reduction indices at the current and next reduction indices
|
||||
currentk = _conv._unpack(k , order, shape_b)
|
||||
nextk = _conv._unpack(k + TK, order, shape_b)
|
||||
# Compute memory stride
|
||||
result = 0
|
||||
result += (nextk[c] - currentk[c]) * stride_a[ac]
|
||||
result += (nextk[d] - currentk[d]) * stride_a[ad]
|
||||
result += (nextk[h] - currentk[h]) * stride_a[ah]
|
||||
result += (nextk[w] - currentk[w]) * stride_a[aw]
|
||||
# Initial k
|
||||
ki = np.arange(TK , dtype=np.int32)[np.newaxis, np.newaxis, np.newaxis, :]
|
||||
currentk = _conv._unpack(ki, order, shape_b)
|
||||
resulti = 0
|
||||
resulti += currentk[c] * stride_a[ac]
|
||||
resulti += currentk[d] * stride_a[ad]
|
||||
resulti += currentk[h] * stride_a[ah]
|
||||
resulti += currentk[w] * stride_a[aw]
|
||||
return np.concatenate((resulti, result), axis=-1)
|
||||
|
||||
@staticmethod
|
||||
def _extract_strides(shape):
|
||||
@@ -134,38 +153,56 @@ void convnd(A_TYPE *A,
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b,
|
||||
upsample_d, upsample_h, upsample_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
stride_d, stride_h, stride_w,
|
||||
mode):
|
||||
stride_d, stride_h, stride_w,
|
||||
upsample_d, upsample_h, upsample_w,
|
||||
a_layout, b_layout, c_layout):
|
||||
# input shapes
|
||||
shape_a = list(triton.shape(a))
|
||||
shape_b = list(triton.shape(b))
|
||||
# add depth
|
||||
shape_a.insert(2, 1)
|
||||
shape_b.insert(1, 1)
|
||||
NB, NC, AD, AH, AW = shape_a
|
||||
NC, BD, BH, BW, NF = shape_b
|
||||
dim = len(shape_a) - 2
|
||||
# indices
|
||||
an, ac, ad, ah, aw = [a_layout.find(x) for x in 'ncdhw']
|
||||
bk, bc, bd, bh, bw = [b_layout.find(x) for x in 'kctrs']
|
||||
cn, ck, cd, ch, cw = [c_layout.find(x) for x in 'nkdhw']
|
||||
# extract shapes
|
||||
if dim == 2:
|
||||
shape_a.insert(ad, 1)
|
||||
if dim == 2:
|
||||
shape_b.insert(bd, 1)
|
||||
# output shape
|
||||
CD = (AD*upsample_d - BD + 1 + 2*pad_d + stride_d - 1) // stride_d
|
||||
CH = (AH*upsample_h - BH + 1 + 2*pad_h + stride_h - 1) // stride_h
|
||||
CW = (AW*upsample_w - BW + 1 + 2*pad_w + stride_w - 1) // stride_w
|
||||
shape_c = [NB, NF, CD, CH, CW]
|
||||
shape_c = [0] * 5
|
||||
shape_c[cn] = shape_a[an]
|
||||
shape_c[ck] = shape_b[bk]
|
||||
shape_c[cd] = (shape_a[ad]*upsample_d - shape_b[bd] + 1 + 2*pad_d + stride_d - 1) // stride_d
|
||||
shape_c[ch] = (shape_a[ah]*upsample_h - shape_b[bh] + 1 + 2*pad_h + stride_h - 1) // stride_h
|
||||
shape_c[cw] = (shape_a[aw]*upsample_w - shape_b[bw] + 1 + 2*pad_w + stride_w - 1) // stride_w
|
||||
# strides
|
||||
stride_a = _conv._extract_strides(shape_a)
|
||||
stride_b = _conv._extract_strides(shape_b)
|
||||
stride_c = _conv._extract_strides(shape_c)
|
||||
# look-up tables
|
||||
# tiling parameters
|
||||
TM = [32]
|
||||
TN = [32]
|
||||
TK = 8
|
||||
FS = BD * BH * BW
|
||||
depth = (TK + FS - 1)//FS * FS
|
||||
# pointer deltas for a
|
||||
delta_a = _conv._delta_a(upsample_d, upsample_h, upsample_w,
|
||||
depth, TK, BD, BH, BW, stride_a)
|
||||
bc, bd, bh, bw,
|
||||
ac, ad, ah, aw,
|
||||
stride_a, shape_b,
|
||||
TK)
|
||||
delta_a = triton.fw.torch.from_numpy(delta_a).cuda()
|
||||
inc_a = np.arange(depth, dtype=np.int32)
|
||||
inc_a = ((inc_a + TK) % depth) - inc_a
|
||||
# delta increments for a
|
||||
inc_a = np.arange(delta_a.shape[-1] - TK, dtype=np.int32)
|
||||
inc_a = ((inc_a + TK) % inc_a.size) - inc_a
|
||||
inc_a = triton.fw.torch.from_numpy(inc_a).cuda()
|
||||
|
||||
# allocate output
|
||||
if dim == 2:
|
||||
shape_c.pop(cd)
|
||||
c = triton.empty(shape_c, dtype=a.dtype)
|
||||
if dim == 2:
|
||||
shape_c.insert(cd, 1)
|
||||
# execute kernel
|
||||
trans_b = False
|
||||
is_wgrad = False
|
||||
is_blut = False
|
||||
@@ -174,31 +211,99 @@ void convnd(A_TYPE *A,
|
||||
'UPAS': 'stride_w' if is_wgrad else '1',
|
||||
'UPAH': '' if is_wgrad else 'stride_h',
|
||||
'UPAW': '' if is_wgrad else 'stride_w',
|
||||
'LUT_SIZE': depth,
|
||||
'TM': [32],
|
||||
'TN': [32],
|
||||
'TK': TK,
|
||||
'A_TYPE': 'float',
|
||||
'B_TYPE': 'float'
|
||||
'LUT_SIZE': delta_a.shape[-1],
|
||||
'TM': TM, 'TN': TN, 'TK': TK,
|
||||
'A_TYPE': 'float', 'B_TYPE': 'float'
|
||||
}
|
||||
|
||||
shape_c.pop(2)
|
||||
c = triton.empty(shape_c, dtype=a.dtype)
|
||||
grid = lambda opt: [triton.cdiv(NB*CD*CH*CW, opt.d('TM')), triton.cdiv(NF, opt.d('TN'))]
|
||||
print(stride_c)
|
||||
print(stride_b)
|
||||
_conv.kernel(a, b, c, NB*CD*CH*CW, NF, NC*BD*BH*BW, AH, AW, BH, BW, CH, CW, NC,
|
||||
stride_a[0], stride_a[1], stride_a[2], stride_a[3], stride_a[4],
|
||||
stride_b[0], stride_b[1], stride_b[2], stride_b[3], stride_b[4],
|
||||
stride_c[0], stride_c[1], stride_c[2], stride_c[3], stride_c[4],
|
||||
pad_h, pad_w, stride_h, stride_w, upsample_h, upsample_w,
|
||||
MATMUL_M = shape_c[cn] * shape_c[cd] * shape_c[ch] * shape_c[cw]
|
||||
MATMUL_N = shape_c[ck]
|
||||
MATMUL_K = shape_b[bc] * shape_b[bd] * shape_b[bh] * shape_b[bw]
|
||||
_conv.kernel(a, b, c,
|
||||
# matrix multiplication shapes
|
||||
MATMUL_M, MATMUL_N, MATMUL_K,
|
||||
# shapes for a
|
||||
shape_a[ah], shape_a[aw],
|
||||
# shapes for b
|
||||
shape_b[bh], shape_b[bw],
|
||||
# chapes for c
|
||||
shape_c[ch], shape_c[cw], shape_c[cn],
|
||||
# strides for a
|
||||
stride_a[an], stride_a[ac], stride_a[ad + 0], stride_a[ad + 1], stride_a[ad + 2],
|
||||
# strides for b
|
||||
stride_b[bc], stride_b[bd + 0], stride_b[bd + 1], stride_b[bd + 2], stride_b[bk],
|
||||
# strides for c
|
||||
stride_c[cn], stride_c[ck], stride_c[cd], stride_c[cd + 1], stride_c[cd + 2],
|
||||
# padding
|
||||
pad_h, pad_w,
|
||||
# striding
|
||||
stride_h, stride_w,
|
||||
# upsampling
|
||||
upsample_h, upsample_w,
|
||||
0, 0, 0, 0, 0, 0,
|
||||
# look-up table
|
||||
delta_a, inc_a,
|
||||
grid, **macros)
|
||||
lambda opt: [triton.cdiv(MATMUL_M, opt.d('TM')), triton.cdiv(MATMUL_N, opt.d('TN'))],
|
||||
**macros)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
return _conv._call(input, weight, 1, 1, 1, 0, 0, 0, 1, 1, 1, '')
|
||||
def forward(ctx, x, w,
|
||||
pad_d = 0, pad_h = 0, pad_w = 0,
|
||||
stride_d = 1, stride_h = 1, stride_w = 1,
|
||||
upsample_d = 1, upsample_h = 1, upsample_w = 1,
|
||||
layout_a = 'ncdhw', layout_b = 'ktrsc', layout_c = 'nkdhw'):
|
||||
# save for backward
|
||||
ctx.save_for_backward(x, w)
|
||||
ctx.pad_d = pad_d
|
||||
ctx.pad_h = pad_h
|
||||
ctx.pad_w = pad_w
|
||||
ctx.stride_d = stride_d
|
||||
ctx.stride_h = stride_h
|
||||
ctx.stride_w = stride_w
|
||||
ctx.upsample_d = upsample_d
|
||||
ctx.upsample_h = upsample_h
|
||||
ctx.upsample_w = upsample_w
|
||||
ctx.layout_a = layout_a
|
||||
ctx.layout_b = layout_b
|
||||
ctx.layout_c = layout_c
|
||||
# return
|
||||
return _conv._call(x, w,
|
||||
pad_d, pad_h, pad_w,
|
||||
stride_d, stride_h, stride_w,
|
||||
upsample_d, upsample_h, upsample_w,
|
||||
layout_a, layout_b, layout_c)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
x, w = ctx.saved_tensors
|
||||
pad_d = ctx.pad_d
|
||||
pad_h = ctx.pad_h
|
||||
pad_w = ctx.pad_w
|
||||
stride_d = ctx.stride_d
|
||||
stride_h = ctx.stride_h
|
||||
stride_w = ctx.stride_w
|
||||
upsample_d = ctx.upsample_d
|
||||
upsample_h = ctx.upsample_h
|
||||
upsample_w = ctx.upsample_w
|
||||
layout_a = ctx.layout_a
|
||||
layout_b = ctx.layout_b
|
||||
layout_c = ctx.layout_c
|
||||
|
||||
# TODO: Deal with this
|
||||
dx_pad_d = 1
|
||||
dx_pad_h = 1
|
||||
dx_pad_w = 1
|
||||
dx = _conv.call(dy, w,
|
||||
dw_pad_d, dw_pad_h, dw_pad_w,
|
||||
upsample_w, upsample_h, upsample_w,
|
||||
stride_d, stride_h, stride_w,
|
||||
'ncdhw', 'cktrs', 'nkdhw')
|
||||
|
||||
|
||||
|
||||
ret = [None] * 14
|
||||
ret[0] = None
|
||||
ret[1] = dw
|
||||
return None,
|
||||
|
||||
conv = _conv.apply
|
@@ -3,37 +3,50 @@ import triton
|
||||
class _dot(triton.function):
|
||||
|
||||
src = """
|
||||
void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C,
|
||||
float alpha,
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
int rk[TK] = 0 ... TK;
|
||||
float c[TM, TN] = 0;
|
||||
// pointers to operands
|
||||
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
|
||||
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
|
||||
// prefetches operands
|
||||
TYPE a[SHAPE_A] = *pa;
|
||||
TYPE b[SHAPE_B] = *pb;
|
||||
// reduction loop
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += USE_A @ USE_B;
|
||||
pa = pa + TK * STRIDE_AK;
|
||||
pb = pb + TK * STRIDE_BK;
|
||||
bool checka[SHAPE_A] = k > TK;
|
||||
bool checkb[SHAPE_B] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
// epilogue
|
||||
TYPE* pc[TM, TN] = C + rm[:, newaxis] * ldc + rn[newaxis, :];
|
||||
*pc = c;
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
int rk[TK] = 0 ... TK;
|
||||
|
||||
// pointers to operands
|
||||
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
|
||||
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
|
||||
|
||||
// prefetches operands
|
||||
bool checka[SHAPE_A] = rk[BROADCAST_AK] < K;
|
||||
bool checkb[SHAPE_B] = rk[BROADCAST_BK] < K;
|
||||
TYPE a[SHAPE_A] = checka ? *pa : 0;
|
||||
TYPE b[SHAPE_B] = checkb ? *pb : 0;
|
||||
|
||||
// reduction loop
|
||||
float c[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
c += USE_A @ USE_B;
|
||||
bool checka[SHAPE_A] = k > TK;
|
||||
bool checkb[SHAPE_B] = k > TK;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
//c = c * alpha;
|
||||
|
||||
// epilogue
|
||||
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
|
||||
int rxn[TN] = get_program_id(1) * TN + 0 ... TN;
|
||||
TYPE* pc[TM, TN] = C + rxm[:, newaxis] * ldc + rxn[newaxis, :];
|
||||
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
|
||||
*?(checkc)pc = (TYPE[TM, TN])c;
|
||||
}
|
||||
"""
|
||||
kernel = triton.kernel(src, ['C'])
|
||||
@@ -75,10 +88,10 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
|
||||
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
|
||||
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
|
||||
_dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc,
|
||||
_dot.kernel(a, b, c, 1., M, N, Ka, lda, ldb, ldc,
|
||||
grid, bench=bench,
|
||||
AT = transpose_a, BT = transpose_b, TYPE = dtype,
|
||||
TM = [64, 128], TN = [64, 128], TK = [8], **macros)
|
||||
TM = [64], TN = [128], TK = [8], **macros)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
|
@@ -1,234 +1,651 @@
|
||||
# Special thanks to Scott Gray from OpenAI for writing the einsum parsing function
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from math import ceil, log2
|
||||
from enum import IntEnum
|
||||
import triton
|
||||
import math
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
import sympy as sp
|
||||
from collections import OrderedDict
|
||||
from collections import namedtuple
|
||||
import re
|
||||
from sympy.printing.ccode import C89CodePrinter
|
||||
|
||||
|
||||
class _einsum(triton.function):
|
||||
|
||||
src = """
|
||||
void einsumk(TYPE * A, TYPE * B, TYPE * C,
|
||||
int dim_M, int dim_N, int dim_K,
|
||||
int std_A0 __multipleof(8),
|
||||
int std_B0 __multipleof(8),
|
||||
int std_C0 __multipleof(8),
|
||||
int std_A1 __multipleof(8),
|
||||
int std_B1 __multipleof(8),
|
||||
int std_C1 __multipleof(8)) {
|
||||
// program id
|
||||
int pgm = get_program_id(0);
|
||||
int pgn = get_program_id(1);
|
||||
int pgb = get_program_id(2);
|
||||
// range
|
||||
int rm[TM] = pgm * TM + 0 ... TM;
|
||||
int rn[TN] = pgn * TN + 0 ... TN;
|
||||
int rb[TB] = pgb * TB + 0 ... TB;
|
||||
int rk[TK] = 0 ... TK;
|
||||
// accumulator
|
||||
float c[TM, TN, TB] = 0;
|
||||
// pointers to a
|
||||
TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK
|
||||
+ rm[BROADCAST_AM] * STRIDE_AM
|
||||
+ rb[newaxis, newaxis, :] * std_A0;
|
||||
// pointers to b
|
||||
TYPE *pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK
|
||||
+ rn[BROADCAST_BN] * STRIDE_BN
|
||||
+ rb[newaxis, newaxis, :] * std_B0;
|
||||
// prefetch
|
||||
TYPE a[SHAPE_A] = *pa;
|
||||
TYPE b[SHAPE_B] = *pb;
|
||||
// accumulation
|
||||
for(int k = dim_K; k > 0; k -= TK) {
|
||||
c += USE_A @ USE_B;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
bool checka[SHAPE_A] = k > TK;
|
||||
bool checkb[SHAPE_B] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
// write-back
|
||||
TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1
|
||||
+ rn[newaxis, :, newaxis] * 1
|
||||
+ rb[newaxis, newaxis, :] * std_C0;
|
||||
bool checkm[TM] = rm < dim_M;
|
||||
bool checkn[TN] = rn < dim_N;
|
||||
|
||||
#############################
|
||||
## Triton-C code generation
|
||||
#############################
|
||||
def print_cc(expr, axes_0, axes_1, axes_2):
|
||||
|
||||
class TritonCodePrinter(C89CodePrinter):
|
||||
|
||||
def __init__(self, axes_0, axes_1, axes_2):
|
||||
super(TritonCodePrinter, self).__init__()
|
||||
self.axes_0 = axes_0
|
||||
self.axes_1 = axes_1
|
||||
self.axes_2 = axes_2
|
||||
|
||||
def _print_Symbol(self, expr):
|
||||
name = super(C89CodePrinter, self)._print_Symbol(expr)
|
||||
if expr in self.axes_0:
|
||||
return f'r{name}[:, newaxis, newaxis]'
|
||||
if expr in self.axes_1:
|
||||
return f'r{name}[newaxis, :, newaxis]'
|
||||
if expr in self.axes_2:
|
||||
return f'r{name}[newaxis, newaxis, :]'
|
||||
return name
|
||||
|
||||
def _print_Indexed(self, expr):
|
||||
assert len(expr.indices) == 1
|
||||
return "*(%s + %s)" % (self._print(expr.base.label),
|
||||
self._print(expr.indices[0]))
|
||||
|
||||
return TritonCodePrinter(axes_0, axes_1, axes_2).doprint(expr)
|
||||
|
||||
|
||||
def unpack_cc(tile, axes, prefix, remat):
|
||||
ret = ''
|
||||
axes = list(map(str, axes))
|
||||
for i, d in enumerate(reversed(axes)):
|
||||
if i == len(axes) - 1:
|
||||
break
|
||||
currs = ''.join(axes[: len(axes) - i])
|
||||
nexts = ''.join(axes[: len(axes) - (i + 1)])
|
||||
ty = '' if remat else 'int '
|
||||
sz = '' if remat else f'[{tile}]'
|
||||
ret += f' {ty}{prefix}{nexts}{sz} = r{currs} / dim_{d};\n'
|
||||
ret += f' {ty}{prefix}{d}{sz} = r{currs} % dim_{d};\n'
|
||||
return ret
|
||||
|
||||
def strides_cc(name, expr):
|
||||
ret = [f'stride_{name}_{d}' for d in expr[:-1]] + ['1']
|
||||
ret = dict(zip(expr, ret))
|
||||
return ret
|
||||
|
||||
def make_kernel(name,
|
||||
expr_a, expr_b, expr_c,
|
||||
axes_m, axes_n, axes_k, axes_b,
|
||||
multipleof_a, multipleof_b, multipleof_c,
|
||||
lut_mode_a, lut_mode_b,
|
||||
delta_a, delta_b,
|
||||
subscripted):
|
||||
|
||||
use_lut_a = True
|
||||
use_lut_b = True
|
||||
|
||||
src = ""
|
||||
|
||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
||||
src += f"""
|
||||
char __constant__* AD = calloc({4*len(delta_a)});"""
|
||||
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
|
||||
src += f"""
|
||||
char __constant__* BD = calloc({4*len(delta_b)});"""
|
||||
|
||||
|
||||
src += f"""
|
||||
__global__ void {name}(
|
||||
TYPE * A __noalias __readonly __aligned(16)
|
||||
, TYPE * B __noalias __readonly __aligned(16)
|
||||
, TYPE * C
|
||||
, int * locks
|
||||
, float alpha
|
||||
, int matmul_m, int matmul_n, int matmul_k __multipleof(16)
|
||||
, int div_m
|
||||
"""
|
||||
for dim in [axes_m, axes_n, axes_k, axes_b]:
|
||||
for d in dim:
|
||||
src += f", int dim_{d}"
|
||||
src += "\n "
|
||||
for dim, name, mult in zip([expr_a, expr_b, expr_c],
|
||||
['a', 'b', 'c'],
|
||||
[multipleof_a, multipleof_b, multipleof_c]):
|
||||
for d in range(len(dim) - 1):
|
||||
attr = f'__multipleof({mult})'
|
||||
src += f", int stride_{name}_{d} {attr}"
|
||||
src += "\n "
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
src += f", int stride_a_inner __multipleof({multipleof_a})"
|
||||
elif lut_mode_a == _einsum.LUT_MODE.DRAM:
|
||||
src += ", int* AD __noalias __readonly __aligned(16)"
|
||||
src += "\n "
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
src += f", int stride_b_inner __multipleof({multipleof_b})"
|
||||
elif lut_mode_b == _einsum.LUT_MODE.DRAM:
|
||||
src += ", int* BD"
|
||||
for ptr in subscripted:
|
||||
src += f", int* {ptr}"
|
||||
src += """) {
|
||||
|
||||
// re-order outer program ids
|
||||
int grid_m = (matmul_m + TM - 1) / TM;
|
||||
int grid_n = (matmul_n + TN - 1) / TN;
|
||||
int pid_mn = get_program_id(0) / div_m;
|
||||
int pid_n = pid_mn % grid_n;
|
||||
int pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m);
|
||||
|
||||
// get batch program id
|
||||
int pid_b = get_program_id(1);
|
||||
|
||||
#if TZ == 1
|
||||
int off_k = 0;
|
||||
#else
|
||||
// get reduction sub-group program id
|
||||
int pid_z = get_program_id(2);
|
||||
int grid_z = get_num_programs(2);
|
||||
int div_z = matmul_k / TZ;
|
||||
int rem_z = matmul_k % TZ;
|
||||
int off_k = pid_z * div_z;
|
||||
matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z);
|
||||
#endif
|
||||
|
||||
// create ranges
|
||||
"""
|
||||
rk = 'r{}'.format(''.join(map(str,axes_k)))
|
||||
for axes, tile, off in zip([axes_m, axes_n, axes_b, axes_k],
|
||||
['TM', 'TN', 'TB', 'TK'],
|
||||
['pid_m*TM', 'pid_n*TN', 'pid_b*TB', 'off_k']):
|
||||
currs = ''.join(map(str,axes))
|
||||
if axes:
|
||||
src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n"
|
||||
src += _einsum.unpack_cc(tile, axes, 'r', False)
|
||||
|
||||
src += """
|
||||
// initialize pointers to A
|
||||
int offa[TM, TK, TB] = """
|
||||
for i, sym in enumerate(expr_a):
|
||||
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
|
||||
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1'
|
||||
if i > 0:
|
||||
src += ' + '
|
||||
src += f"({ccode}) * {stride}\n "
|
||||
src += ';'
|
||||
|
||||
src += """
|
||||
TYPE *pa[TM, TK, TB] = A + offa;"""
|
||||
|
||||
if use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
|
||||
cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
|
||||
src += f"""
|
||||
// initialize pointers to A look-up table
|
||||
int offadelta[TK] = off_k + 0 ... TK;
|
||||
int {spec} *padelta[TK] = {cast}AD + offadelta;
|
||||
int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];"""
|
||||
|
||||
src += """
|
||||
|
||||
// initialize pointers to B
|
||||
int offb[TK, TN, TB] = """
|
||||
for i, sym in enumerate(expr_b):
|
||||
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b)
|
||||
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else '1'
|
||||
if i > 0:
|
||||
src += ' + '
|
||||
src += f"({ccode}) * {stride}\n "
|
||||
src += ';'
|
||||
|
||||
src += """
|
||||
TYPE *pb[TK, TN, TB] = B + offb;"""
|
||||
|
||||
|
||||
if use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
|
||||
cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
|
||||
src += f"""
|
||||
// initialize pointers to B look-up table
|
||||
int offbdelta[TK] = off_k + 0 ... TK;
|
||||
int *pbdelta[TK] = BD + offbdelta;"""
|
||||
|
||||
src += f"""
|
||||
|
||||
// prefetch
|
||||
bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
|
||||
bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
|
||||
bool checkk[TK] = {rk} < matmul_k + off_k;
|
||||
bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
||||
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
|
||||
TYPE a[TM, TK, TB] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN, TB] = checkb ? *pb : 0;
|
||||
// accumulate
|
||||
float acc[TM, TN, TB] = 0;
|
||||
for(int k = matmul_k; k > 0; k -= TK) {{
|
||||
acc += a @ b;"""
|
||||
|
||||
if not use_lut_a or not use_lut_b:
|
||||
src += f"""
|
||||
{rk} += TK;
|
||||
"""
|
||||
src += _einsum.unpack_cc(tile, axes_k, 'r', True)
|
||||
|
||||
|
||||
if use_lut_a:
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
pa += stride_a_inner;"""
|
||||
else:
|
||||
src += """
|
||||
pa += incda;
|
||||
padelta += TK;
|
||||
incda = (*padelta)[newaxis, :, newaxis];"""
|
||||
else:
|
||||
src += """
|
||||
offa = """
|
||||
for i, sym in enumerate(expr_a):
|
||||
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
|
||||
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1'
|
||||
if i > 0:
|
||||
src += ' + '
|
||||
src += f"({ccode}) * {stride}\n "
|
||||
src += """;
|
||||
TYPE *pa[TM, TK, TB] = A + offa;"""
|
||||
|
||||
|
||||
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
pb += stride_b_inner;"""
|
||||
else:
|
||||
src += """
|
||||
pb += (*pbdelta)[:, newaxis, newaxis];
|
||||
pbdelta += TK;"""
|
||||
|
||||
src += f"""
|
||||
checkk = k > TK;
|
||||
checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
||||
checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}}
|
||||
TYPE c[TM, TN, TB] = acc;
|
||||
|
||||
// re-materialize ranges
|
||||
"""
|
||||
for axes, tile, off in zip([axes_m, axes_n, axes_b],
|
||||
['TM', 'TN', 'TB'],
|
||||
['pid_m*TM', 'pid_n*TN', 'pid_b*TB']):
|
||||
currs = ''.join(map(str,axes))
|
||||
if axes:
|
||||
src += f" r{currs} = {off} + 0 ... {tile};\n"
|
||||
src += _einsum.unpack_cc(tile, axes, 'r', True)
|
||||
|
||||
src += """
|
||||
// initialize pointers to C
|
||||
int offc[TM, TN, TB] = """
|
||||
for i, sym in enumerate(expr_c):
|
||||
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else '1'
|
||||
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b)
|
||||
if i > 0:
|
||||
src += ' + '
|
||||
src += f"({ccode}) * {stride}\n "
|
||||
src += ';'
|
||||
|
||||
src += """
|
||||
TYPE *pc[TM, TN, TB] = C + offc;
|
||||
|
||||
// bounds-checking
|
||||
checkm = r""" + ''.join(map(str,axes_m)) + """ < matmul_m;
|
||||
checkn = r""" + ''.join(map(str,axes_n)) + """ < matmul_n;
|
||||
bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] &&
|
||||
checkn[newaxis, :, newaxis];
|
||||
*?(checkc)pc = (TYPE[TM, TN, TB])c;
|
||||
checkn[newaxis, :, newaxis];
|
||||
|
||||
// write back
|
||||
#if TZ == 1
|
||||
*?(checkc)pc = c;
|
||||
#else
|
||||
int *plock = locks + pid_mn + pid_b * get_num_programs(0);
|
||||
int *pcount = plock + 1024*1024;
|
||||
// spin
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc)pc = c;
|
||||
else
|
||||
*?(checkc)pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % (grid_z));
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = triton.kernel(src, ['C'])
|
||||
#print(src)
|
||||
ret = triton.kernel(src, ['C'])
|
||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
||||
ret.set_constant('AD', delta_a)
|
||||
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
|
||||
ret.set_constant('BD', delta_b)
|
||||
return ret
|
||||
|
||||
############################
|
||||
## Look-up Table
|
||||
############################
|
||||
|
||||
class LUT_MODE(IntEnum):
|
||||
SCALAR = 1
|
||||
CONSTANT = 2
|
||||
DRAM = 3
|
||||
|
||||
def lut_mode(delta):
|
||||
if delta.size == 0 or np.min(delta) == np.max(delta):
|
||||
return _einsum.LUT_MODE.SCALAR
|
||||
#if delta.size < 4096:
|
||||
# return _einsum.LUT_MODE.CONSTANT
|
||||
return _einsum.LUT_MODE.DRAM
|
||||
|
||||
def symbolic_delta(symbols, axes):
|
||||
rank = len(symbols)
|
||||
strides = [sp.symbols(f'stride{d}') for d in range(rank)]
|
||||
nexts = {s: sp.symbols(f'next{s}') for s in axes}
|
||||
delta = 0
|
||||
for i in range(rank):
|
||||
delta += strides[i] * (symbols[i].subs(nexts) - symbols[i])
|
||||
return delta
|
||||
|
||||
def unpack_offset(k, axes, dims):
|
||||
ret = dict()
|
||||
for d in reversed(axes):
|
||||
ret[d] = k % dims[d]
|
||||
k = k // dims[d]
|
||||
return ret
|
||||
|
||||
def make_delta(axes, step, stride, dims, symbols, arrays):
|
||||
# symbolic pointer increments
|
||||
delta = _einsum.symbolic_delta(symbols, axes)
|
||||
args = [f'stride{d}' for d in range(len(stride))]
|
||||
args += [f'{sk}' for sk in axes]
|
||||
args += [f'next{sk}' for sk in axes]
|
||||
args += [f'{sk}' for sk, _ in arrays]
|
||||
fn = sp.lambdify(args, delta, 'numpy')
|
||||
# inner axes values
|
||||
inner = [dims[d] for d in axes]
|
||||
k = np.arange(np.prod(inner), dtype=np.int32)
|
||||
off = _einsum.unpack_offset(k, axes, dims)
|
||||
nextoff = _einsum.unpack_offset(k + step, axes, dims)
|
||||
# evaluate deltas
|
||||
args = [s for s in stride]
|
||||
args += [off[sk] for sk in axes]
|
||||
args += [nextoff[sk] for sk in axes]
|
||||
args += [x for _, x in arrays]
|
||||
delta = fn(*args)
|
||||
return delta, _einsum.lut_mode(delta[:-step])
|
||||
|
||||
############################
|
||||
## Einsum parsing
|
||||
############################
|
||||
|
||||
def uniq(seq):
|
||||
seen = set()
|
||||
seen_add = seen.add
|
||||
return [x for x in seq if not (x in seen or seen_add(x))]
|
||||
|
||||
def parse_axes(expr_a, expr_b, expr_c, subscripted):
|
||||
is_index = lambda x: type(x) == sp.indexed.Indexed or str(x) in subscripted
|
||||
sym_a = [x for s in expr_a for x in s.free_symbols if not is_index(x)]
|
||||
sym_b = [x for s in expr_b for x in s.free_symbols if not is_index(x)]
|
||||
sym_c = [x for s in expr_c for x in s.free_symbols]
|
||||
batch = [d for d in sym_a if d in sym_b and d in sym_c]
|
||||
outer = [d for d in sym_a if d not in sym_b and d in sym_c]
|
||||
inner = [d for d in sym_a if d in sym_b and d not in sym_c]
|
||||
illegal = [d for d in sym_a if d not in sym_b and d not in sym_c]
|
||||
if illegal:
|
||||
raise ValueError(f"einsum labels {illegal} ({expr_a}) "\
|
||||
f"not present in {expr_b} or {expr_c}")
|
||||
return _einsum.uniq(batch), _einsum.uniq(outer), _einsum.uniq(inner)
|
||||
|
||||
|
||||
def replace_subscript(expr, arrays):
|
||||
# replace array indexing by Indexed()
|
||||
indexed = re.findall('([_a-zA-Z][_a-zA-Z0-9]*)\[([_a-z]*)\]', expr)
|
||||
for x in indexed:
|
||||
arrays.append(x[0])
|
||||
expr = expr.replace(f'{x[0]}[{x[1]}]', f'Indexed({x[0]},{x[1]})')
|
||||
return expr
|
||||
|
||||
|
||||
def parse_expr(expr, arrays):
|
||||
# extract symbols
|
||||
sym = []
|
||||
i = 0
|
||||
while i < len(expr):
|
||||
d = expr[i]
|
||||
if d == '(':
|
||||
size = expr[i:].find(')')
|
||||
d = expr[i : i + size + 1]
|
||||
d = _einsum.replace_subscript(d, arrays)
|
||||
sym.append(parse_expr(d))
|
||||
i += size + 1
|
||||
else:
|
||||
sym.append(parse_expr(d))
|
||||
i += 1
|
||||
return sym
|
||||
|
||||
@staticmethod
|
||||
def _append_dim(dim_data, dim_type, idx, label, dim, stride):
|
||||
if dim_type in dim_data:
|
||||
data = dim_data[dim_type]
|
||||
if idx != data["idx"] + 1:
|
||||
raise ValueError("aggregate inner, outer and batch dims must be adjacent to each other.")
|
||||
data["dim"] *= dim
|
||||
data["lab"] = label + data["lab"]
|
||||
else:
|
||||
dim_data[dim_type] = dict(idx=idx, lab=label, dim=dim, std=stride)
|
||||
return dim_type
|
||||
############################
|
||||
## Preprocessing
|
||||
############################
|
||||
|
||||
@staticmethod
|
||||
def _parse_abc(labels_a, labels_b, labels_c, shape_a, is_a=False):
|
||||
def pad(tensor, pad):
|
||||
pad = pad + [0] * (2*len(tensor.shape) - len(pad))
|
||||
begin = [ x if x > 0 else None for x in pad[-1::-2]]
|
||||
end = [-x if x > 0 else None for x in pad[-2::-2]]
|
||||
slices = [slice(b, e) for b, e in zip(begin, end)]
|
||||
tensor = torch.nn.functional.pad(tensor, pad, 'constant', 0)
|
||||
tensor = tensor[slices]
|
||||
return tensor
|
||||
|
||||
if len(labels_a) != len(shape_a):
|
||||
raise ValueError(f"einsum notation dims do not match shape: {labels_a} {shape_a}")
|
||||
|
||||
trans = False
|
||||
stride = 1
|
||||
std1 = None
|
||||
data = dict()
|
||||
for idx, (lab, dim) in enumerate(reversed(list(zip(labels_a, shape_a)))):
|
||||
#print(idx, lab, dim)
|
||||
if dim is None:
|
||||
raise ValueError("einsum doens't currently work on shapes with placeholder dims.")
|
||||
if idx == 0 and dim % 8 != 0:
|
||||
raise ValueError("contiguous dim must be multiple of 8")
|
||||
############################
|
||||
## Compilation
|
||||
############################
|
||||
|
||||
if lab in labels_c:
|
||||
# batch dim
|
||||
if lab in labels_b:
|
||||
_einsum._append_dim(data, "B", idx, lab, dim, stride)
|
||||
if idx == 0:
|
||||
raise ValueError(f"batch dim can not be contiguous dim: {lab} {labels_a} {shape_a}")
|
||||
# outer dim
|
||||
else:
|
||||
std1 = _einsum._append_dim(data, "O", idx, lab, dim, stride)
|
||||
if idx == 0:
|
||||
trans = is_a
|
||||
# inner dim
|
||||
elif lab in labels_b:
|
||||
std1 = _einsum._append_dim(data, "I", idx, lab, dim, stride)
|
||||
if idx == 0:
|
||||
trans = not is_a
|
||||
else:
|
||||
raise ValueError(f"einsum def for output: {lab} ({labels_a}), not present in either other def")
|
||||
class instance:
|
||||
|
||||
stride *= dim
|
||||
locks = None
|
||||
kernel_cache = dict()
|
||||
|
||||
if "B" not in data:
|
||||
data["B"] = dict(dim=1, std=1)
|
||||
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays):
|
||||
# parse symbols
|
||||
expr_a, expr_bc = einsum.split(",")
|
||||
expr_b, expr_c = expr_bc.split("->")
|
||||
subscripted = []
|
||||
sym_a = _einsum.parse_expr(expr_a, subscripted)
|
||||
sym_b = _einsum.parse_expr(expr_b, subscripted)
|
||||
sym_c = _einsum.parse_expr(expr_c, subscripted)
|
||||
# parse axes
|
||||
axes_b, axes_m, axes_k = _einsum.parse_axes(sym_a, sym_b, sym_c, subscripted)
|
||||
_, axes_n, _ = _einsum.parse_axes(sym_b, sym_a, sym_c, subscripted)
|
||||
axes = axes_b + axes_m + axes_n + axes_k
|
||||
# check dimensions
|
||||
dims_a = dict(zip(sym_a, shape_a))
|
||||
dims_b = dict(zip(sym_b, shape_b))
|
||||
dims_c = dict(zip(sym_c, shape_c))
|
||||
for axes in [axes_b, axes_k]:
|
||||
for d in axes:
|
||||
dim_a = dims_a[d] if d in sym_a else None
|
||||
dim_b = dims_b[d] if d in sym_b else None
|
||||
if dim_a and dim_b and dim_a != dim_b:
|
||||
raise ValueError(f'incompatible dimension {d}'
|
||||
f' (a: {dim_a}; b: {dim_b})')
|
||||
dims = dict()
|
||||
dims.update(dims_a)
|
||||
dims.update(dims_b)
|
||||
dims.update(dims_c)
|
||||
# look-up tables
|
||||
TK = 16 if dtype == triton.fw.torch.float16 else 8
|
||||
arrays = [(x, arrays[x]) for x in subscripted]
|
||||
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays)
|
||||
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays)
|
||||
# hash for recompilation
|
||||
stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
|
||||
stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
|
||||
stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0])
|
||||
name = f'{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
|
||||
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'
|
||||
# recompile if necessary
|
||||
cache = _einsum.instance.kernel_cache
|
||||
if name not in cache:
|
||||
cachesize = len(cache)
|
||||
cache[name] = _einsum.make_kernel(f'__einsum{cachesize}',
|
||||
sym_a, sym_b, sym_c,
|
||||
axes_m, axes_n, axes_k, axes_b,
|
||||
stride_a_multiple, stride_b_multiple, stride_c_multiple,
|
||||
lut_mode_a, lut_mode_b,
|
||||
delta_a, delta_b,
|
||||
subscripted)
|
||||
self.kernel = cache[name]
|
||||
# Initialize locks
|
||||
if _einsum.instance.locks is None:
|
||||
_einsum.instance.locks = torch.zeros(2*1024*1024, dtype=torch.int32).cuda()
|
||||
# Kernel arguments
|
||||
dim_m = [dims[d] for d in axes_m]
|
||||
dim_n = [dims[d] for d in axes_n]
|
||||
dim_k = [dims[d] for d in axes_k]
|
||||
dim_b = [dims[d] for d in axes_b]
|
||||
M = reduce(mul, dim_m, 1)
|
||||
N = reduce(mul, dim_n, 1)
|
||||
K = reduce(mul, dim_k, 1)
|
||||
B = reduce(mul, dim_b, 1)
|
||||
stride_a = list(stride_a[:-1])
|
||||
stride_b = list(stride_b[:-1])
|
||||
stride_c = list(stride_c[:-1])
|
||||
arrays = [torch.from_numpy(x).cuda() for _, x in arrays]
|
||||
alpha = 1.
|
||||
div_m = 1
|
||||
self.args = [None, None, None,
|
||||
_einsum.instance.locks,
|
||||
alpha, M, N, K, div_m] +\
|
||||
dim_m + dim_n + dim_k + dim_b +\
|
||||
stride_a + stride_b + stride_c
|
||||
if lut_mode_a != _einsum.LUT_MODE.CONSTANT:
|
||||
delta_a = delta_a[0] if lut_mode_a == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_a).cuda()
|
||||
self.args += [delta_a]
|
||||
if lut_mode_b != _einsum.LUT_MODE.CONSTANT:
|
||||
delta_b = delta_b[0] if lut_mode_b == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_b).cuda()
|
||||
self.args += [delta_b]
|
||||
self.args += arrays
|
||||
self.args += [lambda opt: [triton.cdiv(M, opt.d('TM')) *
|
||||
triton.cdiv(N, opt.d('TN')),
|
||||
triton.cdiv(B, opt.d('TB')),
|
||||
opt.d('TZ')]]
|
||||
# position of dynamic arguments
|
||||
self.pos_a = 0
|
||||
self.pos_b = 1
|
||||
self.pos_c = 2
|
||||
# pre-processor macros
|
||||
TM = [x for x in [16, 32, 64, 128] if x <= M]
|
||||
TN = [x for x in [16, 32, 64, 128] if x <= N]
|
||||
TB = [x for x in [1, 2, 4] if x <= B]
|
||||
MAX_GZ = K // 2048
|
||||
MIN_GM = M // max(TM)
|
||||
MIN_GN = N // max(TN)
|
||||
MIN_GB = B // max(TB)
|
||||
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
|
||||
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
|
||||
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
|
||||
#TB, TZ = [1], [1]
|
||||
#TM, TN, TB, TZ = [128], [128], [1], [1]
|
||||
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype }
|
||||
self.dtype = dtype
|
||||
self.flops = 2 * B * M * N * K
|
||||
self.sym_a = sym_a
|
||||
self.sym_b = sym_b
|
||||
self.sym_c = sym_c
|
||||
# save equivalent mat-mul dimensions
|
||||
self.matmul_B = B
|
||||
self.matmul_M = M
|
||||
self.matmul_N = N
|
||||
self.matmul_K = K
|
||||
|
||||
def run(self, a, b, c, bench):
|
||||
self.args[self.pos_a] = a
|
||||
self.args[self.pos_b] = b
|
||||
self.args[self.pos_c] = c
|
||||
self.kernel(*self.args, bench=bench, **self.macros)
|
||||
|
||||
# batch, outer, inner, std0, std1, trans
|
||||
return data["B"]["dim"], data["O"]["dim"], data["I"]["dim"], data["B"]["std"], data[std1]["std"], trans
|
||||
|
||||
|
||||
|
||||
############################
|
||||
## Forward
|
||||
############################
|
||||
|
||||
instance_cache = dict()
|
||||
|
||||
@staticmethod
|
||||
def _parse_einsum(labels_a, labels_b, labels_c, shape_a, shape_b):
|
||||
|
||||
dims_a = dict(zip(labels_a, shape_a))
|
||||
dims_b = dict(zip(labels_b, shape_b))
|
||||
shape_c = list()
|
||||
for lab in labels_c:
|
||||
if lab in dims_a:
|
||||
shape_c.append(dims_a[lab])
|
||||
elif lab in dims_b:
|
||||
shape_c.append(dims_b[lab])
|
||||
else:
|
||||
raise ValueError(f"einsum def for output: {lab} ({labels_c}), not present in either input def ({labels_a}, {labels_b})")
|
||||
|
||||
BA, M, KA, std_a0, std_a1, ta = _einsum._parse_abc(labels_a, labels_b, labels_c, shape_a, True)
|
||||
BB, N, KB, std_b0, std_b1, tb = _einsum._parse_abc(labels_b, labels_a, labels_c, shape_b, False)
|
||||
BC, _, _, std_c0, std_c1, _ = _einsum._parse_abc(labels_c, labels_b, labels_a, shape_c)
|
||||
|
||||
if not (BA == BB == BC):
|
||||
raise ValueError("mismatched batch dims")
|
||||
if KA != KB:
|
||||
raise ValueError("mismatched reduction dims")
|
||||
|
||||
return shape_c, (BA, M, N, KA), (std_a0, std_b0, std_c0), (std_a1, std_b1, std_c1), ta, tb
|
||||
|
||||
@staticmethod
|
||||
def call(a, b, trans_a, trans_b, shape_c, bmnk,
|
||||
std0, std1, einsum_a, einsum_b, einsum_c,
|
||||
bench):
|
||||
def forward(ctx, einsum, a, b, shape_c, **kwargs):
|
||||
bench = kwargs['bench'] if 'bench' in kwargs else False
|
||||
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
|
||||
# allocate output
|
||||
dtype = a.dtype
|
||||
c = triton.empty(shape_c, dtype)
|
||||
grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')),
|
||||
triton.cdiv(bmnk[2], opt.d('TN')),
|
||||
triton.cdiv(bmnk[0], opt.d('TB'))]
|
||||
macros = {# handle A transposition
|
||||
'USE_A' : 'a[^1, ^0, ^2]' if trans_a else 'a',
|
||||
'STRIDE_AK' : 'std_A1' if trans_a else '1',
|
||||
'STRIDE_AM' : '1' if trans_a else 'std_A1',
|
||||
'BROADCAST_AK': ':, newaxis, newaxis' if trans_a else 'newaxis, :, newaxis',
|
||||
'BROADCAST_AM': 'newaxis, :, newaxis' if trans_a else ':, newaxis, newaxis',
|
||||
'SHAPE_A' : 'TK, TM, TB' if trans_a else 'TM, TK, TB',
|
||||
# handle B transposition
|
||||
'USE_B' : 'b' if not trans_b else 'b[^1, ^0, ^2]',
|
||||
'STRIDE_BK' : 'std_B1' if not trans_b else '1',
|
||||
'STRIDE_BN' : '1' if not trans_b else 'std_B1',
|
||||
'BROADCAST_BK': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis',
|
||||
'BROADCAST_BN': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis',
|
||||
'SHAPE_B' : 'TK, TN, TB' if not trans_b else 'TN, TK, TB'}
|
||||
TM = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[1]) + 1 ))))]
|
||||
TN = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[2]) + 1 ))))]
|
||||
TB = [2**i for i in range(0, max(1, min(3, int(math.log2(bmnk[0]) + 1 ))))]
|
||||
TK = [bmnk[2]] if bmnk[2] < 16 else [8, 16]
|
||||
_einsum.kernel(a, b, c,
|
||||
bmnk[1], bmnk[2], bmnk[3],
|
||||
std0[0], std0[1], std0[2],
|
||||
std1[0], std1[1], std1[2],
|
||||
grid, bench=bench,
|
||||
**macros,
|
||||
TYPE=dtype, TM=TM, TN=TN, TK=TK, TB=TB)
|
||||
c = triton.empty(shape_c, dtype=dtype)
|
||||
key = (einsum, dtype,
|
||||
a.stride(), b.stride(), c.stride(),
|
||||
a.shape, b.shape, c.shape)
|
||||
# compile einsum instance
|
||||
cache = _einsum.instance_cache
|
||||
#if key not in cache:
|
||||
cache[key] = _einsum.instance(einsum, dtype,
|
||||
a.stride(), b.stride(), c.stride(),
|
||||
a.shape, b.shape, c.shape, arrays)
|
||||
instance = cache[key]
|
||||
instance.run(a, b, c, bench)
|
||||
# save information in context
|
||||
ctx.flops = instance.flops
|
||||
ctx.sym_a = instance.sym_a
|
||||
ctx.sym_b = instance.sym_b
|
||||
ctx.sym_c = instance.sym_c
|
||||
ctx.matmul_B = instance.matmul_B
|
||||
ctx.matmul_M = instance.matmul_M
|
||||
ctx.matmul_N = instance.matmul_N
|
||||
ctx.matmul_K = instance.matmul_K
|
||||
ctx.bench = bench
|
||||
ctx.save_for_backward(a, b)
|
||||
return c
|
||||
|
||||
############################
|
||||
## Backward
|
||||
############################
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, subscripts, a, b, bench = 0):
|
||||
ctx.save_for_backward(a, b)
|
||||
# parse
|
||||
if type(subscripts) is str:
|
||||
einsum_a, einsum_bc = subscripts.split(",")
|
||||
einsum_b, einsum_c = einsum_bc.split("->")
|
||||
else:
|
||||
einsum_a, einsum_b, einsum_c = subscripts
|
||||
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
|
||||
einsum_a, einsum_b, einsum_c,
|
||||
triton.shape(a), triton.shape(b))
|
||||
# save for backward
|
||||
ctx.trans_a = ta
|
||||
ctx.trans_b = tb
|
||||
ctx.einsum_a = einsum_a
|
||||
ctx.einsum_b = einsum_b
|
||||
ctx.einsum_c = einsum_c
|
||||
ctx.bench = bench
|
||||
ctx.bmnk = bmnk
|
||||
# run
|
||||
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench)
|
||||
|
||||
def sym_invert(sym_c, sym_x, prefix, renamed, inverse):
|
||||
for i, expr in enumerate(sym_x):
|
||||
if expr.is_symbol:
|
||||
continue
|
||||
sc = [x for x in expr.free_symbols if x in sym_c][0]
|
||||
sx = sp.symbols(f'{prefix}{i}')
|
||||
renamed[expr] = sx
|
||||
inverse[sc] = sp.solve(sp.Eq(expr, sx), sc)[0]
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dc):
|
||||
def sym_to_expr(sym):
|
||||
res = [f'({x})' for x in sym]
|
||||
res = ''.join(res)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
a, b = ctx.saved_tensors
|
||||
trans_a = ctx.trans_a
|
||||
trans_b = ctx.trans_b
|
||||
einsum_a = ctx.einsum_a
|
||||
einsum_b = ctx.einsum_b
|
||||
einsum_c = ctx.einsum_c
|
||||
bench = ctx.bench
|
||||
sym_a = ctx.sym_a
|
||||
sym_b = ctx.sym_b
|
||||
sym_c = ctx.sym_c
|
||||
inverse = dict()
|
||||
renamed = dict()
|
||||
_einsum.sym_invert(sym_c, sym_a, 'a', renamed, inverse)
|
||||
_einsum.sym_invert(sym_c, sym_b, 'b', renamed, inverse)
|
||||
sym_a = [renamed[x] if x in renamed else x for x in sym_a]
|
||||
sym_b = [renamed[x] if x in renamed else x for x in sym_b]
|
||||
sym_c = [inverse[x] if x in inverse else x for x in sym_c]
|
||||
expr_a = _einsum.sym_to_expr(sym_a)
|
||||
expr_b = _einsum.sym_to_expr(sym_b)
|
||||
expr_c = _einsum.sym_to_expr(sym_c)
|
||||
expr = f'{expr_c},{expr_b}->{expr_a}'
|
||||
da = einsum(expr, dy, b, a.shape, False)
|
||||
return None, da, None, None, None
|
||||
|
||||
if not trans_a and not trans_b: # NN
|
||||
da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench)
|
||||
db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench)
|
||||
|
||||
elif not trans_a and trans_b: # NT
|
||||
da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench)
|
||||
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench)
|
||||
|
||||
elif trans_a and not trans_b: # TN
|
||||
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench)
|
||||
db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench)
|
||||
|
||||
elif trans_a and trans_b: # TT (not used)
|
||||
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench)
|
||||
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench)
|
||||
|
||||
return None, da, db, None
|
||||
|
||||
einsum = _einsum.apply
|
@@ -24,7 +24,7 @@ def empty(shape, dtype):
|
||||
return tf_empty_proxy(shape, dtype)
|
||||
#return fw.tf_extra_ops.alloc_empty(args, T = dtype)
|
||||
elif fw.has_torch():
|
||||
return fw.torch.empty(shape).cuda()
|
||||
return fw.torch.empty(shape, dtype=dtype).cuda()
|
||||
|
||||
def shape(A) :
|
||||
if fw.has_tensorflow():
|
||||
@@ -47,16 +47,23 @@ class id_dict:
|
||||
return libtriton.retrieve_scalar(self.id)
|
||||
|
||||
def __init__(self):
|
||||
self.data = weakref.WeakKeyDictionary()
|
||||
self.data = dict()
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.data[key]
|
||||
|
||||
def __getitem__(self, key):
|
||||
@staticmethod
|
||||
def _get_key(key):
|
||||
if fw.has_tensorflow():
|
||||
if isinstance(key, fw.tensorflow.Tensor):
|
||||
key = key.op
|
||||
ret = self.data[key]
|
||||
key = id(key.op)
|
||||
if fw.has_torch():
|
||||
if isinstance(key, fw.torch.Tensor):
|
||||
key = id(key)
|
||||
return key
|
||||
|
||||
def __getitem__(self, key):
|
||||
ret = self.data[id_dict._get_key(key)]
|
||||
if isinstance(ret, id_dict.lazy_entry):
|
||||
return ret.get()
|
||||
return ret
|
||||
@@ -65,7 +72,4 @@ class id_dict:
|
||||
return len(self.data)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if fw.has_tensorflow():
|
||||
if isinstance(key, fw.tensorflow.Tensor):
|
||||
key = key.op
|
||||
self.data[key] = value
|
||||
self.data[id_dict._get_key(key)] = value
|
@@ -10,10 +10,12 @@ int main() {
|
||||
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(auto ord: std::vector<std::vector<int>>{{1, 0}})
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}, {false, true},
|
||||
{true, false}}){
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
|
||||
std::vector<config_t> tmp = {
|
||||
config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
||||
// config_t{ord, x[0], x[1], 512, 512, 512},
|
||||
// config_t{ord, x[0], x[1], 1024, 1024, 1024},
|
||||
config_t{ord, x[0], x[1], 127008, 768, 576},
|
||||
// config_t{ord, x[0], x[1], 8192, 8192, 8192}
|
||||
// config_t{ord, x[0], x[1], 16, 2048, 2048},
|
||||
// config_t{ord, x[0], x[1], 32, 2048, 2048},
|
||||
// config_t{ord, x[0], x[1], 64, 2048, 2048},
|
||||
@@ -33,7 +35,7 @@ int main() {
|
||||
int32_t M, N, K;
|
||||
for(const auto& c: configs){
|
||||
std::tie(ord, AT, BT, M, N, K) = c;
|
||||
std::cout << "// " << c << std::flush;
|
||||
std::cout << "// " << c ;
|
||||
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
|
||||
std::cout << ", " << perf << std::flush;
|
||||
std::cout << std::endl;
|
||||
|
@@ -20,7 +20,7 @@ static void cc_dot(std::vector<T> &c, const std::vector<T> &a, const std::vector
|
||||
float acc = 0;
|
||||
for(size_t k = 0; k < K; k++)
|
||||
acc = acc + (!AT ? a[k*M + m] : a[m*K + k]) * (!BT ? b[n*K + k] : b[k*N + n]);
|
||||
c[m + n*M] = static_cast<T>(acc);
|
||||
c[m*N + n] = static_cast<T>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,9 +72,9 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
std::string ty = to_string<T>::value;
|
||||
size_t dt_nbytes = sizeof(T);
|
||||
drv::context* context = stream->context();
|
||||
int32_t lda = AT ? K : M;
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = M;
|
||||
int32_t lda = (AT ^ a_order[0]==1) ? K : M;
|
||||
int32_t ldb = (BT ^ b_order[0]==1) ? N : K;
|
||||
int32_t ldc = N;
|
||||
std::vector<std::string> sa = { "1", "lda" };
|
||||
std::vector<std::string> sb = { "1", "ldb" };
|
||||
|
||||
@@ -86,17 +86,17 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
// macros
|
||||
rt::function::options_space_t opt;
|
||||
// A access patterns
|
||||
opt.defines.push_back({"USEA", {AT? "a[^1, ^0]" : "a" }});
|
||||
opt.defines.push_back({"BROADCAST_AK", {AT? ":, newaxis" : "newaxis, :" }});
|
||||
opt.defines.push_back({"BROADCAST_AM", {AT? "newaxis, :" : ":, newaxis" }});
|
||||
opt.defines.push_back({"SHAPE_A", {AT? "TK, TM" : "TM, TK" }});
|
||||
opt.defines.push_back({"USEA", {AT? "a" : "a" }});
|
||||
opt.defines.push_back({"BROADCAST_AK", {AT? "newaxis, :" : "newaxis, :" }});
|
||||
opt.defines.push_back({"BROADCAST_AM", {AT? ":, newaxis" : ":, newaxis" }});
|
||||
opt.defines.push_back({"SHAPE_A", {AT? "TM, TK" : "TM, TK" }});
|
||||
opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
|
||||
opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
|
||||
// B access patterns
|
||||
opt.defines.push_back({"USEB", {BT? "b[^1, ^0]" : "b" }});
|
||||
opt.defines.push_back({"BROADCAST_BK", {BT? "newaxis, :" : ":, newaxis" }});
|
||||
opt.defines.push_back({"BROADCAST_BN", {BT? ":, newaxis" : "newaxis, :" }});
|
||||
opt.defines.push_back({"SHAPE_B", {BT? "TN, TK" : "TK, TN" }});
|
||||
opt.defines.push_back({"USEB", {BT? "b" : "b" }});
|
||||
opt.defines.push_back({"BROADCAST_BK", {BT? ":, newaxis" : ":, newaxis" }});
|
||||
opt.defines.push_back({"BROADCAST_BN", {BT? "newaxis, :" : "newaxis, :" }});
|
||||
opt.defines.push_back({"SHAPE_B", {BT? "TK, TN" : "TK, TN" }});
|
||||
opt.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }});
|
||||
opt.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }});
|
||||
// data-type
|
||||
@@ -109,15 +109,15 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
opt.num_warps = {nwarp};
|
||||
}
|
||||
if(mode == BENCH) {
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"16"}});
|
||||
opt.num_warps = {4};
|
||||
opt.defines.push_back({"TM", {"32", "64", "128"}});
|
||||
opt.defines.push_back({"TN", {"32", "64", "128"}});
|
||||
opt.defines.push_back({"TK", {to_string<T>::value == "half" ? "16" : "8"}});
|
||||
opt.num_warps = {2, 4, 8};
|
||||
}
|
||||
|
||||
// kernels
|
||||
rt::function function(src::dot, opt);
|
||||
std::vector<rt::arg> args = {&*da, &*db, &*dc, M, N, K, lda, ldb, ldc};
|
||||
std::vector<rt::arg> args = {&*da, &*db, &*dc, (float)1, M, N, K, lda, ldb, ldc};
|
||||
auto grid = grid2d(M, N);
|
||||
|
||||
// metrics
|
||||
@@ -126,17 +126,17 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
|
||||
bench.push_back(tflops(triton_ns));
|
||||
|
||||
// // cublas
|
||||
// if(cublas::cublasinit()){
|
||||
// NumericT alpha(static_cast<double>(1));
|
||||
// NumericT beta(static_cast<double>(0));
|
||||
// cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(cuty, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
||||
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(cuty, stream, AT, BT, M, N, K,
|
||||
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
|
||||
// ldc, nullptr, fastest); }, stream);
|
||||
// result.push_back(tflops(cublas_ms));
|
||||
// }
|
||||
// // cublas
|
||||
// if(cublas::cublasinit()){
|
||||
// T alpha(static_cast<double>(1));
|
||||
// T beta(static_cast<double>(0));
|
||||
// cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
||||
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
|
||||
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
|
||||
// ldc, nullptr, fastest); }, stream);
|
||||
// bench.push_back(tflops(cublas_ms));
|
||||
// }
|
||||
}
|
||||
|
||||
// test triton
|
||||
@@ -147,9 +147,9 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
std::vector<T> ha(M*K);
|
||||
std::vector<T> hb(K*N);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = 1;
|
||||
ha[i] = (float)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = 1;
|
||||
hb[i] = (float)rand()/RAND_MAX;
|
||||
// copy buffer
|
||||
stream->write(&*da, true, 0, ha);
|
||||
stream->write(&*db, true, 0, hb);
|
||||
|
@@ -2,37 +2,58 @@ namespace src {
|
||||
|
||||
const char *dot =
|
||||
R"(
|
||||
void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
int rk[TK] = 0 ... TK;
|
||||
float c[TM, TN] = 0;
|
||||
// pointers to operands
|
||||
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
|
||||
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
|
||||
// prefetches operands
|
||||
TYPE a[SHAPE_A] = *pa;
|
||||
TYPE b[SHAPE_B] = *pb;
|
||||
// reduction loop
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += USEA @ USEB;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
bool checka[SHAPE_A] = k > TK;
|
||||
bool checkb[SHAPE_B] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
// epilogue
|
||||
TYPE* pc[TM, TN] = C + rm[:, newaxis] + rn[newaxis, :] * ldc;
|
||||
*pc = c;
|
||||
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __aligned(16),
|
||||
float alpha,
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int gridx = M / TM;
|
||||
int gridy = N / TN;
|
||||
int rid = ridx + ridy * gridx;
|
||||
ridx = rid / gridy;
|
||||
ridy = rid % gridy;
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
int rk[TK] = 0 ... TK;
|
||||
|
||||
// pointers to operands
|
||||
int offa[SHAPE_A] = rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
|
||||
int offb[SHAPE_B] = rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
|
||||
TYPE* pa[SHAPE_A] = A + offa;
|
||||
TYPE* pb[SHAPE_B] = B + offb;
|
||||
|
||||
// prefetches operands
|
||||
bool checka[SHAPE_A] = rk[BROADCAST_AK] < K;
|
||||
bool checkb[SHAPE_B] = rk[BROADCAST_BK] < K;
|
||||
TYPE a[SHAPE_A] = checka ? *pa : 0;
|
||||
TYPE b[SHAPE_B] = checkb ? *pb : 0;
|
||||
|
||||
// reduction loop
|
||||
float c[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
c += USEA @ USEB;
|
||||
bool checka[SHAPE_A] = k > TK;
|
||||
bool checkb[SHAPE_B] = k > TK;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
//c = c * alpha;
|
||||
|
||||
// epilogue
|
||||
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
|
||||
int rxn[TN] = get_program_id(1) * TN + 0 ... TN;
|
||||
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
|
||||
*?(checkc)pc = (TYPE[TM, TN])c;
|
||||
}
|
||||
)";
|
||||
|
||||
|
@@ -159,7 +159,7 @@ bool diff(const std::vector<T>& hc, const std::vector<T>& rc) {
|
||||
if(hc.size() != rc.size())
|
||||
return false;
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
@@ -10,8 +10,8 @@ int main() {
|
||||
// shapes to test
|
||||
typedef std::tuple<dtype_t, bool, bool, int, int, int, int, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(int TM: std::vector<int>{32, 64})
|
||||
for(int TN: std::vector<int>{32, 64})
|
||||
for(int TM: std::vector<int>{32, 64, 128})
|
||||
for(int TN: std::vector<int>{32, 64, 128})
|
||||
for(int TK: std::vector<int>{16})
|
||||
for(int nwarps: std::vector<int>{4})
|
||||
for(bool AT: std::array<bool, 2>{false, true})
|
||||
|
Reference in New Issue
Block a user