[GENERAL] Merged einsum feature branch. Various feature, performance

improvements and bugfixes:

* Added preliminary support for extended Einstein summation in PyTriton
* Significant performance improvement on FP32 kernels containing matrix
multiplication
* Added re-coalescing pass for FP16 kernels containing matrix
multiplication
* Various bugfixes
This commit is contained in:
Philippe Tillet
2020-01-16 12:09:50 -05:00
parent 50a52df489
commit f278d9741a
49 changed files with 1923 additions and 994 deletions

View File

@@ -1,83 +0,0 @@
#ifndef _TRITON_CODEGEN_INSTRUCTIONS_H_
#define _TRITON_CODEGEN_INSTRUCTIONS_H_
#include "triton/ir/enums.h"
#include <map>
#include <vector>
namespace triton{
namespace ir{
class instruction;
}
namespace codegen{
enum storage_info_t {
NONE,
ANY,
SHARED,
DISTRIBUTED,
REPLICATED
};
typedef std::pair<storage_info_t, std::vector<storage_info_t>> inst_storage_info_t;
static const std::map<ir::value_id_t, inst_storage_info_t> storage_info = {
// scalars
{ ir::INST_GET_PROGRAM_ID, {REPLICATED, {}}},
{ ir::INST_GET_NUM_PROGRAMS, {REPLICATED, {}}},
// scalar/array
{ ir::INST_PHI, {ANY, {ANY, ANY}}},
{ ir::INST_BINOP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
{ ir::INST_GETELEMENTPTR, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
{ ir::INST_SELECT, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED, DISTRIBUTED}}},
{ ir::INST_SQRT, {DISTRIBUTED, {DISTRIBUTED}}},
// cmp
{ ir::INST_ICMP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
{ ir::INST_FCMP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
// cast
{ ir::INST_CAST_TRUNC, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_ZEXT, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_SEXT, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_FP_TRUNC, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_FP_EXT, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_UI_TO_FP, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_SI_TO_FP, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_FP_TO_UI, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_FP_TO_SI, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_PTR_TO_INT, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_INT_TO_PTR, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_BIT_CAST, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_CAST_ADDR_SPACE_CAST, {DISTRIBUTED, {DISTRIBUTED}}},
// io
{ ir::INST_UNMASKED_LOAD, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_MASKED_LOAD, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}},
{ ir::INST_UNMASKED_STORE, {NONE , {DISTRIBUTED, DISTRIBUTED}}},
{ ir::INST_MASKED_STORE, {NONE , {DISTRIBUTED, DISTRIBUTED, DISTRIBUTED}}},
// retile
{ ir::INST_RESHAPE, {DISTRIBUTED, {DISTRIBUTED}}},
{ ir::INST_SPLAT, {DISTRIBUTED, {REPLICATED}}},
{ ir::INST_BROADCAST, {DISTRIBUTED, {REPLICATED}}},
{ ir::INST_DOWNCAST, {DISTRIBUTED, {REPLICATED}}},
// array arithmetic
{ ir::INST_TRANS, {SHARED, {SHARED}}},
{ ir::INST_REDUCE, {SHARED, {DISTRIBUTED}}},
{ ir::INST_DOT, {DISTRIBUTED, {SHARED, SHARED, DISTRIBUTED}}},
// terminator
{ ir::INST_RETURN, {NONE, {}}},
{ ir::INST_UNCOND_BRANCH, {NONE, {}}},
{ ir::INST_COND_BRANCH, {NONE, {REPLICATED}}},
// intrinsics
{ ir::INST_COPY_TO_SHARED, {SHARED, {DISTRIBUTED}}},
{ ir::INST_COPY_FROM_SHARED, {DISTRIBUTED, {SHARED}}},
{ ir::INST_BARRIER, {NONE, {}}},
{ ir::INST_MAKE_RANGE_DYN, {DISTRIBUTED, {}}},
{ ir::INST_MAKE_RANGE_STA, {DISTRIBUTED, {}}},
{ ir::INST_MAKE_RANGE, {DISTRIBUTED, {}}}
};
}
}
#endif

View File

@@ -122,6 +122,7 @@ public:
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);
void visit_recoalesce_inst(ir::recoalesce_inst*);
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
void visit_barrier_inst(ir::barrier_inst*);

View File

@@ -128,22 +128,22 @@ private:
Type *make_vector_ty(Type *ty, size_t vector_size);
public:
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder, bool vectorize);
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder);
void set_value(indices_t idx, Value *v);
Value* get_value(indices_t idx);
const std::vector<int>& get_order() { return order_; }
unsigned get_linear_index(indices_t idx);
indices_t get_ordered_indices(unsigned id);
void for_each(std::function<void(indices_t)> fn);
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }
void for_each(std::function<void(indices_t)> fn, int start = 0, int end = -1);
void for_each(std::function<void(indices_t)> fn, std::vector<int> start, std::vector<int> size);
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }
private:
axes_t axes_;
std::vector<int> order_;
indices_map_t indices_;
values_map_t values_;
ordered_indices_vec_t ordered_indices_;
size_t vector_size_;
Builder &builder_;
};

View File

@@ -45,6 +45,8 @@ public:
llvm::SmallVectorImpl<char> &buffer,
const std::string &features,
file_type_t file_type);
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
protected:
driver::context* ctx_;
@@ -54,13 +56,14 @@ protected:
class host_module: public module{
public:
host_module(driver::context* context, std::unique_ptr<llvm::Module> module);
std::unique_ptr<buffer> symbol(const char * name) const;
};
// OpenCL
class ocl_module: public module{
public:
ocl_module(driver::context* context, std::unique_ptr<llvm::Module> module);
std::unique_ptr<buffer> symbol(const char * name) const;
};
// CUDA
@@ -70,7 +73,7 @@ class cu_module: public module {
public:
cu_module(driver::context* context, std::unique_ptr<llvm::Module> module);
cu_module(driver::context* context, const std::string& source);
cu_buffer* symbol(const char * name) const;
std::unique_ptr<buffer> symbol(const char * name) const;
private:
std::string source_;

View File

@@ -37,6 +37,7 @@ public:
// Constants
value *get_int32(unsigned val);
// Types
type *get_void_ty();
type *get_int1_ty();
type *get_int8_ty();
type *get_int16_ty();
@@ -115,10 +116,10 @@ public:
value *create_and(value *lhs, value *rhs, const std::string &name = "");
value *create_xor(value *lhs, value *rhs, const std::string &name = "");
value *create_or(value *lhs, value *rhs, const std::string &name = "");
// Side effects
value *create_fneg(value *arg, const std::string &name = "");
value *create_neg(value *arg, const std::string &name = "");
value *create_not(value *arg, const std::string &name = "");
// Unary
// value *create_fneg(value *arg, const std::string &name = "");
// value *create_neg(value *arg, const std::string &name = "");
// value *create_not(value *arg, const std::string &name = "");
// Input/Output
value *create_load(value *arg, const std::string &name = "");
value *create_store(value *ptr, value *val, const std::string &name = "");

View File

@@ -134,6 +134,7 @@ enum value_id_t: unsigned {
// intrinsics
INST_COPY_TO_SHARED,
INST_COPY_FROM_SHARED,
INST_RECOALESCE,
INST_BARRIER,
INST_MAKE_RANGE_DYN,
INST_MAKE_RANGE_STA,

View File

@@ -148,9 +148,9 @@ public:
// Factory methods
static binary_operator *create(binary_op_t op, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr);
static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
// static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr);
// static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
// static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(binary_operator)
_TRITON_DEFINE_ACCEPT(binary_operator)
@@ -732,6 +732,17 @@ public:
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
};
class recoalesce_inst: public unary_inst{
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "recoalesce_inst"; }
public:
static recoalesce_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(recoalesce_inst)
_TRITON_DEFINE_ACCEPT(recoalesce_inst)
};
class barrier_inst: public instruction{
private:
barrier_inst(context &ctx, const std::string &name, instruction *next);

View File

@@ -59,6 +59,7 @@ class sqrt_inst;
class reduce_inst;
class select_inst;
class recoalesce_inst;
class copy_to_shared_inst;
class copy_from_shared_inst;
class barrier_inst;
@@ -129,6 +130,7 @@ public:
virtual void visit_reduce_inst(reduce_inst*) = 0;
virtual void visit_select_inst(select_inst*) = 0;
virtual void visit_recoalesce_inst(recoalesce_inst*) = 0;
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
virtual void visit_barrier_inst(barrier_inst*) = 0;

View File

@@ -47,6 +47,7 @@ protected:
};
void set_ret(ir::value* value);
ir::value *GenUnaryMinus(ir::value* arg);
public:
Generator(Parser* parser) : parser_(parser) {}

View File

@@ -145,6 +145,7 @@ public:
THREAD, // _Thread_local
AUTO,
GLOBAL,
CMEM, // constant memory
// STORAGE CLASS SPECIFIER END
BREAK,

View File

@@ -39,7 +39,7 @@ enum {
S_EXTERN = 0x02,
S_STATIC = 0x04,
S_THREAD = 0x08,
S_AUTO = 0x10,
S_CONSTANT = 0x10,
S_GLOBAL = 0x20,
// Type specifier
@@ -73,7 +73,8 @@ struct Qualifier {
CONST = 0x01,
RESTRICT = 0x02,
VOLATILE = 0x04,
MASK = CONST | RESTRICT | VOLATILE
CMEM = 0x08,
MASK = CONST | RESTRICT | VOLATILE | CMEM
};
};
@@ -111,6 +112,7 @@ public:
bool IsConstQualified() const { return ptr_ & Qualifier::CONST; }
bool IsRestrictQualified() const { return ptr_ & Qualifier::RESTRICT; }
bool IsVolatileQualified() const { return ptr_ & Qualifier::VOLATILE; }
bool IsConstantQualified() const { return ptr_ & Qualifier::CMEM; }
private:
intptr_t ptr_;

View File

@@ -8,11 +8,9 @@
#include <string>
#include <memory>
#include <functional>
#include <mutex>
// codegen
#include "triton/ir/context.h"
#include "triton/codegen/target.h"
#include "triton/lang/parser.h"
#include "triton/runtime/arg.h"
namespace llvm {
@@ -20,6 +18,8 @@ namespace llvm {
class LLVMContext;
}
class Parser;
namespace triton {
namespace driver{
@@ -106,14 +106,14 @@ public:
function(const std::string& src, const options_space_t& opt = options_space_t());
void operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream* stream);
void operator()(const std::vector<arg>& args, const grid_fn_ty& grid, driver::stream *stream);
std::string make_tensorflow_src(const std::vector<size_t> &outputs, const std::string &macro);
void set_cst(const std::string& name, void* data, size_t n_bytes);
private:
ir::context ctx_;
std::string src_;
options_space_t opt_space_;
std::map<cache_key_t, caller> cache_;
std::mutex src_mutex_;
std::map<std::string, std::vector<char>> cst_;
};
}

View File

@@ -30,7 +30,7 @@ private:
high_resolution_clock::time_point _start;
};
inline double bench(std::function<void()> const & op, driver::stream * stream)
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false)
{
// const driver::device * device = stream->context()->device();
timer tmr;
@@ -38,9 +38,10 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
double total_time = 0;
op();
stream->synchronize();
while(total_time*1e-9 < 1e-2){
while(total_time*1e-9 < 1e-1){
float norm = 1;
// normalize clock if possible to reduce noise in auto-tuning
if(normalize)
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start();

View File

@@ -5,24 +5,41 @@
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include <iostream>
namespace triton {
namespace codegen{
namespace analysis{
inline int gcd(int a, int b) {
if (a == 0)
return b;
if (b == 0)
return a;
if (a == b)
return a;
if (a > b)
return gcd(a - b, b);
return gcd(a, b - a);
// Function for extended Euclidean Algorithm
int gcd_impl(int a, int b, int *x, int *y)
{
// Base Case
if (a == 0)
{
*x = 0;
*y = 1;
return b;
}
int x1, y1; // To store results of recursive call
int gcd = gcd_impl(b%a, a, &x1, &y1);
// Update x and y using results of
// recursive call
*x = y1 - (b/a) * x1;
*y = x1;
return gcd;
}
int gcd(int a, int b) {
int x, y;
return gcd_impl(a, b, &x, &y);
}
inline int lcm(int a, int b) {
return (a * b) / gcd(a, b);
}
@@ -156,7 +173,7 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
if(is_constant_.find(v) != is_constant_.end())
return is_constant_.at(v);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(v, {cst_info{true, (unsigned)x->get_value()}}, is_constant_);
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
if(dynamic_cast<ir::make_range_sta*>(v))
return add_to_cache(v, {cst_info{true, 0}}, is_constant_);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
@@ -448,7 +465,7 @@ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_starting_multiple_binop(x);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(x, {(unsigned)x->get_value()}, starting_multiple_);
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range*>(v))
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range_dyn*>(v))
@@ -484,6 +501,7 @@ void align::populate(ir::value *v) {
populate_is_constant(v);
populate_starting_multiple(v);
populate_max_contiguous(v);
}
void align::run(ir::module &mod) {

View File

@@ -113,6 +113,7 @@ void axes::update_graph(ir::instruction *i) {
case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);;
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
case ir::INST_RECOALESCE: return update_graph_no_edge(i);
default: return update_graph_elementwise(i);
}
return;

View File

@@ -1,9 +1,9 @@
#include <algorithm>
#include <numeric>
#include <iostream>
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/instructions.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
@@ -148,8 +148,11 @@ layout_t::layout_t(layout_type_t _type,
extract_io_use(v, ptr);
order.resize(axes.size());
std::iota(order.begin(), order.end(), 0);
for(ir::value *v: ptr){
auto max_contiguous = align->contiguous(v);
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank();
});
if(*largest){
auto max_contiguous = align->contiguous(*largest);
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b];
});
@@ -166,9 +169,8 @@ layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
const std::vector<unsigned>& _shapes,
const std::vector<ir::value *> &values, ir::type *_ty, size_t _id,
analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _ty, _id, align) {
unsigned shape_0 = shapes[order[0]];
unsigned shape_1 = shapes[order[1]];
unsigned shape_0 = shapes[0];
unsigned shape_1 = shapes[1];
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
fpw = {1, 1, 1};
@@ -196,6 +198,7 @@ layout_hmma_884_t::layout_hmma_884_t(size_t num_warps,
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_warps *= wpt[d];
if(num_warps != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
@@ -213,20 +216,38 @@ layout_scanline_t::layout_scanline_t(size_t num_warps,
unsigned num_threads = num_warps * 32;
nts.resize(shapes.size());
mts.resize(shapes.size());
bool is_dot = std::any_of(values.begin(), values.end(),
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
ir::value *ptr = nullptr;
for(ir::value *v: values)
for(ir::user *usr: v->get_users())
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
ptr = st->get_pointer_operand();
unsigned i = order[0];
nts[i] = clamp(size / num_threads, 1, 4);
int contiguous = 4;
if(ptr)
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
nts[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shapes[i]));
mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]);
num_threads = num_threads / mts[i];
size /= shapes[i];
num_threads /= mts[i];
if(is_dot)
nts[order[1]] = clamp(size / num_threads, 1, std::min<int>(4, shapes[order[1]]));
for(size_t d = 1; d < shapes.size(); d++){
i = order[d];
nts[i] = 1;
mts[i] = clamp(num_threads, 1, shapes[i]);
if(d > 1 || !is_dot)
nts[i] = 1;
mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]);
num_threads = num_threads / mts[i];
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_threads *= mts[d];
if(num_warps * 32 != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
@@ -259,8 +280,8 @@ void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
if(!i_0 || !i_1 ||
storage_info.at(i_0->get_id()).first != codegen::SHARED ||
storage_info.at(i_1->get_id()).first != codegen::SHARED)
!dynamic_cast<ir::copy_to_shared_inst*>(i_0) ||
!dynamic_cast<ir::copy_to_shared_inst*>(i_1) )
return;
if(is_latch_1)
res.reset(new double_buffer_info_t{value_0, value_1, phi});
@@ -284,10 +305,9 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
extract_double_bufferable(v, double_buffer);
// order
if(arg->type == SCANLINE)
order = arg->order;
else
order = arg->order;
std::vector<int> arg_order = arg ? arg->order : std::vector<int>{0};
order = arg_order;
ir::value* dot_a = nullptr;
ir::value* dot_b = nullptr;
ir::value* hmma_dot_a = nullptr;
@@ -304,24 +324,27 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
col.push_back(s);
row.push_back(s);
}
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
if(is_nonhmma_dot_a)
order = is_trans(dot_a) ? row : col;
if(is_nonhmma_dot_b)
else if(is_nonhmma_dot_b)
order = is_trans(dot_b) ? col : row;
// else
// order = row;
// padding
pad = 0;
if(hmma_dot_a){
bool row = is_trans(hmma_dot_a) ^ order[0] != 0;
pad = 24 - shapes[row ? order[0] : order[1]] % 32;
pad = 24 - shapes[row ? 0 : 1] % 32;
}
else if(hmma_dot_b){
bool row = is_trans(hmma_dot_b) ^ order[0] != 0;
pad = 24 - shapes[row ? order[1] : order[0]] % 32;
pad = 24 - shapes[row ? 1 : 0] % 32;
}
else if(order != arg->order) {
else if(order != arg_order) {
pad = 4;
}
shapes[order[0]] += pad;
@@ -395,6 +418,29 @@ void layout::run(ir::module &mod) {
layouts_[id] = new layout_shared_t(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), id, align_);
tmp_[red] = id;
}
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
ir::value *val = recoalasce->get_operand(0);
const layout_t* in_layout = get(val);
const layout_t* out_layout = get(i);
if(in_layout->type != HMMA_884)
return;
id++;
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
ir::type::tile_shapes_t shape(in_shape.size());
size_t ld = out_layout->order[0];
shape[ld] = in_shape[ld];
for(size_t k = 0; k < in_shape.size(); k++)
if(k != ld)
shape[k] = 4*in_layout->fpw[k]*in_layout->wpt[k];
// create layout
layouts_[id] = new layout_shared_t(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), id, align_);
tmp_[recoalasce] = id;
}
if(auto *atom = dynamic_cast<ir::atomic_cas_inst*>(i)){
id++;
layouts_[id] = new layout_shared_t(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), id, align_);
tmp_[atom] = id;
}
});
}

View File

@@ -7,7 +7,6 @@
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/instructions.h"
#include "triton/ir/context.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
@@ -351,10 +350,9 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
unsigned id = linear / vector_size;
if(linear % vector_size == 0) {
Value *ptr = pointers->get_value(idx);
ptr = builder_->CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
ptr->getType()->getPointerAddressSpace()));
Value *mask = masks->get_value(idx);
BasicBlock *current_bb = builder_->GetInsertBlock();
Function *parent = builder_->GetInsertBlock()->getParent();
@@ -386,9 +384,9 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
// if(cst)
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*2*vector_size);
// Type *fp16x2_ty = VectorType::get(builder_->getHalfTy(), 2);
// Type *fp16x2_pack4_ty = StructType::get(ctx, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
// Type *fp16x2_pack4_ty = StructType::get(*ctx_, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
// FunctionType *ty = FunctionType::get(fp16x2_pack4_ty, {mask->getType(), ptr->getType()}, false);
// std::string asm_str = "@$0 ld.global.nc.b32 {$1, $2, $3, $4}, [$5" + offset + "];";
// std::string asm_str = "@$0 ld.global.nc.v4.b32 {$1, $2, $3, $4}, [$5" + offset + "];";
// if(false_values)
// asm_str += "\n\t@!$0 mov.v4.b32 {$1, $2, $3, $4}, {0, 0, 0, 0};";
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true);
@@ -420,31 +418,83 @@ void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* st) {
void generator::visit_masked_store_inst(ir::masked_store_inst* st) {
distributed_tile* ptrs = (distributed_tile*)tmap_.at(st->get_pointer_operand());
distributed_tile* scalars = (distributed_tile*)tmap_.at(st->get_value_operand());
ir::value *mask = st->get_mask_operand();
distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
ptrs->for_each([&](indices_t idx){
Value *scalar = scalars->get_value(idx);
Value *ptr = ptrs->get_value(idx);
Value *pred = preds->get_value(idx);
Function *parent = builder_->GetInsertBlock()->getParent();
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
builder_->CreateCondBr(pred, mask_then_bb, mask_done_bb);
builder_->SetInsertPoint(mask_then_bb);
builder_->CreateStore(scalar, ptr);
builder_->CreateBr(mask_done_bb);
builder_->SetInsertPoint(mask_done_bb);
// std::string offset = "";
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
// if(gep->getNumIndices() == 1)
// if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*4);
// }
// FunctionType *ty = FunctionType::get(Type::getVoidTy(ctx), {pred->getType(), ptr->getType(), scalar->getType()}, false);
// std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;";
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true);
// builder.CreateCall(iasm, {pred, ptr, scalar});
distributed_tile* masks = (distributed_tile*)tmap_.at(st->get_mask_operand());
// vector size
int vector_size = 1;
int ld = ptrs->get_order()[0];
unsigned alignment = alignment_->get(st->get_pointer_operand(), ld);
vector_size = std::min<unsigned>(ptrs->axis(ld).contiguous, alignment);
// create packets
std::map<unsigned, Value*> packets;
ir::value *arg = st->get_value_operand();
for_each(arg, [&](indices_t idx){
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
unsigned linear = in->get_linear_index(idx);
unsigned id = linear / vector_size;
Value *in_value = in->get_value(idx);
if(linear % vector_size == 0)
packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size));
packets[id] = builder_->CreateInsertElement(packets.at(id), in_value, linear % vector_size);
});
// write-back packets
for_each(arg, [&](indices_t idx){
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
unsigned linear = in->get_linear_index(idx);
unsigned id = linear / vector_size;
if(linear % vector_size == 0){
// fetch tile elements
Value *elt = packets[id];
Value *ptr = ptrs->get_value(idx);
Value *pred = masks->get_value(idx);
// type information
Type *ty = elt->getType();
unsigned nbits = ty->getScalarSizeInBits();
unsigned nbytes = nbits / 8;
// extract pointer offset
std::string offset = "";
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
if(gep->getNumIndices() == 1)
if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbytes);
ptr = gep->getPointerOperand();
}
ptr = builder_->CreateBitCast(ptr, ty->getPointerTo(1));
// asm argument type
std::vector<Type*> arg_ty = {pred->getType(), ptr->getType()};
for(int v = 0; v < vector_size; v++)
arg_ty.push_back(ty->getScalarType());
// asm function type
FunctionType *fn_ty = FunctionType::get(builder_->getVoidTy(), arg_ty, false);
// asm string
std::string asm_str;
asm_str += "@$0 st.global";
if(vector_size > 1)
asm_str += ".v" + std::to_string(vector_size);
asm_str += ".b" + std::to_string(nbits) + " [$1" + offset + "],";
if(vector_size > 1)
asm_str += "{";
for(int v = 0; v < vector_size; v++){
if(v > 0)
asm_str += ", ";
asm_str += "$" + std::to_string(2 + v);
}
if(vector_size > 1)
asm_str += "}";
asm_str += ";";
// asm constraint
std::string constraint = "b,l";
for(int v = 0; v < vector_size; v++){
constraint += ",";
constraint += (nbits == 32 ? "r" : "h");
}
// create inline asm
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
// call asm
std::vector<Value*> args = {pred, ptr};
for(int v = 0; v < vector_size; v++)
args.push_back(builder_->CreateExtractElement(elt, builder_->getInt32(v)));
builder_->CreateCall(iasm, args);
}
});
}
@@ -504,23 +554,27 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0));
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
Value *ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(cas))));
ptr = builder_->CreateBitCast(ptr, PointerType::get(builder_->getInt32Ty(), ptr->getType()->getPointerAddressSpace()));
tgt_->add_memfence(module, *builder_);
tgt_->add_barrier(module, *builder_);
tgt_->add_memfence(module, *builder_);
builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
builder_->SetInsertPoint(tid_0_bb);
Value *cas_ptr = vmap_.at(cas->get_operand(0));
Value *cas_cmp = vmap_.at(cas->get_operand(1));
Value *cas_val = vmap_.at(cas->get_operand(2));
Value *old = builder_->CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic);
Value *old = builder_->CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val,
AtomicOrdering::Monotonic,
AtomicOrdering::Monotonic);
old = builder_->CreateExtractValue(old, {0});
builder_->CreateStore(old, ptr);
Value *atom_ptr;
atom_ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))));
atom_ptr = builder_->CreateBitCast(atom_ptr, PointerType::get(old->getType(), 3));
builder_->CreateStore(old, atom_ptr);
builder_->CreateBr(tid_0_done_bb);
builder_->SetInsertPoint(tid_0_done_bb);
tgt_->add_memfence(module, *builder_);
tgt_->add_barrier(module, *builder_);
vmap_[cas] = builder_->CreateLoad(ptr);
vmap_[cas] = builder_->CreateLoad(atom_ptr);
}
void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
@@ -533,14 +587,14 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent());
BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent());
tgt_->add_memfence(module, *builder_);
tgt_->add_barrier(module, *builder_);
builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
builder_->SetInsertPoint(tid_0_bb);
vmap_[xchg] = builder_->CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, AtomicOrdering::Monotonic, SyncScope::System);
builder_->CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val,
AtomicOrdering::Monotonic,
SyncScope::System);
builder_->CreateBr(tid_0_done_bb);
builder_->SetInsertPoint(tid_0_done_bb);
tgt_->add_memfence(module, *builder_);
tgt_->add_barrier(module, *builder_);
}
void generator::visit_atomic_add_inst(ir::atomic_add_inst*) {
@@ -861,6 +915,115 @@ void generator::visit_select_inst(ir::select_inst* select) {
}
void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
ir::value *op = rc->get_operand(0);
ir::tile_type::tile_shapes_t shape = rc->get_type()->get_tile_shapes();
size_t rank = shape.size();
// temporary layout
shared_tile *tmp = (shared_tile*)machine_layouts_.at(layouts_->get(layouts_->tmp(rc)))
->create(rc);
// pointer to temporary shared memory
Type *ty = llvm_type(rc->get_type()->get_scalar_ty(), *ctx_);
// layouts
const analysis::layout_t* in_layout = layouts_->get(op);
const analysis::layout_t* out_layout = layouts_->get(rc);
// machine tiles
distributed_tile *in_dt = (distributed_tile*)(tmap_.at(op));
distributed_tile *out_dt = (distributed_tile*)(tmap_.at(rc));
// WMMA configuration
long wmma_pt[3] = { 2, 4, 1};
long wmma[3] = { 8*in_layout->wpt[0]*in_layout->fpw[0],
8*in_layout->wpt[1]*in_layout->fpw[1],
1};
// Work per thread for input layout
long in_pt[3] = { shape[0] / wmma[0],
shape[1] / wmma[1],
1 };
// Work per thread for output layout
long out_pt[3] = { shape[0] / out_layout->mts[0],
shape[1] / out_layout->mts[1],
1 };
if(rank > 2){
wmma[2] = in_layout->wpt[2]*in_layout->fpw[2];
in_pt[2] = shape[2] / wmma[2];
out_pt[2] = shape[2] / out_layout->mts[2];
}
// Orders
auto ord = out_layout->order;
if(ord.size() < 3)
ord.push_back(2);
// pointer lanes
std::vector<std::vector<Value*>> ptrs;
for(int in_zz = 0; in_zz < wmma_pt[ord[2]]; in_zz++) {
std::vector<Value*> current;
for(int in_cc = 0; in_cc < wmma_pt[ord[1]]; in_cc++) {
Value *base;
base = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(layouts_->tmp(rc)))));
base = builder_->CreateBitCast(base, PointerType::get(ty, 3));
// shared memory stride
Value *stride_0 = builder_->getInt32(tmp->get_shapes()[ord[0]]);
// indices
Value *idx_cc = axes_.at(a_axes_->get(op, ord[1])).values[in_cc];
// offset
Value *off = builder_->CreateMul(stride_0, idx_cc);
if(rank > 2){
Value *stride_1 = builder_->CreateMul(stride_0,
builder_->getInt32(tmp->get_shapes()[ord[1]]));
Value *idx_zz = axes_.at(a_axes_->get(op, ord[2])).values[in_zz];
off = builder_->CreateAdd(off, builder_->CreateMul(stride_1, idx_zz));
}
current.push_back(builder_->CreateGEP(base, off));
}
ptrs.push_back(current);
}
// Re-coalesce loops
for(int in_z = 0; in_z < in_pt[ord[2]]; in_z++)
for(int in_c = 0; in_c < in_pt[ord[1]]; in_c++){
// write to shared
tgt_->add_barrier(mod_, *builder_);
for(int in_zz = 0; in_zz < wmma_pt[ord[2]]; in_zz++)
for(int in_cc = 0; in_cc < wmma_pt[ord[1]]; in_cc++){
std::vector<int> starts(rank), len(rank);
starts[ord[0]] = 0;
starts[ord[1]] = in_c*wmma_pt[ord[1]] + in_cc;
len[ord[0]] = wmma_pt[ord[0]]*in_pt[ord[0]];
len[ord[1]] = 1;
if(rank > 2){
starts[ord[2]] = in_z*wmma_pt[ord[2]] + in_zz;
len[ord[2]] = 1;
}
in_dt->for_each([&](indices_t idx){
Value *write_ptr = builder_->CreateGEP(ptrs[in_zz][in_cc], idx[ord[0]]);
builder_->CreateStore(in_dt->get_value(idx), write_ptr);
}, starts, len);
}
tgt_->add_barrier(mod_, *builder_);
// load from shared
for(int out_zz = 0; out_zz < out_pt[ord[2]] / in_pt[ord[2]]; out_zz++)
for(int out_cc = 0; out_cc < out_pt[ord[1]] / in_pt[ord[1]]; out_cc++){
std::vector<int> starts(rank), len(rank);
starts[ord[0]] = 0;
starts[ord[1]] = in_c*(out_pt[ord[1]] / in_pt[ord[1]]) + out_cc;
len[ord[0]] = out_pt[ord[0]];
len[ord[1]] = 1;
if(rank > 2){
starts[ord[2]] = in_z*(out_pt[ord[2]] / in_pt[ord[2]]) + out_zz;
len[ord[2]] = 1;
}
out_dt->for_each([&](indices_t idx){
indices_t read_idx(rank);
read_idx[ord[0]] = idx[ord[0]];
read_idx[ord[1]] = axes_.at(a_axes_->get(rc, ord[1])).values[out_cc];
if(rank > 2)
read_idx[ord[2]] = axes_.at(a_axes_->get(rc, ord[2])).values[out_zz];
out_dt->set_value(idx, tmp->get_value(read_idx));
}, starts, len);
}
}
tgt_->add_barrier(mod_, *builder_);
}
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
unsigned vector_size = 1;
auto x_order = layouts_->get(cts)->order;
@@ -1126,16 +1289,14 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
if(tgt_->is_gpu())
if(unsigned alloc_size = alloc_->allocated_size()){
Type *int_8_ty = Type::getInt8Ty(*ctx_);
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
Type *int_32_ty = Type::getInt32Ty(*ctx_);
ArrayType *array_ty = ArrayType::get(int_32_ty, alloc_size/4);
Type *ptr_ty = PointerType::get(int_8_ty, 3);
GlobalVariable *sh_mem_array =
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
sh_mem_ptr_ = builder_->CreateBitCast(sh_mem_array, ptr_ty);
}
// allocate constant memory
for(ir::alloc_const *x: src.allocs())
visit_alloc_const(x);
// visit functions
for(ir::function *fn: src.get_function_list())
visit_function(fn);

View File

@@ -143,7 +143,7 @@ tile *machine_layout_distributed_t::create(ir::value *v) {
axes[d].values = {builder_->getInt32(0)};
}
}
return new distributed_tile(ty, shapes, layout_->order, axes, *builder_, false);
return new distributed_tile(ty, shapes, layout_->order, axes, *builder_);
}
machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder,

View File

@@ -45,9 +45,8 @@ llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size)
return VectorType::get(ty, vector_size);
}
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize)
: tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), order_(order), builder_(builder) {
vector_size_ = vectorize?ty_->getVectorNumElements():1;
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder)
: tile(ty, shapes), axes_(axes), order_(order), builder_(builder) {
init_indices();
}
@@ -73,13 +72,31 @@ indices_t distributed_tile::get_ordered_indices(unsigned id) {
}
void distributed_tile::for_each(std::function<void (indices_t)> fn) {
for(unsigned i = 0; i < ordered_indices_.size(); i++){
if(i % vector_size_ == 0)
fn(ordered_indices_[i]);
void distributed_tile::for_each(std::function<void (indices_t)> fn, int start, int end) {
if(end < 0)
end = ordered_indices_.size() + end + 1;
for(unsigned i = start; i < end; i++)
fn(ordered_indices_[i]);
}
void distributed_tile::for_each(std::function<void(indices_t)> fn, std::vector<int> starts, std::vector<int> sizes){
int rank = sizes.size();
int len = 1;
for(int s: sizes)
len *= s;
for(int i = 0; i < len; i++){
indices_t idx(rank);
int current = i;
for(int k = 0; k < rank; k++){
idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]);
current = current / sizes[k];
}
fn(idx);
}
}
/* Shared Tile */
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
@@ -126,7 +143,9 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_
}
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx) {
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes,
const std::vector<int>& perm, const std::vector<int>& order,
indices_t idx) {
// strides
std::vector<Value*> strides(order.size());
strides[order[0]] = builder.getInt32(1);

View File

@@ -1,6 +1,8 @@
#include <algorithm>
#include <iostream>
#include "triton/ir/utils.h"
#include "triton/ir/instructions.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/analysis/align.h"
@@ -60,8 +62,43 @@ ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
}
void coalesce::run(ir::module &mod) {
// find values to rematerialize
size_t num_groups = layout_->num_layouts();
for(size_t id = 0; id < num_groups; id++) {
if(layout_->get(id)->type != analysis::HMMA_884)
continue;
// extract memory stores
const auto& values = layout_->values_of(id);
ir::value* dot = nullptr;
for(ir::value *v: values)
if(auto x = dynamic_cast<ir::dot_inst*>(v))
dot = x;
ir::builder& builder = mod.get_builder();
std::vector<ir::value*> worklist = {dot};
std::set<ir::value*> seen;
while(!worklist.empty()) {
ir::value *current = worklist.back();
seen.insert(current);
worklist.pop_back();
// stop if trunc
if(auto x = dynamic_cast<ir::fp_trunc_inst*>(current)){
builder.set_insert_point_after(x);
ir::recoalesce_inst* rc = ir::recoalesce_inst::create(x);
builder.insert(rc);
x->replace_all_uses_with(rc);
rc->replace_uses_of_with(rc, x);
break;
}
// recurse
for(ir::user *u: current->get_users())
if(seen.find(u) == seen.end())
worklist.push_back(u);
}
}
// find values to rematerialize
std::vector<ir::io_inst*> remat;
for(size_t id = 0; id < num_groups; id++) {
const auto& values = layout_->values_of(id);
@@ -71,8 +108,10 @@ void coalesce::run(ir::module &mod) {
extract_io_use(v, io);
// extract leading axes
std::map<int, std::vector<ir::io_inst*>> axes;
for(ir::io_inst *i: io)
extract_ld(i, axes);
for(ir::io_inst *i: io){
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->axes.size())
extract_ld(i, axes);
}
// update list of values to rematerialize
if(axes.empty())
continue;

View File

@@ -1,21 +1,37 @@
#include "triton/codegen/transform/cts.h"
#include "triton/codegen/instructions.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include <iostream>
namespace triton {
namespace codegen{
namespace transform{
inline bool is_shared(ir::value *v) {
auto *i = dynamic_cast<ir::instruction*>(v);
inline bool is_shmem_op(ir::instruction* i, int op) {
if(i->get_id() == ir::INST_DOT)
return op==0 || op==1;
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
return op==0;
return false;
}
inline bool is_shmem_res(ir::value* v){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return false;
return storage_info.at(i->get_id()).first == codegen::SHARED;
if(i->get_id() == ir::INST_TRANS)
return true;
if(i->get_id() == ir::INST_REDUCE)
return true;
if(i->get_id() == ir::INST_COPY_TO_SHARED)
return true;
return false;
}
// run pass on module
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
auto *i = dynamic_cast<ir::instruction*>(x);
@@ -36,9 +52,8 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
return;
}
ir::value_id_t id = i->get_id();
// already in shared memory
if(to_shared && storage_info.at(id).first == SHARED)
if(to_shared && is_shmem_res(i))
return;
// copy
builder.set_insert_point_after(i);
@@ -53,18 +68,19 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool
void cts::run(ir::module &mod) {
// Add shared copies
ir::builder &builder = mod.get_builder();
for(ir::function *fn: mod.get_function_list()){
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
auto storage = storage_info.at(i->get_id());
for(ir::function* fn: mod.get_function_list()){
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
size_t num_op = i->get_num_operands();
// copy to shared operands
for(size_t k = 0; k < storage.second.size(); k++)
if(storage.second[k] == SHARED)
for(size_t k = 0; k < num_op; k++)
if(is_shmem_op(i, k))
add_copy(i, i->get_operand(k), builder, true);
// copy from shared operands
for(size_t k = 0; k < storage.second.size(); k++)
if(storage.second[k] == DISTRIBUTED &&
is_shared(i->get_operand(k))){
for(size_t k = 0; k < num_op; k++)
if(!dynamic_cast<ir::phi_node*>(i) &&
!is_shmem_op(i,k) &&
is_shmem_res(i->get_operand(k))){
add_copy(i, i->get_operand(k), builder, false);
}
}

View File

@@ -3,7 +3,6 @@
#include <algorithm>
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/instructions.h"
#include "triton/codegen/transform/membar.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"

View File

@@ -180,6 +180,11 @@ host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module
hst_->engine = builder.create();
}
std::unique_ptr<buffer> host_module::symbol(const char *name) const {
throw std::runtime_error("not implemented");
}
/* ------------------------ */
// OpenCL //
/* ------------------------ */
@@ -211,10 +216,21 @@ ocl_module::ocl_module(driver::context * context, std::unique_ptr<llvm::Module>
// }
}
std::unique_ptr<buffer> ocl_module::symbol(const char *name) const {
throw std::runtime_error("not implemented");
}
/* ------------------------ */
// CUDA //
/* ------------------------ */
static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){
size_t start_replace = str.find(begin);
size_t end_replace = str.find(end, start_replace);
if(start_replace == std::string::npos)
return false;
str.replace(start_replace, end_replace + 1 - start_replace, target);
return true;
}
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
// options
@@ -231,19 +247,17 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
llvm::SmallVector<char, 0> buffer;
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "ptx63", Assembly);
std::string result(buffer.begin(), buffer.end());
size_t start_replace = result.find(".version");
size_t end_replace = result.find('\n', start_replace);
assert(start_replace != std::string::npos);
result.replace(start_replace, end_replace - start_replace, ".version 6.4");
find_and_replace(result, ".version", "\n", ".version 6.4\n");
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
return result;
}
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// exit(EXIT_FAILURE);
// std::cout << source << std::endl;
cu_context::context_switcher ctx(*context);
// std::cout << source << std::endl;
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
unsigned int errbufsize = 8096;
@@ -260,11 +274,12 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
}
}
cu_buffer* cu_module::symbol(const char *name) const{
std::unique_ptr<buffer> cu_module::symbol(const char *name) const{
CUdeviceptr handle;
size_t size;
dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name);
return new cu_buffer(ctx_, size, handle, false);
std::unique_ptr<buffer> res(new cu_buffer(ctx_, size, handle, false));
return std::move(res);
}

View File

@@ -48,6 +48,9 @@ value *builder::get_int32(unsigned val) {
return constant_int::get(type::get_int32_ty(ctx_), val);
}
type *builder::get_void_ty()
{ return type::get_void_ty(ctx_); }
type *builder::get_int1_ty()
{ return type::get_int1_ty(ctx_); }
@@ -132,19 +135,12 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved, const std::string
return insert(binary_operator::create(OPCODE, lhs, rhs), name);\
}
#define DEFINE_UNARY_FLOAT(SUFFIX)\
value *builder::create_ ## SUFFIX(value *arg, const std::string &name){\
return insert(binary_operator::create_ ## SUFFIX(arg), name);\
}
// Binary
DEFINE_BINARY_FLOAT(fmul, binary_op_t::FMul)
DEFINE_BINARY_FLOAT(fdiv, binary_op_t::FDiv)
DEFINE_BINARY_FLOAT(frem, binary_op_t::FRem)
DEFINE_BINARY_FLOAT(fadd, binary_op_t::FAdd)
DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub)
// Unary
DEFINE_UNARY_FLOAT(fneg)
//===----------------------------------------------------------------------===//
@@ -171,10 +167,7 @@ value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs,
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, name, false, false);\
}
#define DEFINE_UNARY_INT(SUFFIX)\
value *builder::create_ ## SUFFIX(value *arg, const std::string &name){\
return insert(binary_operator::create_ ## SUFFIX(arg), name);\
}
// Binary
DEFINE_NOWRAP_BINARY(mul, binary_op_t::Mul)
@@ -190,9 +183,6 @@ DEFINE_BINARY_INT(urem, binary_op_t::URem)
DEFINE_BINARY_INT(and, binary_op_t::And)
DEFINE_BINARY_INT(or, binary_op_t::Or)
DEFINE_BINARY_INT(xor, binary_op_t::Xor)
// Unary
DEFINE_UNARY_INT(neg)
DEFINE_UNARY_INT(not)
//===----------------------------------------------------------------------===//

View File

@@ -138,23 +138,23 @@ binary_operator *binary_operator::create(binary_op_t op, value *lhs, value *rhs,
return new binary_operator(op, lhs, rhs, lhs->get_type(), name, next);
}
binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){
assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty());
value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
return binary_operator::create(binary_op_t::FSub, zero, arg, name, next);
}
//binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){
// assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty());
// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
// return binary_operator::create(binary_op_t::FSub, zero, arg, name, next);
//}
binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){
assert(arg->get_type()->get_scalar_ty()->is_integer_ty());
value *zero = constant_fp::get_zero_value_for_negation(arg->get_type());
return binary_operator::create(binary_op_t::Sub, zero, arg, name, next);
}
//binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){
// assert(arg->get_type()->get_scalar_ty()->is_integer_ty());
// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type()->get_scalar_ty());
// return binary_operator::create(binary_op_t::Sub, zero, arg, name, next);
//}
binary_operator *binary_operator::create_not(value *arg, const std::string &name, instruction *next){
assert(arg->get_type()->is_integer_ty());
constant *mask = constant::get_all_ones_value(arg->get_type());
return binary_operator::create(binary_op_t::Xor, arg, mask, name, next);
}
//binary_operator *binary_operator::create_not(value *arg, const std::string &name, instruction *next){
// assert(arg->get_type()->is_integer_ty());
// constant *mask = constant::get_all_ones_value(arg->get_type());
// return binary_operator::create(binary_op_t::Xor, arg, mask, name, next);
//}
//===----------------------------------------------------------------------===//
// cmp_inst classes
@@ -762,6 +762,12 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri
return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next);
}
// recoalesce
recoalesce_inst* recoalesce_inst::create(value *arg, const std::string &name, instruction *next) {
return new recoalesce_inst(arg->get_type(), INST_RECOALESCE, arg, name, next);
}
// barrier
barrier_inst::barrier_inst(context &ctx, const std::string &name,

View File

@@ -57,7 +57,10 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
}
case Token::MASKED_DEREF: {
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
return set_ret(bld_->create_masked_load(rhs, lhs, ir::undef_value::get(ret_ty)));
ir::value* false_value = ir::undef_value::get(ret_ty->get_scalar_ty());
if(ret_ty->is_tile_ty())
false_value = bld_->create_splat(false_value, ret_ty->get_tile_shapes());
return set_ret(bld_->create_masked_load(rhs, lhs, false_value));
}
case Token::ELLIPSIS: {
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
@@ -76,7 +79,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
return set_ret(bld_->create_add(lhs, rhs));
case '-':
if(binary->lhs_->Type()->ToPointer())
return set_ret(bld_->create_gep(lhs, {bld_->create_neg(rhs)}));
return set_ret(bld_->create_gep(lhs, {GenUnaryMinus(rhs)}));
else if(flt)
return set_ret(bld_->create_fsub(lhs, rhs));
else
@@ -147,7 +150,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
if(flt)
return set_ret(bld_->create_fcmpONE(lhs, rhs));
else
return set_ret(bld_->create_icmpEQ(lhs, rhs));
return set_ret(bld_->create_icmpNE(lhs, rhs));
default:
error_not_implemented();
}
@@ -166,6 +169,16 @@ ir::reduce_inst::op_t reduce_op(int tag, bool is_float) {
should_not_happen();
return reduce_inst::op_t();
}
ir::value* Generator::GenUnaryMinus(ir::value* arg) {
ir::type *ty = arg->get_type();
ir::type *sca_ty = ty->get_scalar_ty();
ir::value *_0 = ir::constant_fp::get_zero_value_for_negation(sca_ty);
if(ty->is_tile_ty())
_0 = bld_->create_splat(_0, ty->get_tile_shapes());
return bld_->create_sub(_0, arg);
}
void Generator::VisitUnaryOp(UnaryOp* unary) {
// recursion
Visit(unary->operand_);
@@ -174,17 +187,17 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
ir::type *arg_scal_ty = arg_ty->get_scalar_ty();
// return
switch (unary->op_) {
case Token::PREFIX_INC: return error_not_implemented();
case Token::PREFIX_DEC: return error_not_implemented();
case Token::PREFIX_INC: return error_not_implemented();
case Token::PREFIX_DEC: return error_not_implemented();
case Token::POSTFIX_INC: return error_not_implemented();
case Token::POSTFIX_DEC: return error_not_implemented();
case Token::ADDR: return error_not_implemented();
case Token::DEREF: return set_ret(bld_->create_load(arg));
case Token::PLUS: return error_not_implemented();
case Token::MINUS: return error_not_implemented();
case '~': return set_ret(bld_->create_neg(arg));
case '!': return set_ret(bld_->create_not(arg));
case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_)));
case Token::ADDR: return error_not_implemented();
case Token::DEREF: return set_ret(bld_->create_load(arg));
case Token::PLUS: return error_not_implemented();
case Token::MINUS: return set_ret(GenUnaryMinus(arg));
case '~': return error_not_implemented();
case '!': return error_not_implemented();
case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_)));
case Token::REDUCE: {
int ax, tag;
UnaryOp::decodeRed(unary->info_, ax, tag);
@@ -232,11 +245,54 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
else
return should_not_happen();
}
if(name == "get_num_programs"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ret = ret_;
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
return set_ret(bld_->create_get_num_program(axis->get_value()));
else
return should_not_happen();
}
if(name == "atomic_cas"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ptr = ret_;
VisitExpr(funcCall->Args()->at(1));
ir::value* cmp = ret_;
VisitExpr(funcCall->Args()->at(2));
ir::value* val = ret_;
return set_ret(bld_->create_atomic_cas(ptr, cmp, val));
}
if(name == "atomic_xchg"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ptr = ret_;
VisitExpr(funcCall->Args()->at(1));
ir::value* val = ret_;
return set_ret(bld_->create_atomic_exch(ptr, val));
}
if(name == "sqrtf"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ret = ret_;
return set_ret(bld_->create_sqrt(ret));
}
if(name == "calloc"){
VisitExpr(funcCall->Args()->at(0));
ir::value* ret = ret_;
ir::constant_int *size = dynamic_cast<ir::constant_int*>(ret);
assert(size);
ir::alloc_const* alloc = new ir::alloc_const(bld_->get_int8_ty(), size);
mod_->add_alloc(alloc);
return set_ret(alloc);
}
//TODO: integrate this into conditionalop
if(name == "select"){
VisitExpr(funcCall->Args()->at(0));
ir::value* cond = ret_;
VisitExpr(funcCall->Args()->at(1));
ir::value* true_val = ret_;
VisitExpr(funcCall->Args()->at(2));
ir::value* false_val = ret_;
return set_ret(bld_->create_select(cond, true_val, false_val));
}
return error_not_implemented();
}
@@ -350,12 +406,15 @@ void Generator::VisitForStmt(ForStmt *forStmt) {
ir::value *cond = ret_;
return bld_->create_cond_br(cond, loop_bb, next_bb);
});
VisitStmt(init_);
VisitExpr(cond_);
ir::value *cond = ret_;
bld_->create_cond_br(cond, loop_bb, next_bb);
if(init_)
VisitStmt(init_);
// VisitExpr(cond_);
// ir::value *cond = ret_;
// bld_->create_cond_br(cond, loop_bb, next_bb);
bld_->create_br(loop_bb);
bld_->set_insert_point(loop_bb);
VisitStmt(body_);
if(body_)
VisitStmt(body_);
if(!is_terminator(ret_))
mod_->get_continue_fn()();
ir::basic_block *stop_bb = bld_->get_insert_block();
@@ -512,6 +571,8 @@ ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) {
else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() &&
src_scalar_ty->get_integer_bitwidth())
return bld_->create_int_cast(src, dst_ty, dst_signed);
else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_pointer_ty())
return bld_->create_cast(ir::BitCast, src, dst_ty);
else{
should_not_happen();
return nullptr;
@@ -611,6 +672,8 @@ ir::type* Generator::GenIRFuncType(FuncType* type, ir::context& ctx) {
ir::type* Generator::GenIRPointerType(PointerType* type, ir::context& ctx) {
ir::type* ele_ty = GenIRType(type->Derived().GetPtr(), ctx);
unsigned addr_space = 1;
if(type->Derived().IsConstantQualified())
addr_space = 4;
return ir::pointer_type::get(ele_ty, addr_space);
}

View File

@@ -1083,14 +1083,12 @@ QualType Parser::ParseDeclSpec(int* storageSpec, int* funcSpec, int* alignSpec)
*storageSpec |= S_THREAD;
break;
case Token::AUTO:
EnsureAndSetStorageSpec(tok, storageSpec, S_AUTO);
break;
// Type qualifier
case Token::CONST: qualSpec |= Qualifier::CONST; break;
case Token::RESTRICT: qualSpec |= Qualifier::RESTRICT; break;
case Token::VOLATILE: qualSpec |= Qualifier::VOLATILE; break;
case Token::CMEM: qualSpec |= Qualifier::CMEM; break;
// Type specifier
case Token::SIGNED:
@@ -1551,6 +1549,7 @@ int Parser::ParseQual() {
case Token::CONST: qualSpec |= Qualifier::CONST; break;
case Token::RESTRICT: qualSpec |= Qualifier::RESTRICT; break;
case Token::VOLATILE: qualSpec |= Qualifier::VOLATILE; break;
case Token::CMEM: qualSpec |= Qualifier::CMEM; break;
case Token::ATOMIC: Error(tok, "do not support 'atomic'"); break;
default: ts_.PutBack(); return qualSpec;
}
@@ -1769,6 +1768,7 @@ QualType Parser::ParseArrayFuncDeclarator(const Token* ident, QualType base) {
if (!base->Complete()) {
Error(ident, "'%s' has incomplete element type", ident->str_.c_str());
}
// return a pointer for tiles in constant memory:
return TileType::New(shape, base);
} else if (ts_.Try('(')) { // Function declaration

View File

@@ -7,6 +7,7 @@
static MemPoolImp<Token> tokenPool;
const std::unordered_map<std::string, int> Token::kwTypeMap_ {
{ "__constant__", Token::CMEM },
{ "__global__", Token::GLOBAL },
{ "auto", Token::AUTO },
{ "break", Token::BREAK },

View File

@@ -294,7 +294,8 @@ std::string ArithmType::Str() const {
bool PointerType::Compatible(const Type& other) const {
// C11 6.7.6.1 [2]: pointer compatibility
auto otherPointer = other.ToPointer();
return otherPointer && derived_->Compatible(*otherPointer->derived_);
return otherPointer &&
derived_->Compatible(*otherPointer->derived_);
// FIXME(wgtdkp): cannot loose compatible constraints
//return other.IsInteger() ||

View File

@@ -184,10 +184,20 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
// kernel uses too much resources
if(!bin)
return;
// copy constants
std::unique_ptr<driver::buffer> buffer;
for(ir::alloc_const* alloc: ir->allocs()){
std::string name = alloc->get_name();
auto it = cst_.find(name);
if(it == cst_.end())
throw std::runtime_error("constant not set before execution");
buffer = bin->symbol(name.c_str());
stream->write(&*buffer, true, 0, it->second);
}
// benchmark
ir::function *tmp = ir->get_function_list()[0];
caller call(tmp, std::move(bin), opt);
double ts = tools::bench([&]() { call(stream, grid_fn(opt), args); }, stream);
double ts = tools::bench([&]() { call(stream, grid_fn(opt), args); }, stream, true);
// save best
if(ts < best_ts) {
best_ts = ts;
@@ -222,20 +232,14 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
// run passes
dce.run(module);
// ir::print(module, std::cout);
disassociate.run(module);
// ir::print(module, std::cout);
dce.run(module);
// ir::print(module, std::cout);
peephole.run(module);
dce.run(module);
align.run(module);
cts.run(module);
axes.run(module);
// ir::print(module, std::cout);
layouts.run(module);
coalesce.run(module);
dce.run(module);
@@ -246,17 +250,19 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
dce.run(module);
align.run(module);
axes.run(module);
// ir::print(module, std::cout);
layouts.run(module);
liveness.run(module);
allocation.run(module);
if(allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>();
barriers.run(module);
// std::cout << "isel" << std::endl;
// ir::print(module, std::cout);
isel.visit(module, *llvm);
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
// done
// exit(EXIT_FAILURE);
return res;
}
@@ -273,8 +279,13 @@ R"(
#define __aligned(A) __attribute__((aligned(A)))
#define __multipleof(A) __attribute__((multipleof(A)))
extern int atomic_cas(int*, int, int);
extern int atomic_xchg(int*, int);
extern int get_program_id(int);
extern int get_num_programs(int);
extern float sqrtf(float);
extern int select(bool, int, int);
extern char __constant__ * calloc(int);
)";
}
@@ -316,5 +327,9 @@ void function::operator()(const std::vector<arg>& args, const grid_t& grid, driv
return this->operator()(args, [&grid](const options_t&){ return grid; }, stream);
}
void function::set_cst(const std::string& name, void* data, size_t n_bytes) {
cst_[name] = std::vector<char>((char*)data, (char*)data + n_bytes);
}
}
}

View File

@@ -1,157 +0,0 @@
import tensorflow as tf
import triton
import numpy as np
src = '''
#if AT == 1
#define USE_A ^a
#define STRIDE_AK lda
#define STRIDE_AM 1
#define BROADCAST_AK :, newaxis
#define BROADCAST_AM newaxis, :
#define SHAPE_A TK, TM
#else
#define USE_A a
#define STRIDE_AK 1
#define STRIDE_AM lda
#define BROADCAST_AK newaxis, :
#define BROADCAST_AM :, newaxis
#define SHAPE_A TM, TK
#endif
#if BT == 1
#define USE_B ^b
#define STRIDE_BK 1
#define STRIDE_BM ldb
#define BROADCAST_BN newaxis, :
#define BROADCAST_BK :, newaxis
#define SHAPE_B TN, TK
#else
#define USE_B b
#define STRIDE_BK ldb
#define STRIDE_BM 1
#define BROADCAST_BN :, newaxis
#define BROADCAST_BK newaxis, :
#define SHAPE_B TK, TN
#endif
void dot (TYPE* A __readonly __noalias __align(16),
TYPE* B __readonly __noalias __align(16),
TYPE* C __writeonly __noalias __align(16),
int lda, int ldb, int ldc,
int N, int* lut,
int* locks, int nlocks) {
int ridx = get_program_id(0);
float c[TM, TN] = 0;
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
// load LUT header
int *header = lut + get_program_id(1) * 4;
int offset = *(header + 0);
int K = *(header + 1);
int column = *(header + 2);
int lockid = *(header + 3);
int *plut = lut + offset * 2;
int offx = ridx;
int offy = 0;
// compute x, y offsets
int rxa[TM] = offx * TM + (0 ... TM);
int ryb[TN] = offy * TN + (0 ... TN);
// bounds checking
bool checka[SHAPE_A] = (rxa < N)[:, newaxis];
bool checkb[SHAPE_B] = 1;
// base offset
int offa[SHAPE_A] = rxa[BROADCAST_AM] * STRIDE_AM + rka[BROADCAST_AK] * STRIDE_AK;
int offb[SHAPE_B] = ryb[BROADCAST_BN] * STRIDE_BN + rkb[BROADCAST_BK] * STRIDE_BK;
for(int k = K; k > 0; k -= 1) {
// fetch block indices
int ak = *(plut + 0);
int bk = *(plut + 1);
lut += 2;
// compute pointers to blocks
TYPE* pa[SHAPE_A] = A + offa + ak * TK * lda;
TYPE* pb[SHAPE_B] = B + offb + bk * TK * TN;
// load blocks
TYPE a[SHAPE_A] = checka ? *pa : 0;
TYPE b[SHAPE_B] = *pb;
// multiply blocks
c += USE_A @ USE_B;
}
int rxc[TM] = ridx * TM + (0 ... TM);
int ryc[TN] = column * TN + (0 ... TN);
TYPE* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
bool checkc[TM, TN] = (rxc < N)[:, newaxis];
if(lockid == 0) {
*?(checkc) pc = c;
}
else {
int *plock = locks + ridx*nlocks + lockid - 1;
int *pcount = plock + get_num_program(0)*nlocks;
while(atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *pc;
atomic_exch(pcount, 1);
atomic_exch(plock, 0);
}
}
'''
# std::string dot::triton_c_src_dw() const {
# bool AT = (op_ == WGRAD);
# bool BT = (op_ == FPROP);
# std::string usea = AT ? "trans(a)" : "a";
# std::string useb = BT ? "trans(b)" : "b";
# std::string sizea = AT ? "TK, TM" : "TM, TK";
# std::string sizeb = BT ? "TN, TK" : "TK, TN";
# std::string bca0 = AT ? "newaxis, :" : ":, newaxis";
# std::string bca1 = AT ? ":, newaxis" : "newaxis, :";
# std::string bcb0 = BT ? ":, newaxis" : "newaxis, :";
# std::string bcb1 = BT ? "newaxis, :" : ":, newaxis";
# std::string lda0 = AT ? "*lda" : "";
# std::string lda1 = AT ? "" : "*lda";
# std::string ldb0 = BT ? "" : "*ldb";
# std::string ldb1 = BT ? "*ldb" : "" ;
# std::string result =
# R"(
# const tunable int TM = {)" + std::to_string(BS_) + R"(};
# const tunable int TN = {)" + std::to_string(BS_) + R"(};
# const tunable int TK = {32};
# void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
# restrict read_only align(16) )" + ab_ty_ + R"( *B,
# )" + c_ty_ + R"(* C,
# int lda, int ldb, int ldc,
# int N, int* lut,
# int* locks, int nlocks) {
# int ridx = get_range_id(0);
# float acc[TM, TN] = 0;
# int rka[TK] = 0 ... TK;
# int rkb[TK] = 0 ... TK;
# int *header = lut + ridx * 2;
# int offx = *(header + 0);
# int offy = *(header + 1);
# int rxa[TM] = offx*TM + (0 ... TM);
# int ryb[TN] = offy*TN + (0 ... TN);
# bool checka[TK, TM] = (rka < N)[:, newaxis];
# bool checkb[TK, TN] = (rkb < N)[:, newaxis];
# int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(;
# int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
# )" + ab_ty_ + " * pa[" + sizea + R"(] = A + offa;
# )" + ab_ty_ + " * pb[" + sizeb + R"(] = B + offb;
# )" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0;
# )" + ab_ty_ + " b[" + sizeb + R"(] = checkb ? *pb : 0;
# for(int k = N; k > 0; k = k - TK) {
# acc = dot()" + usea + ", " + useb + R"(, acc);
# pa = pa + TK)" + lda1 + R"(;
# pb = pb + TK)" + ldb1 + R"(;
# a = checka ? *pa : 0;
# b = checkb ? *pb : 0;
# }
# int rxc[TM] = (0 ... TM);
# int ryc[TN] = (0 ... TN);
# )" + c_ty_ + R"( c[TM, TN] = acc;
# )" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis]*TM + ryc[newaxis, :] + ridx*TM*TN;
# *pc = c;
# })";

View File

@@ -1,15 +0,0 @@
import torch
import triton
N, C, K = 32, 8, 32
H, W = 16, 16
R, S = 3, 3
torch.manual_seed(0)
a = torch.randn(N, C, H, W).cuda()
b = torch.ones(C, R, S, K).cuda()
rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
tc = triton.ops.conv(a, b)
print((rc - tc).abs().max())
#print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max())
#print(tc[31, 31,:,:])

View File

@@ -1,71 +0,0 @@
import numpy as np
import triton
def run_tf():
M, N, K = 2048, 2048, 2048
a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K])
triton_c = triton.ops.dot(a, b, False, True, 1)
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
triton_y = tf.math.reduce_mean(triton_d)
fw_c = tf.matmul(a, b, False, True)
fw_d = tf.matmul(fw_c, b, True, False)
fw_y = tf.math.reduce_mean(fw_d)
# Gradient
triton_da, triton_db = tf.gradients(triton_y, [a, b])
fw_da, fw_db = tf.gradients(fw_y, [a, b])
# Reference
feed_dict = {a: np.random.rand(M, K).astype(np.float32),
b: np.random.rand(K, N).astype(np.float32)}
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([triton_da, fw_da, triton_db, fw_db, fw_y, triton_y], feed_dict = feed_dict)
triton_da, fw_da = result[0][0], result[1][0]
triton_db, fw_db = result[2][0], result[3][0]
# Benchmark
nanosec = triton.bench_registry[triton_d]
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
print('Diff DA:', (triton_da - fw_da).max())
print('Diff DB:', (triton_db - fw_db).max())
def run_torch():
torch.manual_seed(0)
M, N, K = 2048, 2048, 2048
a = torch.randn(M, K).cuda()
b = torch.randn(K, N).cuda()
a.requires_grad_(True)
b.requires_grad_(True)
torch_c = torch.matmul(a, torch.t(b))
torch_d = torch.matmul(torch.t(torch_c), b)
torch_y = torch.mean(torch_d)
triton_c = triton.ops.dot(a, b, False, True, 1)
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
triton_y = torch.mean(triton_d)
# torch gradient
torch_y.backward()
torch_da = a.grad.clone()
torch_db = b.grad.clone()
# triton gradient
a.grad.zero_()
b.grad.zero_()
triton_y.backward()
triton_da = a.grad.clone()
triton_db = b.grad.clone()
#nanosec = triton.bench_registry[triton_d]
#print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
print('Diff DA:', (torch_da - triton_da).max())
print('Diff DB:', (torch_db - triton_db).max())
try:
import tensorflow as tf
run_tf()
except ModuleNotFoundError:
pass
try:
import torch
run_torch()
except ModuleNotFoundError:
pass

View File

@@ -1,92 +1,194 @@
#!/usr/bin/env python
import numpy as np
from enum import Enum
import triton
import torch
from torch.utils.cpp_extension import load
import numpy as np
#import utils
from time import time
class MODE(Enum):
TF = 1
TORCH = 2
#torch.backends.cudnn.benchmark = True
try:
import tensorflow as tf
mode = MODE.TF
except ModuleNotFoundError:
pass
configs = []
try:
import torch
mode = MODE.TORCH
except ModuleNotFoundError:
pass
# Matrix multiplication
MNK = [
(512, 512 ,512),
(2048, 2048, 2048),
(8192, 8192, 8192),
(64, 64, 64000),
(64, 64, 128000),
(256, 256, 64000),
(256, 256, 128000),
cases = []
# Matmul
cases += [[[4, 1024, 1024], [1024, 1024], [4, 1024, 1024], "btc,ck->btk"]]
# Attention
# cases += [[[4, 256, 8, 2, 64], [8, 2, 512, 64], [4, 256, 8, 2, 512], "bchak,hank->bchan"]]
(1536, 16, 1536),
(1536, 32, 1536),
(1536, 64, 1536),
(1536, 128, 1536),
(4096, 16, 4096),
(4096, 32, 4096),
(4096, 64, 4096),
(4096, 128, 4096),
#(127008, 768, 576)
]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b)
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a.t(), b)
configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b.t())
configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
if mode == MODE.TF:
sess = tf.InteractiveSession()
# Relative attention
NTHSE = [
#(16, 512, 1, 64, 64),
# (16, 512, 1, 128, 128),
# (16, 512, 1, 256, 256),
# (16, 512, 1, 256, 512),
#(16, 512, 8, 64, 64),
# (16, 512, 8, 128, 128),
# (16, 512, 8, 256, 256),
# (16, 512, 8, 256, 512),
for a_shape, b_shape, c_shape, einsum in cases:
# (64, 1024, 1, 64, 64),
#(64, 1024, 1, 128, 128),
# (64, 1024, 1, 256, 256),
# (64, 1024, 1, 256, 512),
# (64, 1024, 8, 64, 64),
#(64, 1024, 8, 128, 128),
# (64, 1024, 8, 256, 256),
# (64, 1024, 8, 256, 512),
A = np.random.uniform(-1.0, 1.0, a_shape).astype(np.float16).astype(np.float32)
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
# (128, 1024, 1, 64, 64),
# (128, 1024, 1, 128, 128),
# (128, 1024, 1, 256, 256),
#(128, 1024, 1, 256, 512),
# (128, 1024, 8, 64, 64),
# (128, 1024, 8, 128, 128),
# (128, 1024, 8, 256, 256),
#(128, 1024, 8, 256, 512)
]
for N, T, H, S, E in NTHSE:
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
for N, T, H, S, E in NTHSE:
configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
for N, T, H, S, E in NTHSE:
configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
# Execute (tensorflow)
if mode == MODE.TF:
a = tf.placeholder(tf.float32, a_shape, name="a")
b = tf.placeholder(tf.float32, b_shape, name="b")
e = tf.placeholder(tf.float32, c_shape, name="e")
c = triton.ops.einsum(einsum, a, b, 1)
da, db = tf.gradients(c, [a, b], e)
feed_dict = { a: A.astype(np.float32),
b: B.astype(np.float32),
e: E }
sess.run(tf.global_variables_initializer())
result = sess.run([c, da, db], feed_dict = feed_dict)
# Execute (torch)
if mode == MODE.TORCH:
a = torch.from_numpy(A).cuda()
b = torch.from_numpy(B).cuda()
e = torch.from_numpy(E).cuda()
a.requires_grad_(True)
b.requires_grad_(True)
c = triton.ops.einsum(einsum, a, b, 1)
torch.autograd.backward(c, e)
da = a.grad
db = b.grad
result = [c.cpu().detach().numpy(), da.cpu().detach().numpy(), db.cpu().detach().numpy()]
# benchmark
nanosec = triton.bench_registry[c]
ctx = triton.ctx_registry[c]
b, m, n, k = tuple((ctx.bmnk[i] for i in range(0, 4)))
ops = 2.*b*m*n*k
print('C TFLOPS:', ops / triton.bench_registry[c] * 1e-3)
#print('DA TFLOPS:', ops / triton.bench_registry[da] * 1e-3)
#print('DB TFLOPS:', ops / triton.bench_registry[db] * 1e-3)
# 1D Dense convolution
NCHKR = [
# (1, 1152, 12602, 512, 3)
]
for N, C, H, K, R in NCHKR:
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
configs += [([N, C, H],
[C, R, K],
[N, K, H - R + 1],
torch_fn,
'nc(h+r),crk->nkh',
dict())]
# test
ctx = triton.ctx_registry[c]
t_a = ctx.trans_a
t_b = ctx.trans_b
e_a = ctx.einsum_a
e_b = ctx.einsum_b
e_c = ctx.einsum_c
C = np.einsum(einsum, A, B)
if not t_a and not t_b: # NN
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
elif not t_a and t_b: # NT
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
DB = np.einsum(f"{e_c},{e_a}->{e_b}", E, A)
elif t_a and not t_b: # TN
DA = np.einsum(f"{e_b},{e_c}->{e_a}", B, E)
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
c, da, db = result[0], result[1], result[2]
print('C diff:', np.abs((C - c)).max())
print('DA diff:', np.abs((DA - da)).max())
print('DB diff:', np.abs((DB - db)).max())
# 2D Dense convolution
NCHWKRS = [
#(8, 64, 128, 128, 768, 3, 3),
#(8, 128, 64, 64, 256, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
configs += [([N, C, H, W],
[C, R, S, K],
[N, K, H - R + 1, W - R + 1],
torch_fn,
'nc(h+r)(w+s),crsk->nkhw',
dict())]
# 3D Dense Convolution
NCDHWKTRS = [
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
]
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
configs += [([N, C, D, H, W],
[C, T, R, S, K],
[N, K, D - T + 1, H - R + 1, W - R + 1],
torch_fn,
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
dict())]
# Shift convolution
shift_cuda = torch.utils.cpp_extension.load(
'shift_cuda', ['kernels/shift_cuda.cpp',
'kernels/shift_cuda_kernel.cu'],
extra_cflags=['-O3'])
class shift(torch.autograd.Function):
@staticmethod
def forward(ctx, x, shift):
ctx.save_for_backward(shift)
return shift_cuda.forward(x, shift)
@staticmethod
def backward(ctx, grad_output):
shift, = ctx.saved_tensors
grad_output = shift_cuda.backward(grad_output, shift)
return grad_output, None
NCHWKRS = [
#(8, 64, 128, 128, 128, 3, 3),
#(8, 128, 64, 64, 256, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2
shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2
def shift_conv(a, b, **kwargs):
shift_h, shift_w = kwargs['sh'], kwargs['sw']
shift_torch = np.column_stack((shift_w*-1, shift_h*-1))
shift_torch = torch.from_numpy(shift_torch).cuda()
a = shift.apply(a, shift_torch)
b = b.permute(1, 0)
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
return torch.nn.functional.conv2d(a, b)
configs += [([N, C, H, W],
[C, K],
[N, K, H, W],
shift_conv,
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
{'sh': shift_h, 'sw': shift_w})]
# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
dtype = torch.cuda.HalfTensor
# initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()
# triton output
#ta = triton.ops._einsum.pad(a, [4,4,4,4])
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
# reference output
if torch_fn:
rc = torch_fn(a, b, **arrays)
else:
rc = torch.einsum(expr, a, b)
# performance relative to equivalent matrix multiplication
ctx = triton.ctx_registry[tc]
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
# a = torch.rand(B, M, K).type(dtype).cuda()
# b = torch.rand(B, K, N).type(dtype).cuda()
# tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
# ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
ratio = 0
# test and benchmark
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3
diff = (tc - rc).abs().max() / rc.abs().max()
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} ({ratio:4.2f}); {diff:4.2f}')

View File

@@ -0,0 +1,42 @@
#include <torch/torch.h>
#include <vector>
// CUDA forward declarations
at::Tensor shift_cuda_forward(
const at::Tensor input,
const at::Tensor shift);
at::Tensor shift_cuda_backward(
const at::Tensor grad_input,
const at::Tensor shift);
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at::Tensor shift_forward(
const at::Tensor input,
const at::Tensor shift) {
CHECK_INPUT(input);
CHECK_INPUT(shift);
return shift_cuda_forward(input, shift);
}
at::Tensor shift_backward(
const at::Tensor grad_input,
const at::Tensor shift) {
CHECK_INPUT(grad_input);
CHECK_INPUT(shift);
return shift_cuda_backward(grad_input, shift);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &shift_forward, "Shift forward (CUDA)");
m.def("backward", &shift_backward, "Shift backward (CUDA)");
}

View File

@@ -0,0 +1,111 @@
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
namespace {
template <typename scalar_t>
__global__ void shift_cuda_forward_kernel(
const scalar_t* __restrict__ input,
const int32_t* __restrict__ shift,
scalar_t* __restrict__ output,
const int32_t B,
const int32_t C,
const int32_t H,
const int32_t W) {
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t size = B*C*H*W;
const int32_t CHW = C*H*W;
const int32_t HW = H*W;
const int32_t b = idx / CHW;
const int32_t c = (idx - b*CHW) / HW;
const int32_t h = (idx - b*CHW - c*HW) / W;
const int32_t w = idx - b*CHW - c*HW - h*W;
const int32_t target_w = w + shift[2*c];
const int32_t target_h = h + shift[2*c + 1];
const int32_t target_idx = b*CHW + c*HW + target_h*W + target_w;
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
output[target_idx] = input[idx];
}
}
template <typename scalar_t>
__global__ void shift_cuda_backward_kernel(
const scalar_t* __restrict__ grad_input,
scalar_t* __restrict__ grad_output,
const int32_t* __restrict__ shift,
const int32_t B,
const int32_t C,
const int32_t W,
const int32_t H) {
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t size = B*C*W*H;
const int32_t CWH = C*W*H;
const int32_t WH = W*H;
const int32_t b = idx / CWH;
const int32_t c = (idx - b*CWH) / WH;
const int32_t w = (idx - b*CWH - c*WH) / W;
const int32_t h = idx - b*CWH - c*WH - w*H;
const int32_t target_w = w - shift[2*c];
const int32_t target_h = h - shift[2*c + 1];
const int32_t target_idx = b*CWH + c*WH + target_w*W + target_h;
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
grad_output[target_idx] = grad_input[idx];
}
}
} // namespace
at::Tensor shift_cuda_forward(
const at::Tensor input,
const at::Tensor shift) {
const auto B = input.size(0);
const auto C = input.size(1);
const auto H = input.size(2);
const auto W = input.size(3);
const auto size = B*C*W*H;
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
auto output = at::zeros_like(input);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "shift_forward_cuda", ([&] {
shift_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
input.data<scalar_t>(),
shift.data<int32_t>(),
output.data<scalar_t>(),
B,
C,
H,
W);
}));
return output;
}
at::Tensor shift_cuda_backward(
const at::Tensor grad_input,
const at::Tensor shift) {
const auto B = grad_input.size(0);
const auto C = grad_input.size(1);
const auto H = grad_input.size(2);
const auto W = grad_input.size(3);
const auto size = B*C*W*H;
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
auto grad_output = at::zeros_like(grad_input);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.type(), "shift_backward_cuda", ([&] {
shift_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
grad_input.data<scalar_t>(),
grad_output.data<scalar_t>(),
shift.data<int32_t>(),
B,
C,
H,
W);
}));
return grad_output;
}

View File

@@ -77,7 +77,7 @@ class CMakeBuild(build_ext):
pass
cfg = 'Debug' if self.debug else 'Release'
#cfg = 'Release'
cfg = 'Release'
build_args = ['--config', cfg]
if platform.system() == "Windows":

View File

@@ -1,4 +1,5 @@
#include <pybind11/pybind11.h>
#include <pybind11/buffer_info.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include <string>
@@ -48,6 +49,11 @@ void delete_fn(size_t id) {
id_fn_map.erase(id);
}
void register_cst(size_t id, const std::string& name, pybind11::buffer& data) {
pybind11::buffer_info info = data.request();
id_fn_map[id]->set_cst(name, info.ptr, info.size*info.itemsize);
}
void cleanup() {
id_grid_map.clear();
id_fn_map.clear();
@@ -508,7 +514,8 @@ void gen_torch_make_handles(std::ostream &os,
os << " " << to_c_ty(ty) << " arg_" << name << " = " << name << ";" << std::endl;
else{
os << " CHECK_INPUT(" << name << ");" << std::endl;
os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), (CUdeviceptr)" + name + ".storage().data(), false);" << std::endl;
os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), "
" (CUdeviceptr)((char*)" + name + ".storage().data() + " + name + ".storage_offset() * " + name + ".itemsize()), false);" << std::endl;
}
}
}
@@ -526,8 +533,8 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
os << name;
}
os << "}, *id_grid_map.at(id), &stream);\n";
os << " };\n ";
os << " run();";
os << " };\n";
os << " run();\n";
os << " if(bench > 0)\n ";
os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n ";
}
@@ -562,18 +569,15 @@ std::tuple<std::string,
std::ostringstream oss;
oss << R"(
#include "triton/driver/buffer.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "torch/extension.h"
#include "torch/script.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/CUDAHooks.h"
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) CHECK_CUDA(x);
namespace rt = triton::runtime;
namespace drv = triton::driver;
@@ -628,6 +632,7 @@ PYBIND11_MODULE(libtriton, m) {
m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn);
m.def("register_cst", &register_cst);
m.def("delete_fn", &delete_fn);
m.def("make_op_id", &make_op_id);
m.def("make_scalar_id", &make_scalar_id);

View File

@@ -36,14 +36,16 @@ class function(metaclass = function_meta):
def apply_torch(cls, *args, **kwargs):
class TorchFunction(fw.torch.autograd.Function):
@staticmethod
def forward(ctx, *targs, **tkwargs):
y = cls.forward(ctx, *targs, **tkwargs)
def forward(ctx, *targs):
y = cls.forward(ctx, *targs, **cls.torch_kwargs)
ctx_registry[y] = ctx
return y
@staticmethod
def backward(ctx, grad_output):
return cls.backward(ctx, grad_output)
return TorchFunction.apply(*args, **kwargs)
cls.torch_kwargs = kwargs
return TorchFunction.apply(*args)
torch_kwargs = 0
@classmethod
def extract_tf_tensors(cls, lst, err):

View File

@@ -6,6 +6,8 @@ import hashlib
import sysconfig
import sys
import weakref
import contextlib
import io
# import for just-in-time compilation
import distutils
import setuptools.command.build_ext
@@ -56,7 +58,16 @@ def _write_bindings(src, root):
handle.writelines(src)
# return path of cpp file
return (cpp, so)
@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
try:
yield
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr
def _build(src, path):
ccdir = os.path.join(libtriton.__file__, os.path.pardir)
ccdir = os.path.realpath(ccdir)
@@ -102,7 +113,7 @@ def _build(src, path):
language = 'c++',
sources = [src],
include_dirs = include_dirs,
extra_compile_args = extra_compile_args,
extra_compile_args = extra_compile_args + ['-g0'],
extra_link_args = extra_link_args,
library_dirs = library_dirs,
libraries = libraries,
@@ -119,7 +130,8 @@ def _build(src, path):
ext_modules = [ext],
script_args = args,
)
setuptools.setup(**args)
with quiet():
setuptools.setup(**args)
shutil.rmtree(tmp)
def _cvt_to_def_str(obj):
@@ -188,6 +200,10 @@ class kernel:
self.src = src
self.outputs = outputs
self.tmp = tmp
self.cst = dict()
def set_constant(self, name, value):
self.cst[name] = value
def __call__(self, *args, **kwargs):
# create a new framework op when defines are different
@@ -204,15 +220,16 @@ class kernel:
defines.append((k, values))
opt = libtriton.options_space()
opt.defines = defines
opt.num_warps = [4]
opt.num_warps = [2, 4]
# create unique id for this op
op_id = libtriton.make_op_id()
self.fw_id[key] = op_id
# register function
libtriton.register_fn(op_id, self.src, opt)
for name, value in self.cst.items():
libtriton.register_cst(op_id, name, value)
if self.fw_op is None:
self.fw_op = _make_framework_op(self.src, self.outputs, self.tmp, opt)
# benchmarking info
bench = 0
if 'bench' in kwargs:
@@ -252,9 +269,9 @@ class kernel:
for y in ret:
bench_registry[y] = triton.utils.id_dict.lazy_entry(bench_id)
elif fw.has_torch():
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
args = [x if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
ret = self.fw_op(op_id, bench, bench_id, *args)
if bench > 0:
bench_registry[ret] = libtriton.retrieve_scalar(op_id)
bench_registry[ret] = libtriton.retrieve_scalar(bench_id)
else:
assert False

View File

@@ -38,25 +38,19 @@ void convnd(A_TYPE *A,
int rah[TM] = rabh % CH;
rah = rah * UPAW - off_uah;
raw = raw * UPAH - off_uaw;
int racr[TK] = rk / BW;
int ras[TK] = rk % BW;
int rac[TK] = racr / BH;
int rar[TK] = racr % BH;
rar = UPAR * rar;
ras = UPAS * ras;
int ram[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
int rak[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
int rak[TK] = *(ADELTA + rk);
A_TYPE* pa[TM, TK] = A + ram[:, newaxis] + rak[newaxis, :];
// pointers for B
int rbk[TK] = rk;
int rbn[TN] = ryb;
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rbk[:, newaxis] * ldb_s;
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rbk[:, newaxis] * ldb_c;
// pointers for A look-up table
int rklut[TK] = rk % LUT_SIZE;
int* padiff[TK] = ADIFF + rklut;
int* padelta[TK] = ADELTA + rklut + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
int* padelta[TK] = ADELTA + TK + rklut + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
int adiff[TK] = *padiff;
int adelta[TK] = *padelta;
@@ -66,7 +60,7 @@ void convnd(A_TYPE *A,
for(int k = K; k > 0; k = k - TK){
c += a @ b;
pa += adelta[newaxis, :];
pb += TK * ldb_s;
pb += TK * ldb_c;
// increment A look-up table
padelta = padelta + adiff;
adelta = *padelta;
@@ -99,29 +93,54 @@ void convnd(A_TYPE *A,
kernel = triton.kernel(src, ['C'])
@staticmethod
def _unpack(idx, D, H, W):
cdh = idx // W
w = idx % W
cd = cdh // H
h = cdh % H
c = cd // D
d = cd % D
return c, d, h, w
def _unpack(idx, order, shape_b):
_123 = idx // shape_b[order[0]]
_0 = idx % shape_b[order[0]]
_23 = _123 // shape_b[order[1]]
_1 = _123 % shape_b[order[1]]
_3 = _23 // shape_b[order[2]]
_2 = _23 % shape_b[order[2]]
return _0, _1, _2, _3
@staticmethod
def _delta_a(upsample_d, upsample_h, upsample_w, depth, TK,
T, R, S, stride_a):
def _roundup(x, div):
return (x + div - 1) // div * div
@staticmethod
def _delta_a(upsample_d, upsample_h, upsample_w,
bc, bd, bh, bw,
ac, ad, ah, aw,
stride_a, shape_b,
TK):
# Parse the axes so that the reduction is done
# from the innermost dimension outward
order = sorted([bc, bd, bh, bw], reverse = True)
c, d, h, w = [order.index(x) for x in [bc, bd, bh, bw]]
# Size of the lookup table is the product of the 3 innermost dimensions
K = _conv._roundup(TK, shape_b[order[0]] * shape_b[order[1]] * shape_b[order[2]])
# Allocate temporary arrays
ud = np.arange(upsample_d, dtype=np.int32)[:, np.newaxis, np.newaxis, np.newaxis]
uh = np.arange(upsample_h, dtype=np.int32)[np.newaxis, :, np.newaxis, np.newaxis]
uw = np.arange(upsample_w, dtype=np.int32)[np.newaxis, np.newaxis, :, np.newaxis]
ctrs = np.arange(depth, dtype=np.int32)[np.newaxis, np.newaxis, np.newaxis, :]
c, t, r, s = _conv._unpack(ctrs, T, R, S)
nextc, nextt, nextr, nexts = _conv._unpack(ctrs + TK, T, R, S)
cdiff = nextc - c
tdiff = nextt - t
rdiff = nextr - r
sdiff = nexts - s
return cdiff*stride_a[1] + tdiff*stride_a[2] + rdiff*stride_a[3] + sdiff*stride_a[4]
k = np.arange(K , dtype=np.int32)[np.newaxis, np.newaxis, np.newaxis, :]
# Find reduction indices at the current and next reduction indices
currentk = _conv._unpack(k , order, shape_b)
nextk = _conv._unpack(k + TK, order, shape_b)
# Compute memory stride
result = 0
result += (nextk[c] - currentk[c]) * stride_a[ac]
result += (nextk[d] - currentk[d]) * stride_a[ad]
result += (nextk[h] - currentk[h]) * stride_a[ah]
result += (nextk[w] - currentk[w]) * stride_a[aw]
# Initial k
ki = np.arange(TK , dtype=np.int32)[np.newaxis, np.newaxis, np.newaxis, :]
currentk = _conv._unpack(ki, order, shape_b)
resulti = 0
resulti += currentk[c] * stride_a[ac]
resulti += currentk[d] * stride_a[ad]
resulti += currentk[h] * stride_a[ah]
resulti += currentk[w] * stride_a[aw]
return np.concatenate((resulti, result), axis=-1)
@staticmethod
def _extract_strides(shape):
@@ -134,38 +153,56 @@ void convnd(A_TYPE *A,
@staticmethod
def _call(a, b,
upsample_d, upsample_h, upsample_w,
pad_d, pad_h, pad_w,
stride_d, stride_h, stride_w,
mode):
stride_d, stride_h, stride_w,
upsample_d, upsample_h, upsample_w,
a_layout, b_layout, c_layout):
# input shapes
shape_a = list(triton.shape(a))
shape_b = list(triton.shape(b))
# add depth
shape_a.insert(2, 1)
shape_b.insert(1, 1)
NB, NC, AD, AH, AW = shape_a
NC, BD, BH, BW, NF = shape_b
dim = len(shape_a) - 2
# indices
an, ac, ad, ah, aw = [a_layout.find(x) for x in 'ncdhw']
bk, bc, bd, bh, bw = [b_layout.find(x) for x in 'kctrs']
cn, ck, cd, ch, cw = [c_layout.find(x) for x in 'nkdhw']
# extract shapes
if dim == 2:
shape_a.insert(ad, 1)
if dim == 2:
shape_b.insert(bd, 1)
# output shape
CD = (AD*upsample_d - BD + 1 + 2*pad_d + stride_d - 1) // stride_d
CH = (AH*upsample_h - BH + 1 + 2*pad_h + stride_h - 1) // stride_h
CW = (AW*upsample_w - BW + 1 + 2*pad_w + stride_w - 1) // stride_w
shape_c = [NB, NF, CD, CH, CW]
shape_c = [0] * 5
shape_c[cn] = shape_a[an]
shape_c[ck] = shape_b[bk]
shape_c[cd] = (shape_a[ad]*upsample_d - shape_b[bd] + 1 + 2*pad_d + stride_d - 1) // stride_d
shape_c[ch] = (shape_a[ah]*upsample_h - shape_b[bh] + 1 + 2*pad_h + stride_h - 1) // stride_h
shape_c[cw] = (shape_a[aw]*upsample_w - shape_b[bw] + 1 + 2*pad_w + stride_w - 1) // stride_w
# strides
stride_a = _conv._extract_strides(shape_a)
stride_b = _conv._extract_strides(shape_b)
stride_c = _conv._extract_strides(shape_c)
# look-up tables
# tiling parameters
TM = [32]
TN = [32]
TK = 8
FS = BD * BH * BW
depth = (TK + FS - 1)//FS * FS
# pointer deltas for a
delta_a = _conv._delta_a(upsample_d, upsample_h, upsample_w,
depth, TK, BD, BH, BW, stride_a)
bc, bd, bh, bw,
ac, ad, ah, aw,
stride_a, shape_b,
TK)
delta_a = triton.fw.torch.from_numpy(delta_a).cuda()
inc_a = np.arange(depth, dtype=np.int32)
inc_a = ((inc_a + TK) % depth) - inc_a
# delta increments for a
inc_a = np.arange(delta_a.shape[-1] - TK, dtype=np.int32)
inc_a = ((inc_a + TK) % inc_a.size) - inc_a
inc_a = triton.fw.torch.from_numpy(inc_a).cuda()
# allocate output
if dim == 2:
shape_c.pop(cd)
c = triton.empty(shape_c, dtype=a.dtype)
if dim == 2:
shape_c.insert(cd, 1)
# execute kernel
trans_b = False
is_wgrad = False
is_blut = False
@@ -174,31 +211,99 @@ void convnd(A_TYPE *A,
'UPAS': 'stride_w' if is_wgrad else '1',
'UPAH': '' if is_wgrad else 'stride_h',
'UPAW': '' if is_wgrad else 'stride_w',
'LUT_SIZE': depth,
'TM': [32],
'TN': [32],
'TK': TK,
'A_TYPE': 'float',
'B_TYPE': 'float'
'LUT_SIZE': delta_a.shape[-1],
'TM': TM, 'TN': TN, 'TK': TK,
'A_TYPE': 'float', 'B_TYPE': 'float'
}
shape_c.pop(2)
c = triton.empty(shape_c, dtype=a.dtype)
grid = lambda opt: [triton.cdiv(NB*CD*CH*CW, opt.d('TM')), triton.cdiv(NF, opt.d('TN'))]
print(stride_c)
print(stride_b)
_conv.kernel(a, b, c, NB*CD*CH*CW, NF, NC*BD*BH*BW, AH, AW, BH, BW, CH, CW, NC,
stride_a[0], stride_a[1], stride_a[2], stride_a[3], stride_a[4],
stride_b[0], stride_b[1], stride_b[2], stride_b[3], stride_b[4],
stride_c[0], stride_c[1], stride_c[2], stride_c[3], stride_c[4],
pad_h, pad_w, stride_h, stride_w, upsample_h, upsample_w,
MATMUL_M = shape_c[cn] * shape_c[cd] * shape_c[ch] * shape_c[cw]
MATMUL_N = shape_c[ck]
MATMUL_K = shape_b[bc] * shape_b[bd] * shape_b[bh] * shape_b[bw]
_conv.kernel(a, b, c,
# matrix multiplication shapes
MATMUL_M, MATMUL_N, MATMUL_K,
# shapes for a
shape_a[ah], shape_a[aw],
# shapes for b
shape_b[bh], shape_b[bw],
# chapes for c
shape_c[ch], shape_c[cw], shape_c[cn],
# strides for a
stride_a[an], stride_a[ac], stride_a[ad + 0], stride_a[ad + 1], stride_a[ad + 2],
# strides for b
stride_b[bc], stride_b[bd + 0], stride_b[bd + 1], stride_b[bd + 2], stride_b[bk],
# strides for c
stride_c[cn], stride_c[ck], stride_c[cd], stride_c[cd + 1], stride_c[cd + 2],
# padding
pad_h, pad_w,
# striding
stride_h, stride_w,
# upsampling
upsample_h, upsample_w,
0, 0, 0, 0, 0, 0,
# look-up table
delta_a, inc_a,
grid, **macros)
lambda opt: [triton.cdiv(MATMUL_M, opt.d('TM')), triton.cdiv(MATMUL_N, opt.d('TN'))],
**macros)
return c
@staticmethod
def forward(ctx, input, weight):
return _conv._call(input, weight, 1, 1, 1, 0, 0, 0, 1, 1, 1, '')
def forward(ctx, x, w,
pad_d = 0, pad_h = 0, pad_w = 0,
stride_d = 1, stride_h = 1, stride_w = 1,
upsample_d = 1, upsample_h = 1, upsample_w = 1,
layout_a = 'ncdhw', layout_b = 'ktrsc', layout_c = 'nkdhw'):
# save for backward
ctx.save_for_backward(x, w)
ctx.pad_d = pad_d
ctx.pad_h = pad_h
ctx.pad_w = pad_w
ctx.stride_d = stride_d
ctx.stride_h = stride_h
ctx.stride_w = stride_w
ctx.upsample_d = upsample_d
ctx.upsample_h = upsample_h
ctx.upsample_w = upsample_w
ctx.layout_a = layout_a
ctx.layout_b = layout_b
ctx.layout_c = layout_c
# return
return _conv._call(x, w,
pad_d, pad_h, pad_w,
stride_d, stride_h, stride_w,
upsample_d, upsample_h, upsample_w,
layout_a, layout_b, layout_c)
@staticmethod
def backward(ctx, dy):
x, w = ctx.saved_tensors
pad_d = ctx.pad_d
pad_h = ctx.pad_h
pad_w = ctx.pad_w
stride_d = ctx.stride_d
stride_h = ctx.stride_h
stride_w = ctx.stride_w
upsample_d = ctx.upsample_d
upsample_h = ctx.upsample_h
upsample_w = ctx.upsample_w
layout_a = ctx.layout_a
layout_b = ctx.layout_b
layout_c = ctx.layout_c
# TODO: Deal with this
dx_pad_d = 1
dx_pad_h = 1
dx_pad_w = 1
dx = _conv.call(dy, w,
dw_pad_d, dw_pad_h, dw_pad_w,
upsample_w, upsample_h, upsample_w,
stride_d, stride_h, stride_w,
'ncdhw', 'cktrs', 'nkdhw')
ret = [None] * 14
ret[0] = None
ret[1] = dw
return None,
conv = _conv.apply

View File

@@ -3,37 +3,50 @@ import triton
class _dot(triton.function):
src = """
void dot(TYPE * A, TYPE * B, TYPE * C,
void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C,
float alpha,
int M, int N, int K,
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
float c[TM, TN] = 0;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb;
// reduction loop
for(int k = K; k > 0; k-= TK){
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
bool checka[SHAPE_A] = k > TK;
bool checkb[SHAPE_B] = k > TK;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
}
// epilogue
TYPE* pc[TM, TN] = C + rm[:, newaxis] * ldc + rn[newaxis, :];
*pc = c;
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
bool checka[SHAPE_A] = rk[BROADCAST_AK] < K;
bool checkb[SHAPE_B] = rk[BROADCAST_BK] < K;
TYPE a[SHAPE_A] = checka ? *pa : 0;
TYPE b[SHAPE_B] = checkb ? *pb : 0;
// reduction loop
float c[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
c += USE_A @ USE_B;
bool checka[SHAPE_A] = k > TK;
bool checkb[SHAPE_B] = k > TK;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
a = *?(checka)pa;
b = *?(checkb)pb;
}
//c = c * alpha;
// epilogue
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
int rxn[TN] = get_program_id(1) * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rxm[:, newaxis] * ldc + rxn[newaxis, :];
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
*?(checkc)pc = (TYPE[TM, TN])c;
}
"""
kernel = triton.kernel(src, ['C'])
@@ -75,10 +88,10 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
_dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc,
_dot.kernel(a, b, c, 1., M, N, Ka, lda, ldb, ldc,
grid, bench=bench,
AT = transpose_a, BT = transpose_b, TYPE = dtype,
TM = [64, 128], TN = [64, 128], TK = [8], **macros)
TM = [64], TN = [128], TK = [8], **macros)
return c
@staticmethod

View File

@@ -1,234 +1,651 @@
# Special thanks to Scott Gray from OpenAI for writing the einsum parsing function
import numpy as np
import torch
from math import ceil, log2
from enum import IntEnum
import triton
import math
from functools import reduce
from operator import mul
from sympy.parsing.sympy_parser import parse_expr
import sympy as sp
from collections import OrderedDict
from collections import namedtuple
import re
from sympy.printing.ccode import C89CodePrinter
class _einsum(triton.function):
src = """
void einsumk(TYPE * A, TYPE * B, TYPE * C,
int dim_M, int dim_N, int dim_K,
int std_A0 __multipleof(8),
int std_B0 __multipleof(8),
int std_C0 __multipleof(8),
int std_A1 __multipleof(8),
int std_B1 __multipleof(8),
int std_C1 __multipleof(8)) {
// program id
int pgm = get_program_id(0);
int pgn = get_program_id(1);
int pgb = get_program_id(2);
// range
int rm[TM] = pgm * TM + 0 ... TM;
int rn[TN] = pgn * TN + 0 ... TN;
int rb[TB] = pgb * TB + 0 ... TB;
int rk[TK] = 0 ... TK;
// accumulator
float c[TM, TN, TB] = 0;
// pointers to a
TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK
+ rm[BROADCAST_AM] * STRIDE_AM
+ rb[newaxis, newaxis, :] * std_A0;
// pointers to b
TYPE *pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK
+ rn[BROADCAST_BN] * STRIDE_BN
+ rb[newaxis, newaxis, :] * std_B0;
// prefetch
TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb;
// accumulation
for(int k = dim_K; k > 0; k -= TK) {
c += USE_A @ USE_B;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
bool checka[SHAPE_A] = k > TK;
bool checkb[SHAPE_B] = k > TK;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
}
// write-back
TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1
+ rn[newaxis, :, newaxis] * 1
+ rb[newaxis, newaxis, :] * std_C0;
bool checkm[TM] = rm < dim_M;
bool checkn[TN] = rn < dim_N;
#############################
## Triton-C code generation
#############################
def print_cc(expr, axes_0, axes_1, axes_2):
class TritonCodePrinter(C89CodePrinter):
def __init__(self, axes_0, axes_1, axes_2):
super(TritonCodePrinter, self).__init__()
self.axes_0 = axes_0
self.axes_1 = axes_1
self.axes_2 = axes_2
def _print_Symbol(self, expr):
name = super(C89CodePrinter, self)._print_Symbol(expr)
if expr in self.axes_0:
return f'r{name}[:, newaxis, newaxis]'
if expr in self.axes_1:
return f'r{name}[newaxis, :, newaxis]'
if expr in self.axes_2:
return f'r{name}[newaxis, newaxis, :]'
return name
def _print_Indexed(self, expr):
assert len(expr.indices) == 1
return "*(%s + %s)" % (self._print(expr.base.label),
self._print(expr.indices[0]))
return TritonCodePrinter(axes_0, axes_1, axes_2).doprint(expr)
def unpack_cc(tile, axes, prefix, remat):
ret = ''
axes = list(map(str, axes))
for i, d in enumerate(reversed(axes)):
if i == len(axes) - 1:
break
currs = ''.join(axes[: len(axes) - i])
nexts = ''.join(axes[: len(axes) - (i + 1)])
ty = '' if remat else 'int '
sz = '' if remat else f'[{tile}]'
ret += f' {ty}{prefix}{nexts}{sz} = r{currs} / dim_{d};\n'
ret += f' {ty}{prefix}{d}{sz} = r{currs} % dim_{d};\n'
return ret
def strides_cc(name, expr):
ret = [f'stride_{name}_{d}' for d in expr[:-1]] + ['1']
ret = dict(zip(expr, ret))
return ret
def make_kernel(name,
expr_a, expr_b, expr_c,
axes_m, axes_n, axes_k, axes_b,
multipleof_a, multipleof_b, multipleof_c,
lut_mode_a, lut_mode_b,
delta_a, delta_b,
subscripted):
use_lut_a = True
use_lut_b = True
src = ""
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
src += f"""
char __constant__* AD = calloc({4*len(delta_a)});"""
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
src += f"""
char __constant__* BD = calloc({4*len(delta_b)});"""
src += f"""
__global__ void {name}(
TYPE * A __noalias __readonly __aligned(16)
, TYPE * B __noalias __readonly __aligned(16)
, TYPE * C
, int * locks
, float alpha
, int matmul_m, int matmul_n, int matmul_k __multipleof(16)
, int div_m
"""
for dim in [axes_m, axes_n, axes_k, axes_b]:
for d in dim:
src += f", int dim_{d}"
src += "\n "
for dim, name, mult in zip([expr_a, expr_b, expr_c],
['a', 'b', 'c'],
[multipleof_a, multipleof_b, multipleof_c]):
for d in range(len(dim) - 1):
attr = f'__multipleof({mult})'
src += f", int stride_{name}_{d} {attr}"
src += "\n "
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += f", int stride_a_inner __multipleof({multipleof_a})"
elif lut_mode_a == _einsum.LUT_MODE.DRAM:
src += ", int* AD __noalias __readonly __aligned(16)"
src += "\n "
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += f", int stride_b_inner __multipleof({multipleof_b})"
elif lut_mode_b == _einsum.LUT_MODE.DRAM:
src += ", int* BD"
for ptr in subscripted:
src += f", int* {ptr}"
src += """) {
// re-order outer program ids
int grid_m = (matmul_m + TM - 1) / TM;
int grid_n = (matmul_n + TN - 1) / TN;
int pid_mn = get_program_id(0) / div_m;
int pid_n = pid_mn % grid_n;
int pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m);
// get batch program id
int pid_b = get_program_id(1);
#if TZ == 1
int off_k = 0;
#else
// get reduction sub-group program id
int pid_z = get_program_id(2);
int grid_z = get_num_programs(2);
int div_z = matmul_k / TZ;
int rem_z = matmul_k % TZ;
int off_k = pid_z * div_z;
matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z);
#endif
// create ranges
"""
rk = 'r{}'.format(''.join(map(str,axes_k)))
for axes, tile, off in zip([axes_m, axes_n, axes_b, axes_k],
['TM', 'TN', 'TB', 'TK'],
['pid_m*TM', 'pid_n*TN', 'pid_b*TB', 'off_k']):
currs = ''.join(map(str,axes))
if axes:
src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n"
src += _einsum.unpack_cc(tile, axes, 'r', False)
src += """
// initialize pointers to A
int offa[TM, TK, TB] = """
for i, sym in enumerate(expr_a):
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1'
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
src += ';'
src += """
TYPE *pa[TM, TK, TB] = A + offa;"""
if use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR:
spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
src += f"""
// initialize pointers to A look-up table
int offadelta[TK] = off_k + 0 ... TK;
int {spec} *padelta[TK] = {cast}AD + offadelta;
int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];"""
src += """
// initialize pointers to B
int offb[TK, TN, TB] = """
for i, sym in enumerate(expr_b):
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b)
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else '1'
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
src += ';'
src += """
TYPE *pb[TK, TN, TB] = B + offb;"""
if use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR:
spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
src += f"""
// initialize pointers to B look-up table
int offbdelta[TK] = off_k + 0 ... TK;
int *pbdelta[TK] = BD + offbdelta;"""
src += f"""
// prefetch
bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
bool checkk[TK] = {rk} < matmul_k + off_k;
bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
TYPE a[TM, TK, TB] = checka ? *pa : 0;
TYPE b[TK, TN, TB] = checkb ? *pb : 0;
// accumulate
float acc[TM, TN, TB] = 0;
for(int k = matmul_k; k > 0; k -= TK) {{
acc += a @ b;"""
if not use_lut_a or not use_lut_b:
src += f"""
{rk} += TK;
"""
src += _einsum.unpack_cc(tile, axes_k, 'r', True)
if use_lut_a:
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += """
pa += stride_a_inner;"""
else:
src += """
pa += incda;
padelta += TK;
incda = (*padelta)[newaxis, :, newaxis];"""
else:
src += """
offa = """
for i, sym in enumerate(expr_a):
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1'
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
src += """;
TYPE *pa[TM, TK, TB] = A + offa;"""
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += """
pb += stride_b_inner;"""
else:
src += """
pb += (*pbdelta)[:, newaxis, newaxis];
pbdelta += TK;"""
src += f"""
checkk = k > TK;
checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
a = *?(checka)pa;
b = *?(checkb)pb;
}}
TYPE c[TM, TN, TB] = acc;
// re-materialize ranges
"""
for axes, tile, off in zip([axes_m, axes_n, axes_b],
['TM', 'TN', 'TB'],
['pid_m*TM', 'pid_n*TN', 'pid_b*TB']):
currs = ''.join(map(str,axes))
if axes:
src += f" r{currs} = {off} + 0 ... {tile};\n"
src += _einsum.unpack_cc(tile, axes, 'r', True)
src += """
// initialize pointers to C
int offc[TM, TN, TB] = """
for i, sym in enumerate(expr_c):
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else '1'
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b)
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
src += ';'
src += """
TYPE *pc[TM, TN, TB] = C + offc;
// bounds-checking
checkm = r""" + ''.join(map(str,axes_m)) + """ < matmul_m;
checkn = r""" + ''.join(map(str,axes_n)) + """ < matmul_n;
bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] &&
checkn[newaxis, :, newaxis];
*?(checkc)pc = (TYPE[TM, TN, TB])c;
checkn[newaxis, :, newaxis];
// write back
#if TZ == 1
*?(checkc)pc = c;
#else
int *plock = locks + pid_mn + pid_b * get_num_programs(0);
int *pcount = plock + 1024*1024;
// spin
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc)pc = c;
else
*?(checkc)pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % (grid_z));
atomic_xchg(plock, 0);
#endif
}
"""
kernel = triton.kernel(src, ['C'])
#print(src)
ret = triton.kernel(src, ['C'])
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
ret.set_constant('AD', delta_a)
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
ret.set_constant('BD', delta_b)
return ret
############################
## Look-up Table
############################
class LUT_MODE(IntEnum):
SCALAR = 1
CONSTANT = 2
DRAM = 3
def lut_mode(delta):
if delta.size == 0 or np.min(delta) == np.max(delta):
return _einsum.LUT_MODE.SCALAR
#if delta.size < 4096:
# return _einsum.LUT_MODE.CONSTANT
return _einsum.LUT_MODE.DRAM
def symbolic_delta(symbols, axes):
rank = len(symbols)
strides = [sp.symbols(f'stride{d}') for d in range(rank)]
nexts = {s: sp.symbols(f'next{s}') for s in axes}
delta = 0
for i in range(rank):
delta += strides[i] * (symbols[i].subs(nexts) - symbols[i])
return delta
def unpack_offset(k, axes, dims):
ret = dict()
for d in reversed(axes):
ret[d] = k % dims[d]
k = k // dims[d]
return ret
def make_delta(axes, step, stride, dims, symbols, arrays):
# symbolic pointer increments
delta = _einsum.symbolic_delta(symbols, axes)
args = [f'stride{d}' for d in range(len(stride))]
args += [f'{sk}' for sk in axes]
args += [f'next{sk}' for sk in axes]
args += [f'{sk}' for sk, _ in arrays]
fn = sp.lambdify(args, delta, 'numpy')
# inner axes values
inner = [dims[d] for d in axes]
k = np.arange(np.prod(inner), dtype=np.int32)
off = _einsum.unpack_offset(k, axes, dims)
nextoff = _einsum.unpack_offset(k + step, axes, dims)
# evaluate deltas
args = [s for s in stride]
args += [off[sk] for sk in axes]
args += [nextoff[sk] for sk in axes]
args += [x for _, x in arrays]
delta = fn(*args)
return delta, _einsum.lut_mode(delta[:-step])
############################
## Einsum parsing
############################
def uniq(seq):
seen = set()
seen_add = seen.add
return [x for x in seq if not (x in seen or seen_add(x))]
def parse_axes(expr_a, expr_b, expr_c, subscripted):
is_index = lambda x: type(x) == sp.indexed.Indexed or str(x) in subscripted
sym_a = [x for s in expr_a for x in s.free_symbols if not is_index(x)]
sym_b = [x for s in expr_b for x in s.free_symbols if not is_index(x)]
sym_c = [x for s in expr_c for x in s.free_symbols]
batch = [d for d in sym_a if d in sym_b and d in sym_c]
outer = [d for d in sym_a if d not in sym_b and d in sym_c]
inner = [d for d in sym_a if d in sym_b and d not in sym_c]
illegal = [d for d in sym_a if d not in sym_b and d not in sym_c]
if illegal:
raise ValueError(f"einsum labels {illegal} ({expr_a}) "\
f"not present in {expr_b} or {expr_c}")
return _einsum.uniq(batch), _einsum.uniq(outer), _einsum.uniq(inner)
def replace_subscript(expr, arrays):
# replace array indexing by Indexed()
indexed = re.findall('([_a-zA-Z][_a-zA-Z0-9]*)\[([_a-z]*)\]', expr)
for x in indexed:
arrays.append(x[0])
expr = expr.replace(f'{x[0]}[{x[1]}]', f'Indexed({x[0]},{x[1]})')
return expr
def parse_expr(expr, arrays):
# extract symbols
sym = []
i = 0
while i < len(expr):
d = expr[i]
if d == '(':
size = expr[i:].find(')')
d = expr[i : i + size + 1]
d = _einsum.replace_subscript(d, arrays)
sym.append(parse_expr(d))
i += size + 1
else:
sym.append(parse_expr(d))
i += 1
return sym
@staticmethod
def _append_dim(dim_data, dim_type, idx, label, dim, stride):
if dim_type in dim_data:
data = dim_data[dim_type]
if idx != data["idx"] + 1:
raise ValueError("aggregate inner, outer and batch dims must be adjacent to each other.")
data["dim"] *= dim
data["lab"] = label + data["lab"]
else:
dim_data[dim_type] = dict(idx=idx, lab=label, dim=dim, std=stride)
return dim_type
############################
## Preprocessing
############################
@staticmethod
def _parse_abc(labels_a, labels_b, labels_c, shape_a, is_a=False):
def pad(tensor, pad):
pad = pad + [0] * (2*len(tensor.shape) - len(pad))
begin = [ x if x > 0 else None for x in pad[-1::-2]]
end = [-x if x > 0 else None for x in pad[-2::-2]]
slices = [slice(b, e) for b, e in zip(begin, end)]
tensor = torch.nn.functional.pad(tensor, pad, 'constant', 0)
tensor = tensor[slices]
return tensor
if len(labels_a) != len(shape_a):
raise ValueError(f"einsum notation dims do not match shape: {labels_a} {shape_a}")
trans = False
stride = 1
std1 = None
data = dict()
for idx, (lab, dim) in enumerate(reversed(list(zip(labels_a, shape_a)))):
#print(idx, lab, dim)
if dim is None:
raise ValueError("einsum doens't currently work on shapes with placeholder dims.")
if idx == 0 and dim % 8 != 0:
raise ValueError("contiguous dim must be multiple of 8")
############################
## Compilation
############################
if lab in labels_c:
# batch dim
if lab in labels_b:
_einsum._append_dim(data, "B", idx, lab, dim, stride)
if idx == 0:
raise ValueError(f"batch dim can not be contiguous dim: {lab} {labels_a} {shape_a}")
# outer dim
else:
std1 = _einsum._append_dim(data, "O", idx, lab, dim, stride)
if idx == 0:
trans = is_a
# inner dim
elif lab in labels_b:
std1 = _einsum._append_dim(data, "I", idx, lab, dim, stride)
if idx == 0:
trans = not is_a
else:
raise ValueError(f"einsum def for output: {lab} ({labels_a}), not present in either other def")
class instance:
stride *= dim
locks = None
kernel_cache = dict()
if "B" not in data:
data["B"] = dict(dim=1, std=1)
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays):
# parse symbols
expr_a, expr_bc = einsum.split(",")
expr_b, expr_c = expr_bc.split("->")
subscripted = []
sym_a = _einsum.parse_expr(expr_a, subscripted)
sym_b = _einsum.parse_expr(expr_b, subscripted)
sym_c = _einsum.parse_expr(expr_c, subscripted)
# parse axes
axes_b, axes_m, axes_k = _einsum.parse_axes(sym_a, sym_b, sym_c, subscripted)
_, axes_n, _ = _einsum.parse_axes(sym_b, sym_a, sym_c, subscripted)
axes = axes_b + axes_m + axes_n + axes_k
# check dimensions
dims_a = dict(zip(sym_a, shape_a))
dims_b = dict(zip(sym_b, shape_b))
dims_c = dict(zip(sym_c, shape_c))
for axes in [axes_b, axes_k]:
for d in axes:
dim_a = dims_a[d] if d in sym_a else None
dim_b = dims_b[d] if d in sym_b else None
if dim_a and dim_b and dim_a != dim_b:
raise ValueError(f'incompatible dimension {d}'
f' (a: {dim_a}; b: {dim_b})')
dims = dict()
dims.update(dims_a)
dims.update(dims_b)
dims.update(dims_c)
# look-up tables
TK = 16 if dtype == triton.fw.torch.float16 else 8
arrays = [(x, arrays[x]) for x in subscripted]
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays)
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays)
# hash for recompilation
stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0])
name = f'{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'
# recompile if necessary
cache = _einsum.instance.kernel_cache
if name not in cache:
cachesize = len(cache)
cache[name] = _einsum.make_kernel(f'__einsum{cachesize}',
sym_a, sym_b, sym_c,
axes_m, axes_n, axes_k, axes_b,
stride_a_multiple, stride_b_multiple, stride_c_multiple,
lut_mode_a, lut_mode_b,
delta_a, delta_b,
subscripted)
self.kernel = cache[name]
# Initialize locks
if _einsum.instance.locks is None:
_einsum.instance.locks = torch.zeros(2*1024*1024, dtype=torch.int32).cuda()
# Kernel arguments
dim_m = [dims[d] for d in axes_m]
dim_n = [dims[d] for d in axes_n]
dim_k = [dims[d] for d in axes_k]
dim_b = [dims[d] for d in axes_b]
M = reduce(mul, dim_m, 1)
N = reduce(mul, dim_n, 1)
K = reduce(mul, dim_k, 1)
B = reduce(mul, dim_b, 1)
stride_a = list(stride_a[:-1])
stride_b = list(stride_b[:-1])
stride_c = list(stride_c[:-1])
arrays = [torch.from_numpy(x).cuda() for _, x in arrays]
alpha = 1.
div_m = 1
self.args = [None, None, None,
_einsum.instance.locks,
alpha, M, N, K, div_m] +\
dim_m + dim_n + dim_k + dim_b +\
stride_a + stride_b + stride_c
if lut_mode_a != _einsum.LUT_MODE.CONSTANT:
delta_a = delta_a[0] if lut_mode_a == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_a).cuda()
self.args += [delta_a]
if lut_mode_b != _einsum.LUT_MODE.CONSTANT:
delta_b = delta_b[0] if lut_mode_b == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_b).cuda()
self.args += [delta_b]
self.args += arrays
self.args += [lambda opt: [triton.cdiv(M, opt.d('TM')) *
triton.cdiv(N, opt.d('TN')),
triton.cdiv(B, opt.d('TB')),
opt.d('TZ')]]
# position of dynamic arguments
self.pos_a = 0
self.pos_b = 1
self.pos_c = 2
# pre-processor macros
TM = [x for x in [16, 32, 64, 128] if x <= M]
TN = [x for x in [16, 32, 64, 128] if x <= N]
TB = [x for x in [1, 2, 4] if x <= B]
MAX_GZ = K // 2048
MIN_GM = M // max(TM)
MIN_GN = N // max(TN)
MIN_GB = B // max(TB)
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
#TB, TZ = [1], [1]
#TM, TN, TB, TZ = [128], [128], [1], [1]
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype }
self.dtype = dtype
self.flops = 2 * B * M * N * K
self.sym_a = sym_a
self.sym_b = sym_b
self.sym_c = sym_c
# save equivalent mat-mul dimensions
self.matmul_B = B
self.matmul_M = M
self.matmul_N = N
self.matmul_K = K
def run(self, a, b, c, bench):
self.args[self.pos_a] = a
self.args[self.pos_b] = b
self.args[self.pos_c] = c
self.kernel(*self.args, bench=bench, **self.macros)
# batch, outer, inner, std0, std1, trans
return data["B"]["dim"], data["O"]["dim"], data["I"]["dim"], data["B"]["std"], data[std1]["std"], trans
############################
## Forward
############################
instance_cache = dict()
@staticmethod
def _parse_einsum(labels_a, labels_b, labels_c, shape_a, shape_b):
dims_a = dict(zip(labels_a, shape_a))
dims_b = dict(zip(labels_b, shape_b))
shape_c = list()
for lab in labels_c:
if lab in dims_a:
shape_c.append(dims_a[lab])
elif lab in dims_b:
shape_c.append(dims_b[lab])
else:
raise ValueError(f"einsum def for output: {lab} ({labels_c}), not present in either input def ({labels_a}, {labels_b})")
BA, M, KA, std_a0, std_a1, ta = _einsum._parse_abc(labels_a, labels_b, labels_c, shape_a, True)
BB, N, KB, std_b0, std_b1, tb = _einsum._parse_abc(labels_b, labels_a, labels_c, shape_b, False)
BC, _, _, std_c0, std_c1, _ = _einsum._parse_abc(labels_c, labels_b, labels_a, shape_c)
if not (BA == BB == BC):
raise ValueError("mismatched batch dims")
if KA != KB:
raise ValueError("mismatched reduction dims")
return shape_c, (BA, M, N, KA), (std_a0, std_b0, std_c0), (std_a1, std_b1, std_c1), ta, tb
@staticmethod
def call(a, b, trans_a, trans_b, shape_c, bmnk,
std0, std1, einsum_a, einsum_b, einsum_c,
bench):
def forward(ctx, einsum, a, b, shape_c, **kwargs):
bench = kwargs['bench'] if 'bench' in kwargs else False
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
# allocate output
dtype = a.dtype
c = triton.empty(shape_c, dtype)
grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')),
triton.cdiv(bmnk[2], opt.d('TN')),
triton.cdiv(bmnk[0], opt.d('TB'))]
macros = {# handle A transposition
'USE_A' : 'a[^1, ^0, ^2]' if trans_a else 'a',
'STRIDE_AK' : 'std_A1' if trans_a else '1',
'STRIDE_AM' : '1' if trans_a else 'std_A1',
'BROADCAST_AK': ':, newaxis, newaxis' if trans_a else 'newaxis, :, newaxis',
'BROADCAST_AM': 'newaxis, :, newaxis' if trans_a else ':, newaxis, newaxis',
'SHAPE_A' : 'TK, TM, TB' if trans_a else 'TM, TK, TB',
# handle B transposition
'USE_B' : 'b' if not trans_b else 'b[^1, ^0, ^2]',
'STRIDE_BK' : 'std_B1' if not trans_b else '1',
'STRIDE_BN' : '1' if not trans_b else 'std_B1',
'BROADCAST_BK': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis',
'BROADCAST_BN': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis',
'SHAPE_B' : 'TK, TN, TB' if not trans_b else 'TN, TK, TB'}
TM = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[1]) + 1 ))))]
TN = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[2]) + 1 ))))]
TB = [2**i for i in range(0, max(1, min(3, int(math.log2(bmnk[0]) + 1 ))))]
TK = [bmnk[2]] if bmnk[2] < 16 else [8, 16]
_einsum.kernel(a, b, c,
bmnk[1], bmnk[2], bmnk[3],
std0[0], std0[1], std0[2],
std1[0], std1[1], std1[2],
grid, bench=bench,
**macros,
TYPE=dtype, TM=TM, TN=TN, TK=TK, TB=TB)
c = triton.empty(shape_c, dtype=dtype)
key = (einsum, dtype,
a.stride(), b.stride(), c.stride(),
a.shape, b.shape, c.shape)
# compile einsum instance
cache = _einsum.instance_cache
#if key not in cache:
cache[key] = _einsum.instance(einsum, dtype,
a.stride(), b.stride(), c.stride(),
a.shape, b.shape, c.shape, arrays)
instance = cache[key]
instance.run(a, b, c, bench)
# save information in context
ctx.flops = instance.flops
ctx.sym_a = instance.sym_a
ctx.sym_b = instance.sym_b
ctx.sym_c = instance.sym_c
ctx.matmul_B = instance.matmul_B
ctx.matmul_M = instance.matmul_M
ctx.matmul_N = instance.matmul_N
ctx.matmul_K = instance.matmul_K
ctx.bench = bench
ctx.save_for_backward(a, b)
return c
############################
## Backward
############################
@staticmethod
def forward(ctx, subscripts, a, b, bench = 0):
ctx.save_for_backward(a, b)
# parse
if type(subscripts) is str:
einsum_a, einsum_bc = subscripts.split(",")
einsum_b, einsum_c = einsum_bc.split("->")
else:
einsum_a, einsum_b, einsum_c = subscripts
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
einsum_a, einsum_b, einsum_c,
triton.shape(a), triton.shape(b))
# save for backward
ctx.trans_a = ta
ctx.trans_b = tb
ctx.einsum_a = einsum_a
ctx.einsum_b = einsum_b
ctx.einsum_c = einsum_c
ctx.bench = bench
ctx.bmnk = bmnk
# run
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench)
def sym_invert(sym_c, sym_x, prefix, renamed, inverse):
for i, expr in enumerate(sym_x):
if expr.is_symbol:
continue
sc = [x for x in expr.free_symbols if x in sym_c][0]
sx = sp.symbols(f'{prefix}{i}')
renamed[expr] = sx
inverse[sc] = sp.solve(sp.Eq(expr, sx), sc)[0]
@staticmethod
def backward(ctx, dc):
def sym_to_expr(sym):
res = [f'({x})' for x in sym]
res = ''.join(res)
return res
@staticmethod
def backward(ctx, dy):
a, b = ctx.saved_tensors
trans_a = ctx.trans_a
trans_b = ctx.trans_b
einsum_a = ctx.einsum_a
einsum_b = ctx.einsum_b
einsum_c = ctx.einsum_c
bench = ctx.bench
sym_a = ctx.sym_a
sym_b = ctx.sym_b
sym_c = ctx.sym_c
inverse = dict()
renamed = dict()
_einsum.sym_invert(sym_c, sym_a, 'a', renamed, inverse)
_einsum.sym_invert(sym_c, sym_b, 'b', renamed, inverse)
sym_a = [renamed[x] if x in renamed else x for x in sym_a]
sym_b = [renamed[x] if x in renamed else x for x in sym_b]
sym_c = [inverse[x] if x in inverse else x for x in sym_c]
expr_a = _einsum.sym_to_expr(sym_a)
expr_b = _einsum.sym_to_expr(sym_b)
expr_c = _einsum.sym_to_expr(sym_c)
expr = f'{expr_c},{expr_b}->{expr_a}'
da = einsum(expr, dy, b, a.shape, False)
return None, da, None, None, None
if not trans_a and not trans_b: # NN
da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench)
db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench)
elif not trans_a and trans_b: # NT
da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench)
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench)
elif trans_a and not trans_b: # TN
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench)
db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench)
elif trans_a and trans_b: # TT (not used)
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench)
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench)
return None, da, db, None
einsum = _einsum.apply

View File

@@ -24,7 +24,7 @@ def empty(shape, dtype):
return tf_empty_proxy(shape, dtype)
#return fw.tf_extra_ops.alloc_empty(args, T = dtype)
elif fw.has_torch():
return fw.torch.empty(shape).cuda()
return fw.torch.empty(shape, dtype=dtype).cuda()
def shape(A) :
if fw.has_tensorflow():
@@ -47,16 +47,23 @@ class id_dict:
return libtriton.retrieve_scalar(self.id)
def __init__(self):
self.data = weakref.WeakKeyDictionary()
self.data = dict()
def __delitem__(self, key):
del self.data[key]
def __getitem__(self, key):
@staticmethod
def _get_key(key):
if fw.has_tensorflow():
if isinstance(key, fw.tensorflow.Tensor):
key = key.op
ret = self.data[key]
key = id(key.op)
if fw.has_torch():
if isinstance(key, fw.torch.Tensor):
key = id(key)
return key
def __getitem__(self, key):
ret = self.data[id_dict._get_key(key)]
if isinstance(ret, id_dict.lazy_entry):
return ret.get()
return ret
@@ -65,7 +72,4 @@ class id_dict:
return len(self.data)
def __setitem__(self, key, value):
if fw.has_tensorflow():
if isinstance(key, fw.tensorflow.Tensor):
key = key.op
self.data[key] = value
self.data[id_dict._get_key(key)] = value

View File

@@ -10,10 +10,12 @@ int main() {
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
std::vector<config_t> configs;
for(auto ord: std::vector<std::vector<int>>{{1, 0}})
for(auto x: std::vector<std::array<bool, 2>>{{false, false}, {false, true},
{true, false}}){
for(auto x: std::vector<std::array<bool, 2>>{{false, false}}){
std::vector<config_t> tmp = {
config_t{ord, x[0], x[1], 2048, 2048, 2048},
// config_t{ord, x[0], x[1], 512, 512, 512},
// config_t{ord, x[0], x[1], 1024, 1024, 1024},
config_t{ord, x[0], x[1], 127008, 768, 576},
// config_t{ord, x[0], x[1], 8192, 8192, 8192}
// config_t{ord, x[0], x[1], 16, 2048, 2048},
// config_t{ord, x[0], x[1], 32, 2048, 2048},
// config_t{ord, x[0], x[1], 64, 2048, 2048},
@@ -33,7 +35,7 @@ int main() {
int32_t M, N, K;
for(const auto& c: configs){
std::tie(ord, AT, BT, M, N, K) = c;
std::cout << "// " << c << std::flush;
std::cout << "// " << c ;
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;

View File

@@ -20,7 +20,7 @@ static void cc_dot(std::vector<T> &c, const std::vector<T> &a, const std::vector
float acc = 0;
for(size_t k = 0; k < K; k++)
acc = acc + (!AT ? a[k*M + m] : a[m*K + k]) * (!BT ? b[n*K + k] : b[k*N + n]);
c[m + n*M] = static_cast<T>(acc);
c[m*N + n] = static_cast<T>(acc);
}
}
@@ -72,9 +72,9 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
std::string ty = to_string<T>::value;
size_t dt_nbytes = sizeof(T);
drv::context* context = stream->context();
int32_t lda = AT ? K : M;
int32_t ldb = BT ? N : K;
int32_t ldc = M;
int32_t lda = (AT ^ a_order[0]==1) ? K : M;
int32_t ldb = (BT ^ b_order[0]==1) ? N : K;
int32_t ldc = N;
std::vector<std::string> sa = { "1", "lda" };
std::vector<std::string> sb = { "1", "ldb" };
@@ -86,17 +86,17 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
// macros
rt::function::options_space_t opt;
// A access patterns
opt.defines.push_back({"USEA", {AT? "a[^1, ^0]" : "a" }});
opt.defines.push_back({"BROADCAST_AK", {AT? ":, newaxis" : "newaxis, :" }});
opt.defines.push_back({"BROADCAST_AM", {AT? "newaxis, :" : ":, newaxis" }});
opt.defines.push_back({"SHAPE_A", {AT? "TK, TM" : "TM, TK" }});
opt.defines.push_back({"USEA", {AT? "a" : "a" }});
opt.defines.push_back({"BROADCAST_AK", {AT? "newaxis, :" : "newaxis, :" }});
opt.defines.push_back({"BROADCAST_AM", {AT? ":, newaxis" : ":, newaxis" }});
opt.defines.push_back({"SHAPE_A", {AT? "TM, TK" : "TM, TK" }});
opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
// B access patterns
opt.defines.push_back({"USEB", {BT? "b[^1, ^0]" : "b" }});
opt.defines.push_back({"BROADCAST_BK", {BT? "newaxis, :" : ":, newaxis" }});
opt.defines.push_back({"BROADCAST_BN", {BT? ":, newaxis" : "newaxis, :" }});
opt.defines.push_back({"SHAPE_B", {BT? "TN, TK" : "TK, TN" }});
opt.defines.push_back({"USEB", {BT? "b" : "b" }});
opt.defines.push_back({"BROADCAST_BK", {BT? ":, newaxis" : ":, newaxis" }});
opt.defines.push_back({"BROADCAST_BN", {BT? "newaxis, :" : "newaxis, :" }});
opt.defines.push_back({"SHAPE_B", {BT? "TK, TN" : "TK, TN" }});
opt.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }});
opt.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }});
// data-type
@@ -109,15 +109,15 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
opt.num_warps = {nwarp};
}
if(mode == BENCH) {
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {"16"}});
opt.num_warps = {4};
opt.defines.push_back({"TM", {"32", "64", "128"}});
opt.defines.push_back({"TN", {"32", "64", "128"}});
opt.defines.push_back({"TK", {to_string<T>::value == "half" ? "16" : "8"}});
opt.num_warps = {2, 4, 8};
}
// kernels
rt::function function(src::dot, opt);
std::vector<rt::arg> args = {&*da, &*db, &*dc, M, N, K, lda, ldb, ldc};
std::vector<rt::arg> args = {&*da, &*db, &*dc, (float)1, M, N, K, lda, ldb, ldc};
auto grid = grid2d(M, N);
// metrics
@@ -126,17 +126,17 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
bench.push_back(tflops(triton_ns));
// // cublas
// if(cublas::cublasinit()){
// NumericT alpha(static_cast<double>(1));
// NumericT beta(static_cast<double>(0));
// cublasGemmAlgo_t fastest;
// cublasGemm(cuty, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(cuty, stream, AT, BT, M, N, K,
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
// ldc, nullptr, fastest); }, stream);
// result.push_back(tflops(cublas_ms));
// }
// // cublas
// if(cublas::cublasinit()){
// T alpha(static_cast<double>(1));
// T beta(static_cast<double>(0));
// cublasGemmAlgo_t fastest;
// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
// ldc, nullptr, fastest); }, stream);
// bench.push_back(tflops(cublas_ms));
// }
}
// test triton
@@ -147,9 +147,9 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
std::vector<T> ha(M*K);
std::vector<T> hb(K*N);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = 1;
ha[i] = (float)rand()/RAND_MAX;
for(size_t i = 0; i < hb.size(); i++)
hb[i] = 1;
hb[i] = (float)rand()/RAND_MAX;
// copy buffer
stream->write(&*da, true, 0, ha);
stream->write(&*db, true, 0, hb);

View File

@@ -2,37 +2,58 @@ namespace src {
const char *dot =
R"(
void dot(TYPE * A, TYPE * B, TYPE * C,
int M, int N, int K,
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
float c[TM, TN] = 0;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb;
// reduction loop
for(int k = K; k > 0; k-= TK){
c += USEA @ USEB;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
bool checka[SHAPE_A] = k > TK;
bool checkb[SHAPE_B] = k > TK;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
}
// epilogue
TYPE* pc[TM, TN] = C + rm[:, newaxis] + rn[newaxis, :] * ldc;
*pc = c;
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16),
float alpha,
int M, int N, int K,
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8)) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int gridx = M / TM;
int gridy = N / TN;
int rid = ridx + ridy * gridx;
ridx = rid / gridy;
ridy = rid % gridy;
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
// pointers to operands
int offa[SHAPE_A] = rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
int offb[SHAPE_B] = rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
TYPE* pa[SHAPE_A] = A + offa;
TYPE* pb[SHAPE_B] = B + offb;
// prefetches operands
bool checka[SHAPE_A] = rk[BROADCAST_AK] < K;
bool checkb[SHAPE_B] = rk[BROADCAST_BK] < K;
TYPE a[SHAPE_A] = checka ? *pa : 0;
TYPE b[SHAPE_B] = checkb ? *pb : 0;
// reduction loop
float c[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
c += USEA @ USEB;
bool checka[SHAPE_A] = k > TK;
bool checkb[SHAPE_B] = k > TK;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
a = *?(checka)pa;
b = *?(checkb)pb;
}
//c = c * alpha;
// epilogue
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
int rxn[TN] = get_program_id(1) * TN + 0 ... TN;
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
*?(checkc)pc = (TYPE[TM, TN])c;
}
)";

View File

@@ -159,7 +159,7 @@ bool diff(const std::vector<T>& hc, const std::vector<T>& rc) {
if(hc.size() != rc.size())
return false;
for(size_t i = 0; i < hc.size(); i++)
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
return false;
}

View File

@@ -10,8 +10,8 @@ int main() {
// shapes to test
typedef std::tuple<dtype_t, bool, bool, int, int, int, int, int, int, int> config_t;
std::vector<config_t> configs;
for(int TM: std::vector<int>{32, 64})
for(int TN: std::vector<int>{32, 64})
for(int TM: std::vector<int>{32, 64, 128})
for(int TN: std::vector<int>{32, 64, 128})
for(int TK: std::vector<int>{16})
for(int nwarps: std::vector<int>{4})
for(bool AT: std::array<bool, 2>{false, true})