[PYTHON] Cleaned up legacy code; added simple standalone compilation API (#22)

This commit is contained in:
Philippe Tillet
2022-07-26 11:06:45 -07:00
committed by GitHub
parent 96cc6fb563
commit 3265e0df5a
84 changed files with 1382 additions and 14023 deletions

1
.clang-format Normal file
View File

@@ -0,0 +1 @@
BasedOnStyle: LLVM

View File

@@ -1,51 +0,0 @@
name: Documentation
on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
jobs:
Build-Documentation:
runs-on: self-hosted
steps:
- name: Checkout gh-pages
uses: actions/checkout@v1
with:
ref: 'gh-pages'
- name: Checkout branch
uses: actions/checkout@v1
- name: Build docs
run: |
git fetch origin master:master
cd docs
sphinx-multiversion . _build/html/
- name: Publish docs
run: |
git branch
# update docs
rm -r /tmp/triton-docs;
mkdir /tmp/triton-docs;
mv docs/_build/html/* /tmp/triton-docs/
git checkout gh-pages
cp -r CNAME /tmp/triton-docs/
cp -r index.html /tmp/triton-docs/
cp -r .nojekyll /tmp/triton-docs/
rm -r *
cp -r /tmp/triton-docs/* .
# ln -s master/index.html .
# mv master docs
git add .
git commit -am "[GH-PAGES] Updated website"
# publish docs
eval `ssh-agent -s`
DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }}
git remote set-url origin git@github.com:openai/triton.git
git push

View File

@@ -4,8 +4,7 @@ on:
workflow_dispatch:
pull_request:
branches:
- master
- v2.0
- main
jobs:
@@ -21,7 +20,7 @@ jobs:
- name: Clear cache
run: |
rm -r /tmp/triton/
rm -r ~/.triton/
continue-on-error: true
- name: Install Triton

View File

@@ -1,40 +0,0 @@
name: Wheels
on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
jobs:
Build-Wheels:
runs-on: self-hosted
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Patch setup.py
run: |
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
export LATEST_DATE=$(git show -s --format=%ci `git rev-parse HEAD` | cut -d ' ' -f 1 | sed 's/-//g')
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py
echo "" >> python/setup.cfg
echo "[build_ext]" >> python/setup.cfg
echo "base-dir=/project" >> python/setup.cfg
- name: Build wheels
run: |
export CIBW_MANYLINUX_X86_64_IMAGE="manylinux2014"
export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="manylinux2014"
export CIBW_BEFORE_BUILD="pip install cmake;\
yum install -y llvm11 llvm11-devel llvm11-static llvm11-libs zlib-devel;"
export CIBW_SKIP="{cp,pp}35-*"
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64"
python3 -m cibuildwheel python --output-dir wheelhouse
- name: Upload wheels to PyPI
run: |
python3 -m twine upload wheelhouse/* --skip-existing

View File

@@ -126,18 +126,10 @@ include_directories(${LLVM_INCLUDE_DIRS})
# Python module
if(BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module")
# Build CUTLASS python wrapper if requested
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}")
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}")
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
set(CUTLASS_SRC ${PYTHON_SRC_PATH}/cutlass.cc)
add_definitions(-DWITH_CUTLASS_BINDINGS)
set(CUTLASS_LIBRARIES "cutlass.a")
endif()
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS})
link_directories(${PYTHON_LINK_DIRS})
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc)
endif()

1
deps/dlfcn-win32 vendored

Submodule deps/dlfcn-win32 deleted from 522c301ec3

View File

@@ -59,13 +59,13 @@ For example, a row-major coalesced layout may distribute a 64x16 tensor over 2 w
thread tile size 2
- - - - - - /\ - - - - - -
block| thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3]
warp | thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3]
tile | tile size 2 || A_{1, 0}[T0] A_{1, 1}[T0] ... A_{1, 6}[T3] A_{1, 7}[T3] A_{1, 8}[T0] A_{1, 9}[T0] ... A_{1, 14}[T3] A_{1, 15}[T3]
size } ....
32 | A_{30, 0}[T60] A_{14, 1}[T60] ... A_{14, 6}[T63] A_{14, 7}[T63] A_{14, 8}[T60] A_{14, 9}[T60] ... A_{14, 14}[T63] A_{14, 15}[T63]
| A_{31, 0}[T60] A_{15, 1}[T60] ... A_{15, 6}[T63] A_{15, 7}[T63] A_{15, 8}[T60] A_{15, 9}[T60] ... A_{15, 14}[T63] A_{15, 15}[T63]
-----------------------------/\-----------------------------------
block tile size 8
warp tile size 8
A_{32, 0}[T0] A_{32, 1}[T0] ... A_{32, 6}[T3] A_{32, 7}[T3] A_{32, 8}[T0] A_{32, 9}[T0] ... A_{32, 14}[T3] A_{32, 15}[T3]

View File

@@ -1,80 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
#include <map>
#include <vector>
namespace triton {
namespace ir {
class value;
class module;
class phi_node;
class splat_inst;
class cast_inst;
class reshape_inst;
class broadcast_inst;
class binary_operator;
class getelementptr_inst;
}
namespace codegen{
namespace analysis{
class align {
private:
struct cst_info {
unsigned num_cst;
unsigned value;
};
// helpers
std::vector<unsigned> get_shapes(ir::value *v);
// populate is_constant
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
std::vector<cst_info> populate_is_constant_default(ir::value* v);
std::vector<cst_info> populate_is_constant(ir::value *v);
// populate max_contiguous
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
std::vector<unsigned> populate_max_contiguous(ir::value *v);
// populate starting_multiple
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
std::vector<unsigned> populate_starting_multiple(ir::value *v);
// populate all maps
void populate(ir::value *v);
public:
void run(ir::module &mod);
unsigned get(ir::value* v, unsigned ax) const;
std::vector<unsigned> contiguous(ir::value* v) const;
private:
std::map<ir::value*, std::vector<cst_info>> is_constant_;
std::map<ir::value*, std::vector<unsigned>> max_contiguous_;
std::map<ir::value*, std::vector<unsigned>> starting_multiple_;
};
}
}
}
#endif

View File

@@ -1,47 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
#define TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
#include <map>
#include <set>
#include <iostream>
#include "triton/codegen/analysis/liveness.h"
namespace triton{
namespace ir{
class value;
class function;
class module;
}
namespace codegen{
namespace analysis{
class tiles;
class liveness;
class cts;
class allocation {
public:
allocation(liveness *live)
: liveness_(live) { }
// accessors
bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); }
unsigned offset(const data_layout *x) const { return offsets_.at(x); }
unsigned allocated_size() const { return allocated_size_; }
// run
void run(ir::module& mod);
private:
std::map<const data_layout*, unsigned> offsets_;
size_t allocated_size_;
// dependences
liveness *liveness_;
};
}
}
}
#endif

View File

@@ -1,52 +0,0 @@
#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_
#define _TRITON_CODEGEN_ANALYSIS_AXES_H_
#include "triton/tools/graph.h"
#include <map>
#include <vector>
namespace triton{
namespace ir{
class value;
class module;
class instruction;
}
namespace codegen{
namespace analysis{
class axes {
typedef std::pair<ir::value*, unsigned> node_t;
private:
// update graph
void update_graph_store(ir::instruction *i);
void update_graph_reduce(ir::instruction *i);
void update_graph_reshape(ir::instruction *i);
void update_graph_trans(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i,
bool is_masked_load_async=false);
void update_graph_no_edge(ir::instruction *i);
void update_graph(ir::instruction *i);
public:
axes();
void run(ir::module &mod);
// accessors
int get(ir::value *value, unsigned dim);
std::vector<int> get(ir::value *value);
private:
tools::graph<node_t> graph_;
std::map<node_t, size_t> axes_;
};
}
}
}
#endif

View File

@@ -1,345 +0,0 @@
#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_
#define _TRITON_CODEGEN_ANALYSIS_GRID_H_
#include <map>
#include <set>
#include <vector>
#include <memory>
#include "triton/tools/graph.h"
#include "triton/codegen/target.h"
namespace triton{
namespace ir{
class value;
class type;
class module;
class instruction;
class phi_node;
}
namespace codegen{
namespace analysis{
class axes;
class align;
class layout_visitor;
class data_layout;
class mma_layout;
class scanline_layout;
class shared_layout;
class layout_visitor {
public:
virtual void visit_layout(data_layout *);
virtual void visit_layout_mma(mma_layout*) = 0;
virtual void visit_layout_scanline(scanline_layout*) = 0;
virtual void visit_layout_shared(shared_layout*) = 0;
};
class data_layout {
protected:
enum id_t {
MMA,
SCANLINE,
SHARED
};
typedef std::vector<int> axes_t;
typedef std::vector<unsigned> shape_t;
typedef std::vector<int> order_t;
typedef std::vector<ir::value*> values_t;
private:
template<typename T>
T* downcast(id_t id) {
if(id_ == id)
return static_cast<T*>(this);
return nullptr;
}
public:
data_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align);
// visitor
virtual void accept(layout_visitor* vst) = 0;
// downcast
mma_layout* to_mma() { return downcast<mma_layout>(MMA); }
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
// accessors
size_t get_rank() { return shape_.size(); }
const shape_t& get_shape() const { return shape_; }
const order_t& get_order() const { return order_; }
const values_t& get_values() const { return values_;}
int get_axis(size_t k) const { return axes_.at(k); }
std::vector<int> get_axes() const { return axes_; }
const int get_order(size_t k) const { return order_.at(k); }
// find the position of given axis
int find_axis(int to_find) const;
private:
id_t id_;
axes_t axes_;
values_t values_;
protected:
order_t order_;
shape_t shape_;
};
class distributed_layout: public data_layout{
public:
distributed_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value*>& values,
analysis::align* align);
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
virtual int contig_per_thread(size_t k) = 0;
protected:
std::vector<int> shape_per_cta_;
};
class mma_layout: public distributed_layout {
public:
enum TensorCoreType : uint8_t {
// floating-point tensor core instr
FP32_FP16_FP16_FP32 = 0, // default
FP32_BF16_BF16_FP32,
FP32_TF32_TF32_FP32,
// integer tensor core instr
INT32_INT1_INT1_INT32, // Not implemented
INT32_INT4_INT4_INT32, // Not implemented
INT32_INT8_INT8_INT32, // Not implemented
//
NOT_APPLICABLE,
};
// Used on nvidia GPUs with sm >= 80
inline static const std::map<TensorCoreType, std::vector<int>> mma_instr_shape_ = {
{FP32_FP16_FP16_FP32, {16, 8, 16}},
{FP32_BF16_BF16_FP32, {16, 8, 16}},
{FP32_TF32_TF32_FP32, {16, 8, 8}},
{INT32_INT1_INT1_INT32, {16, 8, 256}},
{INT32_INT4_INT4_INT32, {16, 8, 64}},
{INT32_INT8_INT8_INT32, {16, 8, 32}},
};
// shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices)
inline static const std::map<TensorCoreType, std::vector<int>> mma_mat_shape_ = {
{FP32_FP16_FP16_FP32, {8, 8, 8}},
{FP32_BF16_BF16_FP32, {8, 8, 8}},
{FP32_TF32_TF32_FP32, {8, 8, 4}},
{INT32_INT1_INT1_INT32, {8, 8, 64}},
{INT32_INT4_INT4_INT32, {8, 8, 32}},
{INT32_INT8_INT8_INT32, {8, 8, 16}},
};
inline static const std::map<TensorCoreType, std::string> mma_instr_ptx_ = {
{FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
{FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"},
{FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"},
{INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"},
{INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"},
{INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
};
// vector length per ldmatrix (16*8/elelment_size_in_bits)
inline static const std::map<TensorCoreType, int> mma_instr_vec_ = {
{FP32_FP16_FP16_FP32, 8},
{FP32_BF16_BF16_FP32, 8},
{FP32_TF32_TF32_FP32, 4},
{INT32_INT1_INT1_INT32, 128},
{INT32_INT4_INT4_INT32, 32},
{INT32_INT8_INT8_INT32, 16},
};
public:
mma_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values,
analysis::align* align, target *tgt,
shared_layout* layout_a,
shared_layout* layout_b,
ir::value *dot);
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
// accessor
int fpw(size_t k) { return fpw_.at(k); }
int wpt(size_t k) { return wpt_.at(k); }
int spw(size_t k) { return spw_.at(k); }
int rep(size_t k) { return rep_.at(k); }
int contig_per_thread(size_t k) { return contig_per_thread_.at(k); }
// helpers for generator.cc
std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); }
std::vector<int> get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); }
std::vector<int> get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); }
int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); }
int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); }
// setter
void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; }
private:
// fragment per warp
std::vector<int> fpw_;
// shape per warp
std::vector<int> spw_;
// warp per tile
std::vector<int> wpt_;
// shape per tile
std::vector<int> spt_;
// repetitions
std::vector<int> rep_;
// contiguous per thread
std::vector<int> contig_per_thread_;
TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
};
struct scanline_layout: public distributed_layout {
scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align,
target* tgt);
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
// accessor
int mts(size_t k) { return mts_.at(k); }
int nts(size_t k) { return nts_.at(k); }
int contig_per_thread(size_t k) { return nts_.at(k); }
public:
// micro tile size. The size of a tile held by a thread block.
std::vector<int> mts_;
// nano tile size. The size of a tile held by a thread.
std::vector<int> nts_;
};
struct double_buffer_info_t {
ir::value* first;
ir::value* latch;
ir::phi_node* phi;
};
struct N_buffer_info_t {
std::vector<ir::value*> firsts; // not necessarily ordered as input order
ir::value* latch;
ir::phi_node* phi;
std::map<ir::value*, int> firsts_idx;
};
// abstract for dot and coresponding smem values
class shared_layout: public data_layout {
private:
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
static void extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t>& res, int &prev_stages);
public:
shared_layout(data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_,
ir::type *ty,
analysis::align* align, target *tgt);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors
size_t get_size() { return size_; }
ir::type* get_type() { return ty_; }
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); }
int get_num_stages() const;
size_t get_per_stage_size() const { return size_ / get_num_stages(); }
size_t get_per_stage_elements() const;
size_t get_num_per_phase() { return num_per_phase_; }
ir::value* hmma_dot_a() { return hmma_dot_a_; }
ir::value* hmma_dot_b() { return hmma_dot_b_; }
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
int get_mma_vec() { return mma_vec_;}
int get_mma_strided() { return mma_strided_; }
bool allow_swizzle() const { return allow_swizzle_; }
data_layout* get_arg_layout() { return arg_layout_; }
private:
size_t size_;
ir::type *ty_;
std::shared_ptr<double_buffer_info_t> double_buffer_;
std::shared_ptr<N_buffer_info_t> N_buffer_;
size_t num_per_phase_;
ir::value* hmma_dot_a_;
ir::value* hmma_dot_b_;
data_layout* arg_layout_;
int mma_vec_;
int mma_strided_;
bool allow_swizzle_ = true;
target *tgt_;
};
class layouts {
typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
private:
// graph creation
void connect(ir::value *x, ir::value *y);
void make_graph(ir::instruction *i);
void init_hmma_tile(data_layout& layouts);
void init_scanline_tile(data_layout &layouts);
void create(size_t id, const std::vector<ir::value*>& values);
public:
// constructor
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
// accessors
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const { return values_.size();}
data_layout* get(size_t id) { return layouts_.at(id); }
data_layout* get(ir::value *v) { return get(layout_of(v));}
std::map<size_t, data_layout*> &get_all() { return layouts_; }
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
int tmp(ir::value* i) { return tmp_.at(i);}
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
// execution
void run(ir::module &mod);
private:
analysis::axes* axes_;
analysis::align* align_;
size_t num_warps_;
target* tgt_;
tools::graph<ir::value*> graph_;
std::map<ir::value*, size_t> groups_;
std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, data_layout*> layouts_;
std::map<ir::value*, size_t> tmp_;
};
}
}
}
#endif

View File

@@ -1,67 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#include <map>
#include <set>
#include <vector>
#include "triton/codegen/analysis/layout.h"
#include "triton/tools/graph.h"
namespace triton{
namespace ir{
class value;
class phi_node;
class function;
class module;
class instruction;
}
namespace codegen{
namespace analysis{
typedef unsigned slot_index;
class tiles;
class layouts;
class data_layout;
struct segment {
slot_index start;
slot_index end;
bool contains(slot_index idx) const {
return start <= idx && idx < end;
}
bool intersect(const segment &Other){
return contains(Other.start) || Other.contains(start);
}
};
class liveness {
private:
typedef std::map<shared_layout*, segment> intervals_map_t;
public:
// constructor
liveness(layouts *l): layouts_(l){ }
// accessors
const intervals_map_t& get() const { return intervals_; }
segment get(shared_layout* v) const { return intervals_.at(v); }
// run
void run(ir::module &mod);
private:
// analysis
layouts *layouts_;
intervals_map_t intervals_;
};
}
}
}
#endif

View File

@@ -1,43 +0,0 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
#include <map>
namespace triton{
namespace ir{
class module;
}
namespace codegen{
class target;
namespace analysis{
class layouts;
class data_layout;
class swizzle {
public:
// constructor
swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ }
// accessors
int get_per_phase(data_layout* layout) { return per_phase_.at(layout); }
int get_max_phase(data_layout* layout) { return max_phase_.at(layout); }
int get_vec (data_layout* layout) { return vec_.at(layout); }
// run
void run(ir::module &mod);
private:
layouts* layouts_;
target* tgt_;
std::map<data_layout*, int> per_phase_;
std::map<data_layout*, int> max_phase_;
std::map<data_layout*, int> vec_;
};
}
}
}
#endif

View File

@@ -1,42 +0,0 @@
#ifndef _TRITON_CODEGEN_PASS_H_
#define _TRITON_CODEGEN_PASS_H_
#include <memory>
namespace llvm{
class Module;
class LLVMContext;
}
namespace triton{
namespace codegen {
class target;
}
namespace ir{
class module;
}
namespace driver{
class device;
class module;
class kernel;
}
}
namespace triton{
namespace codegen{
// TODO:
// There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
codegen::target* target,
int sm, int num_warps,
int num_stages, int &shared_static);
}
}
#endif

View File

@@ -1,258 +0,0 @@
#pragma once
#ifndef _TRITON_SELECTION_GENERATOR_H_
#define _TRITON_SELECTION_GENERATOR_H_
#include "triton/ir/visitor.h"
#include "triton/codegen/analysis/layout.h"
#include <functional>
// forward
namespace llvm{
class Type;
class Value;
class PHINode;
class BasicBlock;
class Attribute;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
namespace triton{
namespace ir{
class attribute;
class load_inst;
class store_inst;
}
namespace codegen{
// forward
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layouts;
class swizzle;
}
// typedef
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Attribute Attribute;
typedef llvm::BasicBlock BasicBlock;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
typedef std::vector<Value*> indices_t;
class target;
}
}
namespace triton{
namespace codegen{
struct distributed_axis {
int contiguous;
std::vector<Value*> values;
Value* thread_id;
};
class adder{
public:
adder(Builder** builder): builder_(builder) { }
Value* operator()(Value* x, Value* y, const std::string& name = "");
private:
Builder** builder_;
};
class multiplier{
public:
multiplier(Builder** builder): builder_(builder) { }
Value* operator()(Value* x, Value* y, const std::string& name = "");
private:
Builder** builder_;
};
class geper{
public:
geper(Builder** builder): builder_(builder) { }
Value* operator()(Value *ptr, Value* off, const std::string& name = "");
Value* operator()(Type* ty, Value*ptr, std::vector<Value*> vals, const std::string& name = "");
private:
Builder** builder_;
};
class generator: public ir::visitor, public analysis::layout_visitor {
private:
void init_idx(ir::value *x);
Instruction* add_barrier();
Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx);
void finalize_shared_layout(analysis::shared_layout*);
void finalize_function(ir::function*);
void finalize_phi_node(ir::phi_node*);
private:
Type *cvt(ir::type *ty);
llvm::Attribute cvt(ir::attribute attr);
public:
generator(analysis::axes *a_axes,
analysis::layouts *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
analysis::swizzle *swizzle,
target *tgt,
unsigned num_warps);
void visit_value(ir::value* v);
void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*);
void visit_getelementptr_inst(ir::getelementptr_inst*);
void visit_icmp_inst(ir::icmp_inst*);
void visit_fcmp_inst(ir::fcmp_inst*);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
Value* bf16_to_fp32(Value *in0);
Value* fp32_to_bf16(Value *in0);
void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*);
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
void visit_load_inst(ir::load_inst*);
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
void visit_masked_load_inst(ir::masked_load_inst*);
void visit_store_inst(ir::store_inst*);
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
void visit_masked_store_inst(ir::masked_store_inst*);
void visit_cat_inst(ir::cat_inst*);
void visit_reshape_inst(ir::reshape_inst*);
void visit_splat_inst(ir::splat_inst*);
void visit_broadcast_inst(ir::broadcast_inst*);
void visit_downcast_inst(ir::downcast_inst*);
void visit_exp_inst(ir::exp_inst*);
void visit_cos_inst(ir::cos_inst*);
void visit_umulhi_inst(ir::umulhi_inst* x);
void visit_sin_inst(ir::sin_inst*);
void visit_log_inst(ir::log_inst*);
void visit_get_program_id_inst(ir::get_program_id_inst*);
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
void visit_dot_inst(ir::dot_inst*);
void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*);
Value* shfl_sync(Value* acc, int32_t i);
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);
void visit_layout_convert(ir::value *out, ir::value *in);
void visit_cvt_layout_inst(ir::cvt_layout_inst*);
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
void visit_barrier_inst(ir::barrier_inst*);
void visit_prefetch_s_inst(ir::prefetch_s_inst*);
void visit_async_wait_inst(ir::async_wait_inst*);
// void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
// void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
void visit_constant_fp(ir::constant_fp*);
void visit_alloc_const(ir::alloc_const*);
void visit_function(ir::function*);
void visit_basic_block(ir::basic_block*);
void visit_argument(ir::argument*);
void visit(ir::module &, llvm::Module &);
// layouts
void visit_layout_mma(analysis::mma_layout*);
void visit_layout_scanline(analysis::scanline_layout*);
void visit_layout_shared(analysis::shared_layout*);
private:
LLVMContext *ctx_;
Builder* builder_;
Module *mod_;
analysis::axes *a_axes_;
analysis::swizzle *swizzle_;
std::map<unsigned, distributed_axis> axes_;
target *tgt_;
analysis::layouts *layouts_;
analysis::align *alignment_;
analysis::allocation *alloc_;
Value *shmem_;
std::set<ir::value*> seen_;
unsigned num_warps_;
std::map<analysis::data_layout*, Value*> offset_a_m_;
std::map<analysis::data_layout*, Value*> offset_a_k_;
std::map<analysis::data_layout*, Value*> offset_b_k_;
std::map<analysis::data_layout*, Value*> offset_b_n_;
/// layout -> base ptr
std::map<analysis::data_layout*, Value*> shared_ptr_;
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
/// offset for double-buffered layout
std::map<analysis::data_layout*, Value*> shared_off_;
/// Base shmem pointer of ir value
std::map<ir::value*, Value*> shmems_;
std::map<ir::value*, Value*> shoffs_;
std::map<ir::value*, std::vector<indices_t>> idxs_;
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
/// idx for multi-stage pipeline
std::map<analysis::data_layout*, Value*> read_smem_idx_;
std::map<analysis::data_layout*, Value*> write_smem_idx_;
/// triton bb -> llvm bb
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;
// helper for creating llvm values
adder add;
multiplier mul;
geper gep;
/// PHI nodes
std::vector<std::tuple<llvm::PHINode*, Value*, ir::basic_block*>> lazy_phi_incs_;
/// Record prefetch instrs that needs to be moved
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
};
}
}
#endif

View File

@@ -1,105 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_TARGET_H
#define TDL_INCLUDE_IR_CODEGEN_TARGET_H
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
// typedefs
namespace triton{
namespace codegen{
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
}
}
namespace triton{
namespace codegen{
class nvidia_cu_target;
class target {
public:
target(bool is_gpu): is_gpu_(is_gpu){}
virtual ~target() {}
virtual void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn) = 0;
virtual Instruction* add_barrier(Module *module, Builder& builder) = 0;
virtual Instruction* add_memfence(Module *module, Builder& builder) = 0;
virtual Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax) = 0;
virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
virtual unsigned guaranteed_alignment() = 0;
nvidia_cu_target* as_nvidia();
bool is_gpu() const;
private:
bool is_gpu_;
};
class amd_cl_target: public target {
public:
amd_cl_target(): target(true){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 16; }
};
class nvidia_cu_target: public target {
public:
nvidia_cu_target(int sm): target(true), sm_(sm){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
int sm() { return sm_; }
unsigned guaranteed_alignment() { return 16; }
private:
int sm_;
};
class cpu_target: public target {
public:
cpu_target(): target(false){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 1; }
};
}
}
#endif

View File

@@ -1,48 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
#include <map>
#include <set>
#include <vector>
namespace triton {
namespace ir {
class module;
class value;
class io_inst;
class instruction;
class builder;
}
namespace codegen{
namespace analysis{
class align;
class layouts;
class cts;
}
namespace transform{
class coalesce {
private:
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public:
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
void run(ir::module &mod);
private:
analysis::align* align_;
analysis::layouts* layout_;
};
}
}
}
#endif

View File

@@ -1,36 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
#define TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
#include <set>
#include <map>
namespace triton {
namespace ir {
class module;
class value;
class phi_node;
class instruction;
class builder;
}
namespace codegen{
namespace transform{
class cts {
private:
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
public:
cts(bool use_async = false): use_async_(use_async) {}
void run(ir::module &mod);
private:
bool use_async_;
};
}
}
}
#endif

View File

@@ -1,24 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
namespace triton {
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class dce {
public:
dce() {}
void run(ir::module &mod);
};
}
}
}
#endif

View File

@@ -1,22 +0,0 @@
#ifndef _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
#define _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
namespace triton {
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class disassociate {
public:
void run(ir::module &mod);
};
}
}
}
#endif

View File

@@ -1,72 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
#include <vector>
#include <map>
#include <list>
#include <set>
#include "triton/codegen/target.h"
namespace triton {
namespace ir {
class module;
class basic_block;
class instruction;
class masked_load_async_inst;
class value;
class builder;
}
namespace codegen{
namespace analysis{
class allocation;
class liveness;
class layouts;
class cts;
class shared_layout;
}
namespace transform{
class prefetch;
class membar {
private:
typedef std::pair<unsigned, unsigned> interval_t;
typedef std::set<ir::value*> val_set_t;
typedef std::vector<ir::value*> val_vec_t;
private:
bool intersect(const val_set_t &X, const val_set_t &Y);
bool check_safe_war(ir::instruction* i);
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout);
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc,
transform::prefetch *prefetch, target* tgt):
liveness_(liveness), layouts_(layouts), alloc_(alloc), prefetch_(prefetch), tgt_(tgt) {}
void run(ir::module &mod);
private:
analysis::liveness *liveness_;
analysis::layouts *layouts_;
analysis::allocation *alloc_;
transform::prefetch *prefetch_;
target* tgt_;
};
}
}
}
#endif

View File

@@ -1,53 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#include "triton/codegen/target.h"
namespace triton {
namespace ir {
class module;
class value;
class instruction;
class trans_inst;
class builder;
class constant_int;
class dot_inst;
}
namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{
class peephole {
private:
// bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder);
public:
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
void run(ir::module &mod);
private:
target* tgt_;
analysis::layouts* layouts_;
};
}
}
}
#endif

View File

@@ -1,30 +0,0 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
// forward declaration
namespace triton {
namespace ir {
class module;
}
} // namespace triton
namespace triton {
namespace codegen {
namespace transform {
class pipeline {
public:
pipeline(bool has_copy_async, int num_stages)
: has_copy_async_(has_copy_async), num_stages_(num_stages) {}
void run(ir::module &module);
private:
bool has_copy_async_;
int num_stages_;
};
} // namespace transform
} // namespace codegen
} // namespace triton
#endif

View File

@@ -1,27 +0,0 @@
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#include <set>
// forward dclaration
namespace triton::ir{
class module;
class value;
}
namespace triton::codegen {
class target;
}
namespace triton::codegen::transform {
class prefetch {
target* tgt_;
std::set<ir::value*> prefetched_vals_;
public:
prefetch(target *tgt) : tgt_(tgt) {}
void run(ir::module &module);
bool is_prefetched(ir::value* v) { return prefetched_vals_.find(v) != prefetched_vals_.end(); }
};
}
#endif

View File

@@ -1,26 +0,0 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H
#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H
namespace triton {
// forward declaration
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class reorder {
public:
void run(ir::module& module);
};
}
}
}
#endif

View File

@@ -1,5 +0,0 @@
file(GLOB_RECURSE CODEGEN_SRC *.cc)
add_library(TritonCodeGen
${CODEGEN_SRC}
)

View File

@@ -1,533 +0,0 @@
#include "triton/codegen/analysis/align.h"
#include "triton/ir/utils.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include <iostream>
namespace triton {
namespace codegen{
namespace analysis{
// Function for extended Euclidean Algorithm
int gcd_impl(int a, int b, int *x, int *y)
{
// Base Case
if (a == 0)
{
*x = 0;
*y = 1;
return b;
}
int x1, y1; // To store results of recursive call
int gcd = gcd_impl(b%a, a, &x1, &y1);
// Update x and y using results of
// recursive call
*x = y1 - (b/a) * x1;
*y = x1;
return gcd;
}
int gcd(int a, int b) {
int x, y;
return gcd_impl(a, b, &x, &y);
}
inline int lcm(int a, int b) {
return (a * b) / gcd(a, b);
}
template<class T>
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
return map[i] = value;
}
/*
* is constant
*/
std::vector<unsigned> align::get_shapes(ir::value *v) {
ir::type *ty = v->get_type();
if(ty->is_block_ty())
return ty->get_block_shapes();
else
return {1};
}
std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) {
auto shapes = get_shapes(x);
std::vector<cst_info> result(shapes.size(), cst_info{1, 0});
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto it = is_constant_.find(inc);
if(it != is_constant_.end())
result = it->second;
}
return add_to_cache(x, result, is_constant_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto cst = populate_is_constant(inc);
for(size_t d = 0; d < cst.size(); d++)
result[d].num_cst = std::min(result[d].num_cst, cst[d].num_cst);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) {
auto shapes = get_shapes(x);
ir::value* op = x->get_operand(0);
std::vector<cst_info> result;
auto op_cst = populate_is_constant(op);
for(auto d: shapes)
result.push_back(cst_info{d, op_cst[0].value});
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_block_shapes();
auto op_cst = populate_is_constant(op);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < x_shapes.size(); d ++){
cst_info ax ;
if(x_shapes[d] == 1)
ax = {1, op_cst[current].value};
else if(!is_skewed
&& x_shapes[d] == op_shapes[current])
ax = {x_shapes[d], op_cst[current++].value};
else {
is_skewed = true;
ax = {x_shapes[d], 0};
}
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_block_shapes();
auto op_cst = populate_is_constant(op);
for(size_t d = 0; d < x_shapes.size(); d++)
if(op_shapes[d] == 1)
result.push_back(cst_info{x_shapes[d], op_cst[d].value});
else
result.push_back(op_cst[d]);
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
auto max_contiguous = populate_max_contiguous(lhs_op);
for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax;
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
// todo might not be entirely true
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
ax = {num_constants, 0};
}
else
ax = {std::min(lhs[d].num_cst, rhs[d].num_cst), 0};
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_gep(ir::getelementptr_inst* x) {
auto x_shapes = get_shapes(x);
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
std::vector<cst_info> result;
for(size_t d = 0; d < x_shapes.size(); d++)
result.push_back({std::min(lhs[d].num_cst, rhs[d].num_cst), 0});
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_default(ir::value *v) {
auto shapes = get_shapes(v);
std::vector<cst_info> result(shapes.size(), {1, 0});
return add_to_cache(v, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
if(is_constant_.find(v) != is_constant_.end())
return is_constant_.at(v);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_is_constant_phi(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_is_constant_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_is_constant_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_is_constant_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_is_constant_binop(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_is_constant_gep(x);
return populate_is_constant_default(v);
}
/*
* max contiguous
*/
std::vector<unsigned> align::populate_max_contiguous_phi(ir::phi_node* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result(shapes.size(), 1);
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto it = max_contiguous_.find(inc);
if(it != max_contiguous_.end())
result = it->second;
}
add_to_cache(x, result, max_contiguous_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto contiguous = populate_max_contiguous(inc);
for(size_t d = 0; d < result.size(); d++)
result[d] = std::min(result[d], contiguous[d]);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_splat(ir::splat_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<unsigned> result;
for(size_t d = 0; d < x_shapes.size(); d++)
result.push_back({1});
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_block_shapes();
auto op_mc = populate_max_contiguous(op);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < shapes.size(); d ++){
if(shapes[d] == 1)
result.push_back(1);
else if(!is_skewed
&& shapes[d] == op_shapes[current])
result.push_back(op_mc[current++]);
else {
is_skewed = true;
result.push_back(1);
}
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_block_shapes();
auto op_mc = populate_max_contiguous(op);
for(size_t d = 0; d < shapes.size(); d++)
if(op_shapes[d] == 1)
result.push_back(1);
else
result.push_back(op_mc[d]);
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator* x) {
auto shapes = get_shapes(x);
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
auto lhs_max_contiguous = populate_max_contiguous(lhs);
auto rhs_max_contiguous = populate_max_contiguous(rhs);
auto lhs_cst_info = populate_is_constant(lhs);
auto rhs_cst_info = populate_is_constant(rhs);
auto lhs_starting_multiple = populate_starting_multiple(lhs);
auto rhs_starting_multiple = populate_starting_multiple(rhs);
std::vector<unsigned> result;
for(size_t d = 0; d < shapes.size(); d++){
unsigned value = 1;
if(x->is_int_rem() && rhs_starting_multiple[d] > 0){
value = std::min(lhs_max_contiguous[d], rhs_starting_multiple[d]);
}
if(x->is_int_mult()){
unsigned lvalue = 1, rvalue = 1;
if(rhs_cst_info[d].value == 1)
lvalue = lhs_max_contiguous[d];
if(lhs_cst_info[d].value == 1)
rvalue = rhs_max_contiguous[d];
value = std::max(lvalue, rvalue);
}
if(x->is_int_add_sub()){
unsigned lvalue = 1, rvalue = 1;
lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]);
rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]);
value = std::max(lvalue, rvalue);
}
result.push_back(value);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst* x) {
auto shapes = get_shapes(x);
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
auto lhs_max_contiguous = populate_max_contiguous(lhs);
auto rhs_max_contiguous = populate_max_contiguous(rhs);
auto lhs_cst_info = populate_is_constant(lhs);
auto rhs_cst_info = populate_is_constant(rhs);
std::vector<unsigned> result(shapes.size(), 1);
for(size_t d = 0; d < shapes.size(); d++){
unsigned lvalue = 1, rvalue = 1;
if(lhs_cst_info[d].num_cst)
lvalue = rhs_max_contiguous[d];
if(rhs_cst_info[d].num_cst)
rvalue = lhs_max_contiguous[d];
result[d] = std::max(lvalue, rvalue);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
if(!v->get_type()->is_block_ty())
return add_to_cache(v, {1}, max_contiguous_);
auto shapes = v->get_type()->get_block_shapes();
if(dynamic_cast<ir::make_range*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){
auto result = populate_max_contiguous(v->get_operand(0));
return add_to_cache(v, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end())
return max_contiguous_.at(v);
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
if(max_contiguous > 0)
return add_to_cache(x, {max_contiguous}, max_contiguous_);
}
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_max_contiguous_cast(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_max_contiguous_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_max_contiguous_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_max_contiguous_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_max_contiguous_binop(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_max_contiguous_gep(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_max_contiguous_phi(x);
return populate_max_contiguous_default(v);
}
/*
* starting multiple
*/
std::vector<unsigned> align::populate_starting_multiple_splat(ir::splat_inst* x){
auto shapes = get_shapes(x);
auto op = populate_starting_multiple(x->get_operand(0));
std::vector<unsigned> result(shapes.size(), op[0]);
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst* x){
auto op = populate_starting_multiple(x->get_operand(0));
auto op_shapes = get_shapes(x->get_operand(0));
auto shapes = get_shapes(x);
std::vector<unsigned> result(shapes.size(), 1);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < shapes.size(); d ++){
if(shapes[d] == 1)
result[d] = 1;
else if(!is_skewed
&& shapes[d] == op_shapes[current])
result[d] = op[current++];
else {
is_skewed = true;
result[d] = 1;
}
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
auto result = populate_starting_multiple(x->get_operand(0));
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operator* x){
auto lhs = populate_starting_multiple(x->get_operand(0));
auto rhs = populate_starting_multiple(x->get_operand(1));
std::vector<unsigned> result(lhs.size(), 1);
for(size_t d = 0; d < lhs.size(); d++){
if(x->is_int_mult())
result[d] = lhs[d] * rhs[d];
if(x->is_int_add_sub())
result[d] = gcd(lhs[d], rhs[d]);
if(x->is_int_div())
result[d] = 1;
if(x->is_int_rem() && rhs[d] > 1){
result[d] = gcd(lhs[d], rhs[d]);
}
if(x->is_shl())
result[d] = lhs[d] << rhs[d];
if(x->is_shr())
result[d] = std::max<unsigned>(lhs[d] >> rhs[d], 1);
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_gep(ir::getelementptr_inst* x){
auto lhs = populate_starting_multiple(x->get_operand(0));
auto rhs = populate_starting_multiple(x->get_operand(1));
std::vector<unsigned> result(lhs.size(), 1);
for(size_t d = 0; d < lhs.size(); d++){
result[d] = gcd(lhs[d], rhs[d]);
// std::cout << "starting multiple: " << x->get_name() << " " << d << " " << result[d] << std::endl;
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
auto shape = get_shapes(x);
std::vector<unsigned> result(shape.size(), 1);
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
if(starting_multiple_.find(inc) != starting_multiple_.end())
result = starting_multiple_.at(inc);
}
add_to_cache(x, result, starting_multiple_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto sm = populate_starting_multiple(inc);
for(size_t d = 0; d < result.size(); d++)
result[d] = gcd(result[d], sm[d]);
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){
auto result = populate_starting_multiple(x->get_operand(0));
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
ir::type* ty = v->get_type();
if(ty->is_block_ty()) {
return add_to_cache(v, ty->get_block_shapes(), starting_multiple_);
}
if(auto *x = dynamic_cast<ir::argument*>(v)){
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
for(auto attr: attributes){
if(attr.get_kind() == ir::multiple_of){
return add_to_cache(x, {attr.get_value()}, starting_multiple_);
}
if(attr.get_kind() == ir::aligned){
ir::type* ty = x->get_type()->get_pointer_element_ty();
int nbits = ty->get_primitive_size_in_bits();
int nbytes = std::max<int>(nbits / 8, 1);
return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
}
}
}
return add_to_cache(v, {1}, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v);
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0)
return add_to_cache(x, {multiple_of}, starting_multiple_);
}
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_starting_multiple_cast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_starting_multiple_binop(x);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range*>(v))
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_starting_multiple_gep(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_starting_multiple_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_starting_multiple_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_starting_multiple_broadcast(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_starting_multiple_phi(x);
return populate_starting_multiple_default(v);
}
unsigned align::get(ir::value *v, unsigned ax) const {
unsigned starting_multiple = starting_multiple_.at(v)[ax];
unsigned max_contiguous = max_contiguous_.at(v)[ax];
return std::min(starting_multiple, max_contiguous);
}
std::vector<unsigned> align::contiguous(ir::value* v) const {
return max_contiguous_.at(v);
}
void align::populate(ir::value *v) {
populate_is_constant(v);
populate_starting_multiple(v);
populate_max_contiguous(v);
}
void align::run(ir::module &mod) {
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
// ir::for_each_value(mod, [this](ir::value* v) {
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
// std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl;
// });
}
}
}
}

View File

@@ -1,101 +0,0 @@
#include <algorithm>
#include <climits>
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace analysis{
void allocation::run(ir::module &mod) {
using std::max;
using std::min;
typedef std::multimap<unsigned, segment> triples_map_type;
std::vector<shared_layout*> I;
for(auto x: liveness_->get())
I.push_back(x.first);
std::vector<shared_layout*> J = I;
triples_map_type H;
H.insert({0, segment{0, INT_MAX}});
std::vector<shared_layout*> V;
std::map<shared_layout*, unsigned> starts;
while(!J.empty()){
auto h_it = H.begin();
unsigned w = h_it->first;
segment xh = h_it->second;
H.erase(h_it);
auto j_it = std::find_if(J.begin(), J.end(), [&](shared_layout* JJ){
segment xj = liveness_->get(JJ);
bool res = xj.intersect(xh);
for(auto val: H)
res = res && !val.second.intersect(xj);
return res;
});
if(j_it != J.end()){
unsigned size = (*j_it)->get_size();
segment xj = liveness_->get(*j_it);
starts[*j_it] = w;
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
if(xh.start < xj.start)
H.insert({w, segment{xh.start, xj.end}});
if(xj.end < xh.end)
H.insert({w, segment{xj.start, xh.end}});
V.push_back(*j_it);
J.erase(j_it);
}
}
// Build interference graph
std::map<shared_layout*, std::set<shared_layout*>> interferences;
for(shared_layout* x: V)
for(shared_layout* y: V){
if(x == y)
continue;
unsigned X0 = starts[x], Y0 = starts[y];
unsigned NX = x->get_size();
unsigned NY = y->get_size();
segment XS = {X0, X0 + NX};
segment YS = {Y0, Y0 + NY};
if(liveness_->get(x).intersect(liveness_->get(y))
&& XS.intersect(YS))
interferences[x].insert(y);
}
// Initialize colors
std::map<shared_layout*, int> colors;
for(shared_layout* X: V)
colors[X] = (X==V[0])?0:-1;
// First-fit graph coloring
std::vector<bool> available(V.size());
for(shared_layout* x: V){
// Non-neighboring colors are available
std::fill(available.begin(), available.end(), true);
for(shared_layout* Y: interferences[x]){
int color = colors[Y];
if(color >= 0)
available[color] = false;
}
// Assigns first available color
auto It = std::find(available.begin(), available.end(), true);
colors[x] = std::distance(available.begin(), It);
}
// Finalize allocation
for(shared_layout* x: V){
unsigned Adj = 0;
for(shared_layout* y: interferences[x])
Adj = std::max<unsigned>(Adj, starts[y] + y->get_size());
offsets_[x] = starts[x] + colors[x] * Adj;
}
// Save maximum size of induced memory space
allocated_size_ = 0;
for(shared_layout* x: V)
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
}
}
}
}

View File

@@ -1,162 +0,0 @@
#include "triton/codegen/analysis/axes.h"
#include "triton/ir/utils.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include <iostream>
namespace triton{
namespace codegen{
namespace analysis{
axes::axes() {}
void axes::update_graph_reduce(ir::instruction *i) {
auto* red = static_cast<ir::reduce_inst*>(i);
unsigned axis = red->get_axis();
ir::value *arg = red->get_operand(0);
auto in_shapes = arg->get_type()->get_block_shapes();
unsigned current = 0;
for(unsigned d = 0; d < in_shapes.size(); d++){
if(d == axis)
continue;
graph_.add_edge({i, current++}, {arg, d});
}
}
void axes::update_graph_reshape(ir::instruction *i) {
auto* reshape = static_cast<ir::reshape_inst*>(i);
// operands
ir::value *op = reshape->get_operand(0);
// shapes
auto op_shapes = op->get_type()->get_block_shapes();
auto res_shapes = reshape->get_type()->get_block_shapes();
// construct edges
unsigned current = 0;
bool is_skewed = false;
for(unsigned d = 0; d < res_shapes.size(); d ++){
bool same_shape = res_shapes[d] == op_shapes[current];
// either add edge between axis or just add a node in the graph
if(!is_skewed && same_shape)
graph_.add_edge({i, d}, {op, current++});
else
graph_.add_edge({i, d}, {i, d});
// reshaping is skewed
if(res_shapes[d] > 1 && !same_shape)
is_skewed = true;
}
}
void axes::update_graph_trans(ir::instruction *i) {
auto *trans = static_cast<ir::trans_inst*>(i);
ir::value *op = trans->get_operand(0);
auto perm = trans->get_perm();
// add edge between axis perm[d] and axis d
for(unsigned d = 0; d < perm.size(); d++)
graph_.add_edge({i, perm[d]}, {op, d});
}
void axes::update_graph_broadcast(ir::instruction *i) {
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
auto shapes = broadcast->get_type()->get_block_shapes();
ir::value *op = broadcast->get_operand(0);
ir::type *op_ty = op->get_type();
const auto& op_shapes = op_ty->get_block_shapes();
// add edge between non-broadcast axes
for(unsigned d = 0; d < shapes.size(); d ++)
if(op_shapes[d] == shapes[d])
graph_.add_edge({i, d}, {op, d});
}
void axes::update_graph_dot(ir::instruction *i) {
auto *dot = static_cast<ir::dot_inst*>(i);
auto shapes = dot->get_type()->get_block_shapes();
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
// add edges between result and accumulator
for(unsigned d = 0; d < shapes.size(); d++)
graph_.add_edge({dot, d}, {D, d});
}
void axes::update_graph_elementwise(ir::instruction *i,
bool is_masked_load_async) {
if(i->get_num_operands() == 0)
return;
ir::value *op = i->get_operand(0);
if(!op->get_type()->is_block_ty())
return;
auto rank = op->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++) {
// If we are dealing with a masked async load we need to attach the
// dimensions so we match the behaviour of the copy_to_shared instruction
// which async masked load replaces.
if (is_masked_load_async) {
graph_.add_edge({i, d}, {i, d});
}
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()) {
if(!is_masked_load_async && !i->get_type()->is_void_ty())
graph_.add_edge({i, d}, {opx, d});
graph_.add_edge({opx, d}, {opy, d});
}
}
}
void axes::update_graph_no_edge(ir::instruction *i) {
if(!i->get_type()->is_block_ty())
return;
auto rank = i->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++)
graph_.add_edge({i, d}, {i, d});
}
void axes::update_graph(ir::instruction *i) {
switch (i->get_id()) {
case ir::INST_REDUCE: return update_graph_reduce(i);
case ir::INST_RESHAPE: return update_graph_reshape(i);
case ir::INST_SPLAT: return update_graph_no_edge(i);
case ir::INST_CAT: return update_graph_elementwise(i, true);
case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true);
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i);
default: return update_graph_elementwise(i);
}
return;
}
int axes::get(ir::value *value, unsigned dim) {
return axes_.at({value, dim});
}
std::vector<int> axes::get(ir::value *value) {
std::vector<int> result;
for(size_t d = 0; d < value->get_type()->get_tile_rank(); d++)
result.push_back(this->get(value, d));
return result;
}
void axes::run(ir::module &mod) {
// make graph
graph_.clear();
axes_.clear();
ir::for_each_instruction(mod, [this](ir::instruction *x) {
update_graph(x);
});
// find connected components
graph_.connected_components(nullptr, &axes_);
std::set<size_t> uniq;
for(auto x: axes_)
uniq.insert(x.second);
}
}
}
}

View File

@@ -1,653 +0,0 @@
#include <algorithm>
#include <numeric>
#include <iostream>
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
// #include "triton/ir/type.h"
namespace triton{
namespace codegen{
namespace analysis{
/* -------------------------------- *
* Helper Functions *
* -------------------------------- */
inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
unsigned lo = std::min(a, b);
unsigned hi = std::max(a, b);
return std::min(std::max(x, lo), hi);
}
inline bool is_hmma_c(ir::value *v, int sm){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
x->allow_tf32() && sm >= 80) ||
(a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&
sm >= 80);
}
return result;
}
static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
mma_layout::TensorCoreType mma_type;
if (auto* dot = dynamic_cast<ir::dot_inst*>(v)) {
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
ir::type* a_ty = a->get_type();
ir::type* b_ty = b->get_type();
ir::type* c_ty = v->get_type();
if (c_ty->get_scalar_ty()->is_fp32_ty()) {
// floating point tensor cores
if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) {
mma_type = mma_layout::FP32_FP16_FP16_FP32;
return mma_type;
}
if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) {
mma_type = mma_layout::FP32_BF16_BF16_FP32;
return mma_type;
}
if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty()
&& dot->allow_tf32()) {
mma_type = mma_layout::FP32_TF32_TF32_FP32;
return mma_type;
}
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
// throw std::runtime_error("integer tensor cores are not yet supported");
// // integer tensor cores
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
// return mma_type;
// }
// if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) {
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
// return mma_type;
// }
if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
mma_type = mma_layout::INT32_INT8_INT8_INT32;
return mma_type;
}
}
}
return mma_layout::NOT_APPLICABLE;
}
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(v);
}
}
inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && i->get_operand(n) == v)
result = v;
}
}
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) {
result = i;
}
}
}
inline bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
bool result = true;
for(ir::value *op: phi->ops())
result = result && is_trans(op);
return result;
}
return false;
}
/* -------------------------------- *
* Layout Visitor *
* -------------------------------- */
void layout_visitor::visit_layout(data_layout *layout) {
layout->accept(this);
}
/* -------------------------------- *
* Base Data Layout *
* -------------------------------- */
data_layout::data_layout(id_t id,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) {
// io pointer
std::set<ir::value*> ptr;
for(ir::value* v: values_)
extract_io_use(v, ptr);
order_.resize(axes_.size());
std::iota(order_.begin(), order_.end(), 0);
std::vector<unsigned> max_contiguous;
for(ir::value* p: ptr){
std::vector<unsigned> curr = align->contiguous(p);
if(curr.size() > max_contiguous.size())
max_contiguous = curr;
else if(curr.size() == max_contiguous.size()){
if(*std::max_element(curr.begin(), curr.end()) > *std::max_element(max_contiguous.begin(), max_contiguous.end()))
max_contiguous = curr;
}
}
if(max_contiguous.size() > 0){
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b];
});
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
// std::cout << order_[0] << " " << order_[1] << std::endl;
}
}
int data_layout::find_axis(int to_find) const {
auto it = std::find(axes_.begin(), axes_.end(), to_find);
if(it == axes_.end())
return -1;
return std::distance(axes_.begin(), it);
}
distributed_layout::distributed_layout(id_t id,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(id, axes, shape, values, align)
{ }
/* -------------------------------- *
* MMA Layout *
* -------------------------------- */
mma_layout::mma_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align, target* tgt,
shared_layout *layout_a, shared_layout *layout_b,
ir::value *dot): distributed_layout(MMA, axes, shape, values, align) {
tensor_core_type_ = get_mma_type(dot);
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
if(tgt->as_nvidia()->sm() < 80){
fpw_ = {2, 2, 1};
auto ord_a = layout_a->get_order();
auto ord_b = layout_b->get_order();
bool is_a_row = ord_a[0] != 0;
bool is_b_row = ord_b[0] != 0;
bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16);
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
contig_per_thread_ = {1, 1};
}
else{
// fpw_ = {1, 1, 1};
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
contig_per_thread_ = {1, 2};
// rep_ = {2, 2, 1};
}
order_ = {0, 1};
/* warps per tile */
wpt_ = {1, 1, 1};
// try to make warp-level tiles as square as possible to maximize data re-use
if (tgt->as_nvidia()->sm() < 80) {
std::vector<int> wpt_nm1;
do{
wpt_nm1 = wpt_;
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
}while(wpt_nm1 != wpt_);
} else {
bool changed = false;
do {
changed = false;
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
break;
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
if (wpt_[0] < shape_[0] / spw_[0]) {
wpt_[0] *= 2;
changed = true;
}
} else {
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
wpt_[1] *= 2;
changed = true;
}
}
} while (changed);
}
/* shape per block */
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
}
/* -------------------------------- *
* Scanline Layout *
* -------------------------------- */
scanline_layout::scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align, target *tgt): distributed_layout(SCANLINE, axes, shape, values, align){
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1;
nts_.resize(shape_.size());
mts_.resize(shape_.size());
bool is_dot = std::any_of(values.begin(), values.end(),
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
std::vector<ir::value*> ptrs;
for(ir::value *v: values)
for(ir::user *usr: v->get_users())
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
if(ptrs.empty() || ptrs[0]->get_type()->get_tile_rank() <= io->get_pointer_operand()->get_type()->get_tile_rank())
ptrs.push_back(io->get_pointer_operand());
}
unsigned i = order_[0];
int contiguous = 1;
for(ir::value* ptr: ptrs){
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
}
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
size /= shape_[i];
num_threads /= mts_[i];
if(is_dot)
nts_[order_[1]] = clamp(size / num_threads, 1, std::min<int>(4, shape_[order_[1]]));
for(size_t d = 1; d < shape_.size(); d++){
i = order_[d];
if(d > 1 || !is_dot)
nts_[i] = 1;
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
num_threads = num_threads / mts_[i];
}
shape_per_cta_.resize(shape_.size());
for(size_t d = 0; d < shape_.size(); d++)
shape_per_cta_[d] = mts_[d]*nts_[d];
}
/* -------------------------------- *
* Shared Layout *
* -------------------------------- */
bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent())
return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
return br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
return false;
else
throw std::runtime_error("unreachable");
}
void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
auto* phi = dynamic_cast<ir::phi_node*>(v);
if(!phi || phi->get_num_incoming() != 2)
return;
ir::basic_block *block_0 = phi->get_incoming_block(0);
ir::basic_block *block_1 = phi->get_incoming_block(1);
ir::instruction *terminator_0 = block_0->get_inst_list().back();
ir::instruction *terminator_1 = block_1->get_inst_list().back();
bool is_latch_0 = is_loop_latch(phi, terminator_0);
bool is_latch_1 = is_loop_latch(phi, terminator_1);
ir::value *value_0 = phi->get_incoming_value(0);
ir::value *value_1 = phi->get_incoming_value(1);
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
if(!(i_0 && !i_1) &&
!(dynamic_cast<ir::copy_to_shared_inst*>(i_0) && dynamic_cast<ir::copy_to_shared_inst*>(i_1)) &&
!(dynamic_cast<ir::masked_load_async_inst*>(i_0) && dynamic_cast<ir::masked_load_async_inst*>(i_1)))
return;
if(is_latch_1)
res.reset(new double_buffer_info_t{value_0, value_1, phi});
if(is_latch_0)
res.reset(new double_buffer_info_t{value_1, value_0, phi});
}
static bool is_smem(ir::value* v) {
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v))
return true;
else
return false;
}
/// param:
/// value_1: next_value
static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::basic_block* bb1,
std::vector<ir::value*>& values_0, ir::value*& value_1) {
ir::value* next = phi;
while (auto cphi = dynamic_cast<ir::phi_node*>(next)) {
// smem from previous bb & phi/smem from current bb
ir::value* c0 = cphi->get_incoming_value(0);
ir::value* c1 = cphi->get_incoming_value(1);
ir::basic_block *cbb0 = cphi->get_incoming_block(0);
ir::basic_block *cbb1 = cphi->get_incoming_block(1);
if (is_smem(c0)) {
assert(cbb0 == bb0);
values_0.push_back(c0);
if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) {
next = phi1;
continue;
} else {
if (is_smem(c1)) {
value_1 = c1;
assert(cbb1 == bb1);
return true;
} else {
return false;
}
}
} else
return false;
}
return false;
}
void shared_layout::extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t> &res, int &prev_stages) {
auto* phi = dynamic_cast<ir::phi_node*>(v);
// if the phi node is nested
if (!phi)
return;
ir::basic_block *bb0 = phi->get_incoming_block(0);
ir::basic_block *bb1 = phi->get_incoming_block(1);
std::vector<ir::value*> values_0;
ir::value* value_1;
if (!is_multistage_pipe_phi(phi, bb0, bb1, values_0, value_1))
return;
// double-buffer is a special case
if (values_0.size() == 1)
return;
// compute original values_0 input order
std::map<ir::value*, int> order;
int idx = 0;
for (ir::instruction* instr : *bb0) {
if (std::find(values_0.begin(), values_0.end(), instr) != values_0.end())
order[static_cast<ir::value*>(instr)] = idx++;
}
assert(order.size() == values_0.size() && "order size incorrect");
int curr_stages = values_0.size() + 1;
if (curr_stages > prev_stages) {
res.reset(new N_buffer_info_t{values_0, value_1, phi, order});
prev_stages = curr_stages;
}
}
shared_layout::shared_layout(data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
ir::type *ty,
analysis::align* align, target *tgt)
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) {
size_ = 0;
arg_layout_ = arg;
// N-stage buffering
int prev_stages = 0;
for (ir::value *v : values)
extract_N_bufferable(v, N_buffer_, prev_stages);
// double-buffering
if (!N_buffer_)
for(ir::value *v: values)
extract_double_bufferable(v, double_buffer_);
// order
std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
order_ = arg_order;
ir::value* dot_a = nullptr;
ir::value* dot_b = nullptr;
ir::value* hmma_dot_a = nullptr;
ir::value* hmma_dot_b = nullptr;
for(ir::value* v: values){
extract_dot_use(v, dot_a, 0);
extract_dot_use(v, dot_b, 1);
extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm());
extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm());
}
hmma_dot_a_ = hmma_dot_a;
hmma_dot_b_ = hmma_dot_b;
// Update mma_vec
if (hmma_dot_a_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
// for now, disable swizzle when using lds.8
if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32)
if (order_[0] == 0) // need transpose
allow_swizzle_ = false;
} else if (hmma_dot_b_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
// for now, disable swizzle when using lds.8
if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32)
if (order_[0] == 1) // need transpose
allow_swizzle_ = false;
}
// size
size_ = ty_->get_primitive_size_in_bits() / 8;
for(auto s: shape_)
size_ *= s;
if(double_buffer_)
size_ *= 2;
if (N_buffer_) {
size_ *= (N_buffer_->firsts.size() + 1);
}
}
int shared_layout::get_num_stages() const {
if (double_buffer_)
return 2;
if (N_buffer_)
return N_buffer_->firsts.size() + 1;
return 1;
}
size_t shared_layout::get_per_stage_elements() const {
return get_per_stage_size()/(ty_->get_primitive_size_in_bits()/8);
}
/* -------------------------------- *
* ---- Layouts Inference Pass ---- *
* -------------------------------- */
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt)
: axes_(axes), align_(align), num_warps_(num_warps), tgt_(tgt){ }
void layouts::connect(ir::value *x, ir::value *y) {
if(x == y)
return;
if(!x->get_type()->is_block_ty())
return;
if(!y->get_type()->is_block_ty())
return;
std::vector<int> x_axes = axes_->get(x);
std::vector<int> y_axes = axes_->get(y);
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
std::set<int> common;
std::set_intersection(sx_axes.begin(), sx_axes.end(),
sy_axes.begin(), sy_axes.end(),
std::inserter(common, common.begin()));
graph_.add_edge(x, x);
graph_.add_edge(y, y);
if(!common.empty())
graph_.add_edge(x, y);
}
void layouts::make_graph(ir::instruction *i) {
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
connect(i, opx);
connect(opx, opy);
}
}
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
// if(layouts_.find(id) != layouts_.end())
// return;
auto it_hmma_c = std::find_if(values.begin(), values.end(),
[&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); });
auto cmp = [](ir::value* x, ir::value *y) {
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
return xx < yy;
};
std::vector<ir::value*> lvalue = values;
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
const auto& axes = axes_->get(largest);
const auto& shapes = largest->get_type()->get_block_shapes();
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v);
});
// type
if(it_hmma_c != values.end()){
ir::instruction *dot = (ir::instruction*)*it_hmma_c;
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
create(groups_.at(a), values_.at(groups_.at(a)));
create(groups_.at(b), values_.at(groups_.at(b)));
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_,
(shared_layout*)layouts_.at(groups_.at(a)),
(shared_layout*)layouts_.at(groups_.at(b)),
dot);
}
else if(it_cts != values.end()){
ir::instruction *cts = (ir::instruction*)*it_cts;
ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_);
}
else{
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
}
}
void layouts::run(ir::module &mod) {
// make graph
graph_.clear();
layouts_.clear();
groups_.clear();
ir::for_each_instruction(mod, [this](ir::instruction* i) {
make_graph(i);
});
// connected components
graph_.connected_components(&values_, &groups_);
// create layouts
for(const auto& x: values_)
create(x.first, x.second);
// create temporaries
size_t id = values_.size();
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
id++;
ir::value *arg = red->get_operand(0);
unsigned axis = red->get_axis();
// shape
auto shapes = arg->get_type()->get_block_shapes();
scanline_layout *layout = get(arg)->to_scanline();
shapes[axis] = layout->mts(axis);
// create layout
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[red] = id;
}
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
id++;
size_t dim = val->get_type()->get_tile_rank();
ir::type::block_shapes_t shape(dim);
for(size_t k = 0; k < dim; k++){
shape[k] = std::max(in_layout->shape_per_cta(k),
out_layout->shape_per_cta(k));
}
auto in_ord = in_layout->get_order();
auto out_ord = out_layout->get_order();
int in_vec = in_layout->contig_per_thread(in_ord[0]);
int out_vec = out_layout->contig_per_thread(out_ord[0]);
int pad = std::max(in_vec, out_vec);
shape[out_ord[0]] += pad;
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[val] = id;
}
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
id++;
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[atom] = id;
}
});
}
}
}
}

View File

@@ -1,59 +0,0 @@
#include <climits>
#include <iostream>
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace analysis{
void liveness::run(ir::module &mod) {
intervals_.clear();
// Assigns index to each instruction
std::map<ir::value*, slot_index> indices;
for(ir::function *fn: mod.get_function_list()){
slot_index index = 0;
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *instr: block->get_inst_list()){
index += 1;
indices.insert({instr, index});
}
}
// create live intervals
for(auto &x: layouts_->get_all()) {
shared_layout* layout = x.second->to_shared();
if(!layout)
continue;
// users
std::set<ir::user*> users;
for(ir::value *v: layout->get_values()){
for(ir::user *u: v->get_users())
users.insert(u);
}
// compute intervals
unsigned start = INT32_MAX;
for(ir::value *v: layout->get_values())
if(indices.find(v) != indices.end())
start = std::min(start, indices.at(v));
unsigned end = 0;
for(ir::user *u: users)
if(indices.find(u) != indices.end())
end = std::max(end, indices.at(u));
if(end == 0)
end = start + 1;
intervals_[layout] = segment{start, end};
}
}
}
}
}

View File

@@ -1,61 +0,0 @@
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/target.h"
#include "triton/ir/type.h"
#include <iostream>
namespace triton{
namespace codegen{
namespace analysis{
void swizzle::run(ir::module &) {
per_phase_.clear();
max_phase_.clear();
for(auto &x: layouts_->get_all()){
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
if(!layout)
continue;
ir::value* mma_dot_a = layout->hmma_dot_a();
ir::value* mma_dot_b = layout->hmma_dot_b();
if(!mma_dot_a && !mma_dot_b){
per_phase_[layout] = 1;
max_phase_[layout] = 1;
vec_[layout] = 1;
continue;
}
auto ord = layout->get_order();
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
if(!in_layout)
continue;
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
int inner = mma_dot_a ? 0 : 1;
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
if(mma_dot_a)
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
else
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
}
else {
if (!layout->allow_swizzle()) {
per_phase_[layout] = 1;
max_phase_[layout] = 1;
vec_[layout] = 1;
} else {
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
vec_[layout] = layout->get_mma_vec();
}
}
}
}
}
}
}

View File

@@ -1,86 +0,0 @@
#include "triton/codegen/pass.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/codegen/transform/prefetch.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/print.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
namespace triton {
namespace codegen {
// TODO:
// There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
int cc, int num_warps, int num_stages, int& shared_static) {
// generate llvm code
std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
// optimizations
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
// create passes
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, num_warps, target);
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target);
codegen::analysis::allocation allocation(&liveness);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target, &layouts);
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::prefetch prefetch_s(target);
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
// run passes
dce.run(ir);
peephole.run(ir);
dce.run(ir);
pipeline.run(ir);
dce.run(ir);
disassociate.run(ir);
dce.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
peephole.run(ir);
dce.run(ir);
if (target->is_gpu())
cts.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
coalesce.run(ir);
dce.run(ir);
align.run(ir);
dce.run(ir);
if (target->is_gpu())
cts.run(ir);
dce.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
peephole.run(ir);
dce.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
swizzle.run(ir);
liveness.run(ir);
allocation.run(ir);
prefetch_s.run(ir);
barriers.run(ir);
isel.visit(ir, *llvm);
shared_static = allocation.allocated_size();
return llvm;
}
} // namespace codegen
} // namespace triton

File diff suppressed because it is too large Load Diff

View File

@@ -1,173 +0,0 @@
#include "triton/codegen/target.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/IRBuilder.h"
#include <iostream>
using namespace llvm;
namespace triton{
namespace codegen{
// base
nvidia_cu_target* target::as_nvidia() {
return dynamic_cast<nvidia_cu_target*>(this);
}
bool target::is_gpu() const {
return is_gpu_;
}
// AMD
void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
fn->setCallingConv(CallingConv::AMDGPU_KERNEL);
}
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_barrier, {}, {});
}
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* group_id = get_block_id(module, builder, ax);
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
}
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
throw std::runtime_error("not implemented");
}
Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workgroup_id_x,
Intrinsic::amdgcn_workgroup_id_y,
Intrinsic::amdgcn_workgroup_id_z
};
Value* group_id = builder.CreateIntrinsic(ids[ax], {}, {});
return group_id;
}
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
throw std::runtime_error("not implemented on AMD");
}
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workitem_id_x,
Intrinsic::amdgcn_workitem_id_y,
Intrinsic::amdgcn_workitem_id_z
};
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_local_id, {});
}
// NVIDIA
void nvidia_cu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn){
// set metadata
Metadata *md_args[] = {
ValueAsMetadata::get(fn),
MDString::get(ctx, "kernel"),
ValueAsMetadata::get(builder.getInt32(1))
};
module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
}
Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_barrier0);
return builder.CreateCall(barrier, {});
}
Instruction* nvidia_cu_target::add_memfence(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_membar_gl);
return builder.CreateCall(barrier, {});
}
Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* group_id = get_block_id(module, builder, ax);
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
}
Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> cta_ids = {
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
};
Value* cta_id = builder.CreateIntrinsic(cta_ids[ax], {}, {});
return cta_id;
}
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::nvvm_read_ptx_sreg_tid_x,
Intrinsic::nvvm_read_ptx_sreg_tid_y,
Intrinsic::nvvm_read_ptx_sreg_tid_z
};
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_local_id, {});
}
Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
};
return builder.CreateIntrinsic(ids[ax], {}, {});
}
// CPU
void cpu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
// normal cpu functions can be kernels
}
Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) {
// no barrier on CPU
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
}
Instruction* cpu_target::add_memfence(Module *module, IRBuilder<>& builder) {
// no barrier on CPU
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
}
Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) {
const Function *fn = builder.GetInsertBlock()->getParent();
size_t num_params = fn->getFunctionType()->getNumParams();
static std::array<const Argument*, 3> ids = {
fn->arg_begin() + num_params - 3,
fn->arg_begin() + num_params - 2,
fn->arg_begin() + num_params - 1
};
return (Argument*)ids[ax];
}
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
throw std::runtime_error("not implemented");
}
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax));
return result;
}
Value* cpu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
return builder.getInt32(0);
}
}
}

View File

@@ -1,133 +0,0 @@
#include <algorithm>
#include <iostream>
#include "triton/ir/utils.h"
#include "triton/ir/instructions.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/layout.h"
namespace triton {
namespace codegen{
namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
: align_(align), layout_(layouts) { }
// simplify layout conversions using the following simple rules:
// - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
// - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y))
//ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){
// ir::value* _op = inst->get_operand(0);
// ir::instruction* op = dynamic_cast<ir::instruction*>(_op);
// analysis::mma_layout* mma_in = layout_->get(op) ->to_mma();
// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma();
// std::cout << 1 << std::endl;
// // i must be layout conversion instruction
// if(!mma_in && !mma_out)
// return inst;
// // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
// bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT;
// if((mma_in || mma_out) && is_op_cvt &&
// (layout_->get(inst) == layout_->get(op->get_operand(0))))
// return op->get_operand(0);
// // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y))
// if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR)
// return inst;
// std::cout << 1 << std::endl;
// for(size_t i = 0; i < op->get_num_operands(); i++){
// ir::value* arg_i = op->get_operand(i);
// builder.set_insert_point(op);
// // create new layout transform
// ir::instruction* new_arg_i = inst->clone();
// builder.insert(new_arg_i);
// // set the right args
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
// op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder));
// }
// std::cout << 2 << std::endl;
// return op;
//}
void coalesce::run(ir::module &mod) {
ir::builder& builder = mod.get_builder();
// add layout conversion instructions
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
// coalesce before store
if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
if(ir::value* op = i->get_operand(1))
if(op->get_type()->is_block_ty())
if(layout_->get(op)->to_mma()){
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
builder.set_insert_point(i);
builder.insert(new_op);
i->replace_uses_of_with(op, new_op);
}
// uncoalesce after load
if(auto x = dynamic_cast<ir::load_inst*>(i))
if(x->get_type()->is_block_ty())
if(x->get_type()->get_tile_rank()==2)
if(layout_->get(x)->to_mma()){
builder.set_insert_point_after(x);
ir::instruction* new_x = ir::cvt_layout_inst::create(x);
builder.insert(new_x);
x->replace_all_uses_with(new_x);
new_x->replace_uses_of_with(new_x, x);
}
}
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
// re-arrange scanline to promote memory coalescing
if(auto x = dynamic_cast<ir::store_inst*>(i)){
ir::value* ptr = x->get_pointer_operand();
ir::value* val = x->get_value_operand();
auto out_contig = align_->contiguous(ptr);
auto val_inst = dynamic_cast<ir::instruction*>(val);
if(!val_inst)
break;
if(dynamic_cast<ir::cvt_layout_inst*>(val))
break;
std::vector<unsigned> in_contig;
std::vector<ir::instruction*> queue = {val_inst};
std::set<ir::instruction*> seen;
std::vector<ir::io_inst*> ios;
while(!queue.empty()){
ir::instruction* curr = queue.back();
seen.insert(curr);
queue.pop_back();
if(auto dot_inst = dynamic_cast<ir::dot_inst*>(curr))
break;
if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){
in_contig = align_->contiguous(io_inst->get_pointer_operand());
break;
}
for(ir::value* op: curr->ops()){
auto inst_op = dynamic_cast<ir::instruction*>(op);
if(!inst_op || seen.find(inst_op) != seen.end())
continue;
if(!op->get_type()->is_block_ty() ||
!val->get_type()->is_block_ty())
continue;
if(op->get_type()->get_tile_num_elements() ==
val->get_type()->get_tile_num_elements())
queue.push_back(inst_op);
}
}
if(in_contig.size() <= 1 || out_contig==in_contig)
continue;
builder.set_insert_point_after(val_inst);
auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst));
x->replace_uses_of_with(val_inst, new_val);
}
}
}
}
}
}

View File

@@ -1,97 +0,0 @@
#include "triton/codegen/transform/cts.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include <iostream>
namespace triton {
namespace codegen{
namespace transform{
inline bool is_shmem_op(ir::instruction* i, int op) {
if(i->get_id() == ir::INST_DOT)
return op==0 || op==1;
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
return op==0;
if(i->get_id() == ir::INST_TRANS)
return op==0;
return false;
}
inline bool is_shmem_res(ir::value* v){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return false;
if(i->get_id() == ir::INST_TRANS)
return true;
if(i->get_id() == ir::INST_COPY_TO_SHARED)
return true;
if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
return true;
return false;
}
// run pass on module
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction
if(!i) {
builder.set_insert_point(parent);
ir::value *copy;
if(to_shared)
copy = builder.create_copy_to_shared(x);
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
return;
}
// phi node
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
return;
}
// already in shared memory
if(to_shared && is_shmem_res(i))
return;
// copy
builder.set_insert_point_after(i);
ir::value *copy;
if(to_shared){
copy = builder.create_copy_to_shared(x);
}
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
}
void cts::run(ir::module &mod) {
// Add shared copies
ir::builder &builder = mod.get_builder();
for(ir::function* fn: mod.get_function_list()){
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
size_t num_op = i->get_num_operands();
// copy to shared operands
for(size_t k = 0; k < num_op; k++)
if(is_shmem_op(i, k)){
add_copy(i, i->get_operand(k), builder, true);
}
// copy from shared operands
for(size_t k = 0; k < num_op; k++)
if(!dynamic_cast<ir::phi_node*>(i) &&
!is_shmem_op(i,k) &&
is_shmem_res(i->get_operand(k))){
add_copy(i, i->get_operand(k), builder, false);
}
}
}
}
}
}
}

View File

@@ -1,75 +0,0 @@
#include "triton/codegen/transform/dce.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
void dce::run(ir::module &mod) {
std::list<ir::instruction*> work_list;
std::set<ir::instruction*> marked;
// initialize work-list
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// iterate through blocks
for(ir::basic_block *block: rpo)
for(ir::instruction *i: block->get_inst_list()){
switch(i->get_id()){
case ir::INST_RETURN:
case ir::INST_UNCOND_BRANCH:
case ir::INST_COND_BRANCH:
case ir::INST_UNMASKED_STORE:
case ir::INST_MASKED_STORE:
case ir::INST_ATOMIC_CAS:
case ir::INST_ATOMIC_RMW:
case ir::INST_ATOMIC_EXCH:
case ir::INST_BARRIER: {
work_list.push_back(i);
marked.insert(i);
break;
}
default:
break;
}
}
}
// mark -- ignore branches
while(!work_list.empty()){
ir::instruction* current = work_list.back();
work_list.pop_back();
// mark instruction operands
for(ir::value* op: current->ops()) {
if(auto *i = dynamic_cast<ir::instruction*>(op)){
if(marked.insert(i).second)
work_list.push_back(i);
}
}
// TODO: mark last intstruction of current's reverse-dominance frontier
}
// sweep -- delete non-branch unmarked instructions
std::vector<ir::instruction*> to_delete;
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// iterate through blocks
for(ir::basic_block *block: rpo)
for(ir::instruction *i: block->get_inst_list()){
if(marked.find(i) == marked.end())
to_delete.push_back(i);
}
}
// delete
for(ir::instruction* i: to_delete)
i->erase_from_parent();
}
}
}
}

View File

@@ -1,62 +0,0 @@
#include "triton/codegen/transform/disassociate.h"
#include "triton/ir/utils.h"
#include "triton/ir/instructions.h"
#include "triton/ir/builder.h"
#include "triton/ir/module.h"
#include <iostream>
namespace triton {
namespace codegen{
namespace transform{
ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
std::set<ir::value*>& seen) {
if (dynamic_cast<ir::phi_node*>(root))
return root;
if(!seen.insert(root).second)
return root;
if(!root->get_type()->is_block_ty())
return root;
bld.set_insert_point(root);
ir::instruction *new_root = bld.insert(root->clone());
for(ir::value *op: root->ops()){
ir::instruction *i = dynamic_cast<ir::instruction*>(op);
if(!i || i->get_id() == ir::INST_REDUCE)
continue;
ir::instruction* new_op = rematerialize(bld, i, seen);
new_root->replace_uses_of_with(op, new_op);
}
return new_root;
}
void disassociate::run(ir::module &mod) {
ir::builder &bld = mod.get_builder();
// ir::for_each_instruction(mod, [&](ir::instruction *i){
// bld.set_insert_point(i);
// for(ir::value* op: i->ops()){
// auto reshape = dynamic_cast<ir::make_range*>(op);
// if(!reshape)
// continue;
// ir::instruction* new_op = bld.insert(reshape->clone());
// i->replace_uses_of_with(op, new_op);
// }
// });
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(dynamic_cast<ir::reshape_inst*>(i) || dynamic_cast<ir::splat_inst*>(i)){
std::set<ir::value*> seen;
ir::instruction* new_i = rematerialize(bld, i, seen);
i->replace_all_uses_with(new_i);
}
});
}
}
}
}

View File

@@ -1,244 +0,0 @@
#include <vector>
#include <set>
#include <algorithm>
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/prefetch.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
if (analysis::double_buffer_info_t* info = layout->get_double_buffer())
return group_of(info->first, async_write);
else if (analysis::N_buffer_info_t* info = layout->get_N_buffer()) {
if (v == info->phi)
return group_of(info->firsts[0], async_write);
else // prefetched value
return group_of(info->firsts[1], async_write);
}
std::vector<int> groups(phi->get_num_operands());
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
return *std::max_element(groups.begin(), groups.end());
}
else{
if(layouts_->has_tmp(v))
return async_write.size() - 1;
auto it = std::find(async_write.begin(), async_write.end(), v);
return std::distance(async_write.begin(), it);
}
}
inline bool membar::intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout) {
if(!a_layout || !b_layout)
return false;
int a_start = alloc_->offset(a_layout);
int a_end = a_start + a_layout->get_size();
int b_start = alloc_->offset(b_layout);
int b_end = b_start + b_layout->get_size();
if(a_start < b_end || b_start < a_end)
return true;
return false;
}
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
val_set_t ret;
for(ir::value* a: as){
if(!a->get_type()->is_block_ty())
continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr;
for(ir::value* b: bs){
if(!b->get_type()->is_block_ty())
continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr;
if(intersect_with(a_layout, b_layout) ||
intersect_with(a_layout, b_tmp) ||
intersect_with(a_tmp, b_layout) ||
intersect_with(a_tmp, b_tmp))
ret.insert(b);
}
}
return ret;
}
bool membar::check_safe_war(ir::instruction* i) {
bool is_i_shared_block = i->get_type()->is_block_ty() &&
layouts_->get(i)->to_shared();
bool is_i_double_buffered = is_i_shared_block &&
layouts_->get(i)->to_shared()->get_double_buffer();
bool is_i_n_buffered = is_i_shared_block &&
layouts_->get(i)->to_shared()->get_N_buffer();
if (is_i_double_buffered || is_i_n_buffered) {
// with async copy & prefetch_s disabled, WARs are not safe
if (dynamic_cast<ir::masked_load_async_inst*>(i) && !prefetch_->is_prefetched(i))
return false;
else
return true;
}
return false;
}
void membar::transfer(ir::basic_block *block,
val_vec_t& async_write,
val_set_t& sync_write,
val_set_t& sync_read,
std::set<ir::value*>& safe_war,
bool& inserted, ir::builder& builder) {
std::vector<ir::async_wait_inst*> async_waits;
ir::basic_block::inst_list_t instructions = block->get_inst_list();
for(ir::instruction *i: instructions){
if(dynamic_cast<ir::phi_node*>(i))
continue;
if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() &&
dynamic_cast<ir::masked_load_async_inst*>(i)){
async_write.push_back(i);
}
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
sync_write.insert(i);
ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i);
ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i);
// Get shared memory reads
std::set<ir::value*> read;
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
[&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();});
if(layouts_->has_tmp(i))
read.insert(i);
// RAW (async)
val_set_t tmp;
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
if(intersect_with(read, tmp).size()){
std::vector<int> groups(read.size());
std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
int N = *std::max_element(groups.begin(), groups.end());
if(N < async_write.size()){
builder.set_insert_point(i);
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
async_waits.push_back(async_wait);
}
}
// RAW, WAR
bool is_safe_war = check_safe_war(i);
// WAR barrier is not required when data is double-buffered
if(!intersect_with(read, sync_write).empty() ||
(!intersect_with({i}, sync_read).empty() && !is_safe_war)) {
builder.set_insert_point(i);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
}
// update state of asynchronous copies
if(async_wait){
int N = async_write.size() - async_wait->get_N();
async_write.erase(async_write.begin(), async_write.begin() + N);
}
// all the copy_to_shared and read from shared are synchronized after barrier
if(barrier){
sync_write.clear();
sync_read.clear();
}
sync_read.insert(read.begin(), read.end());
}
// coalesce barriers
// fixme: to support more general cases
if (async_waits.size() == 2) {
// (aw N; bar; prefetch; aw N-1; bar; prefetch; => aw N-1; bar; 2*prefetch;)
for (int idx=0; idx<async_waits.size()-1; ++idx) {
ir::async_wait_inst *first_async_wait = async_waits[idx];
std::vector<ir::instruction*> to_erase;
ir::basic_block::inst_list_t instructions = block->get_inst_list();
for(auto iter = instructions.begin(); iter != instructions.end(); ++iter){
ir::instruction *i = *iter;
if (static_cast<ir::instruction*>(first_async_wait) == i) {
// peak next 5 instructions
auto peak_iter = std::next(iter);
if (std::distance(peak_iter, instructions.end()) >= 5) {
auto first_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
auto first_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter++);
auto second_async_wait = dynamic_cast<ir::async_wait_inst*>(*peak_iter++);
auto second_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
auto second_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter);
if (first_bar && first_pf && second_async_wait && second_bar && second_pf) {
int first_n = first_async_wait->get_N();
int second_n = second_async_wait->get_N();
to_erase.push_back(second_async_wait);
to_erase.push_back(second_bar);
first_async_wait->set_N(second_n);
}
} else
break;
for (ir::instruction *i : to_erase)
block->erase(i);
}
}
}
}
}
void membar::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// extract phi-node associates with double-buffered
// shared-memory copies. These can be read from and written to
// without needing synchronization
std::set<ir::value*> safe_war;
for(const auto& x: layouts_->get_all()){
analysis::shared_layout* layout = x.second->to_shared();
if(!layout || !layout->get_double_buffer() || !layout->get_N_buffer())
continue;
for(ir::value *v: layout->get_values())
if(v != layout->get_double_buffer()->phi){
safe_war.insert(v);
}
}
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, val_vec_t> async_writes;
std::map<ir::basic_block*, val_set_t> sync_writes;
std::map<ir::basic_block*, val_set_t> sync_reads;
std::list<ir::value *> pipelined;
bool inserted;
do{
inserted = false;
// find barrier location
for(ir::basic_block *block: rpo){
// join inputs
val_vec_t async_write;
val_set_t sync_write;
val_set_t sync_read;
val_set_t tmp;
for(ir::basic_block* pred: block->get_predecessors()){
for(ir::value* v: async_writes[pred])
if(tmp.insert(v).second)
async_write.push_back(v);
sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end());
sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end());
}
transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder);
async_writes[block] = async_write;
sync_writes[block] = sync_write;
sync_reads[block] = sync_read;
}
}while(inserted);
}
}
}
}
}

View File

@@ -1,309 +0,0 @@
#include <algorithm>
#include <iostream>
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/analysis/layout.h"
namespace triton {
namespace codegen{
namespace transform{
ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
const std::vector<int>& perm) {
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
// transpose operands
std::vector<ir::value*> incs;
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
incs.push_back(rewrite_trans_phi_impl(phi->get_incoming_value(n), builder, perm));
// create phi for transposed values
builder.set_insert_point(phi);
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size());
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
result->add_incoming(incs[n], phi->get_incoming_block(n));
return result;
}
else if(auto i = dynamic_cast<ir::instruction*>(value)){
ir::basic_block* block = i->get_parent();
auto it = std::find(block->begin(), block->end(), i);
it++;
builder.set_insert_point(it);
ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm);
trans->set_operand(0, i);
return trans;
}
return nullptr;
}
bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
auto trans = dynamic_cast<ir::trans_inst*>(value);
if(!trans)
return false;
auto users = trans->get_users();
auto ops = trans->ops();
if(users.size() > 1 || ops.size() > 1)
return false;
ir::value* op = *ops.begin();
// trans(phi) -> phi(trans(), trans()...)
auto* phi = dynamic_cast<ir::phi_node*>(op);
if(!phi)
return false;
ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm());
if(!new_phi)
return false;
trans->replace_all_uses_with(new_phi);
return true;
}
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
// dot(a, b, c) + d -> dot(a, b, c + d)
// d + dot(a, b, c) -> dot(a, b, c + d)
auto add = dynamic_cast<ir::binary_operator*>(value);
if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) {
bool is_int_dot = add->get_op() == ir::binary_op_t::Add;
ir::value *lhs = add->get_operand(0);
ir::value *rhs = add->get_operand(1);
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
ir::dot_inst *rhs_dot = dynamic_cast<ir::dot_inst*>(rhs);
if(!lhs_dot && !rhs_dot)
return false;
ir::dot_inst *dot = lhs_dot ? lhs_dot : rhs_dot;
ir::value *other = (dot == lhs) ? rhs : lhs;
ir::value *acc = dot->get_operand(2);
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
ir::constant *_0 = nullptr;
if(splat)
_0 = dynamic_cast<ir::constant*>(splat->get_operand(0));
if(!_0)
return false;
if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0))
if (fp_0->get_value() != 0.0)
return false;
if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0))
if (int_0->get_value() != 0)
return false;
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
builder.set_insert_point(add);
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name()));
add->replace_all_uses_with(new_dot);
return true;
}
return false;
}
//bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
// auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
// if(cfs) {
// ir::value *arg = cfs->get_operand(0);
// ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
// if(!cts)
// return false;
// cfs->replace_all_uses_with(cts->get_operand(0));
// return true;
// }
//}
bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){
auto copy_to_shared = dynamic_cast<ir::copy_to_shared_inst*>(value);
if(!copy_to_shared)
return false;
ir::value *arg = copy_to_shared->get_operand(0);
ir::masked_load_inst* ld = dynamic_cast<ir::masked_load_inst*>(arg);
if(!ld)
return false;
builder.set_insert_point(copy_to_shared);
ir::value *ptr = ld->get_pointer_operand();
ir::value *msk = ld->get_mask_operand();
ir::value *val = ld->get_false_value_operand();
analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
int nts = layout->nts(layout->get_order()[0]);
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(nts*dtsize >= 4){
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy());
copy_to_shared->replace_all_uses_with(new_load);
return true;
}
return false;
// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
// std::cout << layout->nts(layout->get_order(0)) << std::endl;
// return true;
}
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
auto x = dynamic_cast<ir::reduce_inst*>(value);
if(!x)
return false;
ir::value *arg = x->get_operand(0);
auto shapes = arg->get_type()->get_block_shapes();
if(shapes[x->get_axis()] == 1){
builder.set_insert_point(x);
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_block_shapes());
x->replace_all_uses_with(new_red);
return true;
}
return false;
}
bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
auto binop = dynamic_cast<ir::binary_operator*>(value);
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
ir::value *lhs = binop->get_operand(0);
ir::value *rhs = binop->get_operand(1);
ir::constant_int *_1_lhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_lhs = cst;
}
ir::constant_int *_1_rhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_rhs = cst;
}
if(_1_lhs){
binop->replace_all_uses_with(rhs);
return true;
}
else if(_1_rhs){
binop->replace_all_uses_with(lhs);
return true;
}
}
return false;
}
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
auto x = dynamic_cast<ir::getelementptr_inst*>(value);
if(!x)
return false;
auto y = dynamic_cast<ir::getelementptr_inst*>(x->get_pointer_operand());
if(!y)
return false;
auto idx = *y->idx_begin();
auto z = dynamic_cast<ir::binary_operator*>(idx);
if(!z)
return false;
bool is_sub = z->get_op() == ir::binary_op_t::Sub;
auto *lhs = dynamic_cast<ir::constant_int*>(z->get_operand(0));
bool is_lhs_0 = lhs && (lhs->get_value()==0);
bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin();
if(is_sub && is_lhs_0 && is_rhs_eq_x_rhs){
x->replace_all_uses_with(y->get_pointer_operand());
return true;
}
return false;
}
bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& builder){
auto select = dynamic_cast<ir::select_inst*>(value);
if(!select)
return false;
auto if_value = dynamic_cast<ir::masked_load_inst*>(select->get_if_value_op());
if(!if_value)
return false;
if(select->get_pred_op() != if_value->get_mask_operand())
return false;
builder.set_insert_point(select);
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
if_value->get_mask_operand(),
select->get_else_value_op(),
if_value->get_cache_modifier(),
if_value->get_eviction_policy(),
if_value->get_is_volatile());
select->replace_all_uses_with(new_load);
return true;
}
bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){
auto cvt = dynamic_cast<ir::cvt_layout_inst*>(value);
if(!cvt)
return false;
ir::instruction* op = dynamic_cast<ir::instruction*>(cvt->get_operand(0));
if(!op)
return false;
// // convert(elementwise(x, y)) = elementwise(convert(x), convert(y))
// if(op->get_id() == ir::INST_BINOP){
// for(size_t i = 0; i < op->get_num_operands(); i++){
// ir::value* arg_i = op->get_operand(i);
// builder.set_insert_point(op);
// // create new layout transform
// ir::instruction* new_arg_i = cvt->clone();
// layouts_->copy(new_arg_i, op);
// builder.insert(new_arg_i);
// // set the right args
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
// op->replace_uses_of_with(arg_i, new_arg_i);
// }
// cvt->replace_all_uses_with(op);
// return true;
// }
auto cvt_op = dynamic_cast<ir::cvt_layout_inst*>(op);
if(!cvt_op)
return false;
// convert1(convert2(x)) if convert1 is the inverse of convert2
ir::value* op_op = cvt_op->get_operand(0);
if(layouts_->has(cvt) && layouts_->has(op_op) &&
layouts_->get(cvt) && layouts_->get(op_op)){
cvt->replace_all_uses_with(op_op);
return true;
}
return false;
}
void peephole::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// keep track of whether any modification was made
std::set<ir::value*> seen;
size_t n_seen;
// rewrite dots first
do{
n_seen = seen.size();
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
if(seen.find(i) != seen.end())
continue;
bool was_modified = rewrite_dot(i, builder);
if(was_modified){
seen.insert(i);
}
}
}while(seen.size() != n_seen);
// rewrite other ops
seen.clear();
do{
n_seen = seen.size();
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
if(seen.find(i) != seen.end())
continue;
bool was_modified = false;
was_modified = was_modified || rewrite_mult(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
// was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
// TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD
// was_modified = was_modified || rewrite_select_masked_load(i, builder);
was_modified = was_modified || rewrite_cvt_layout(i, builder);
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(was_modified)
seen.insert(i);
}
}while(seen.size() != n_seen);
}
}
}
}

View File

@@ -1,330 +0,0 @@
#include <iostream>
#include <algorithm>
#include "triton/codegen/transform/pipeline.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i || i->get_parent() != block)
return;
if(i->get_id()==ir::INST_PHI)
return;
ret.push_back(i);
for(ir::user* u: i->get_users())
recursive_deps(u, block, ret);
}
void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
auto instr = dynamic_cast<ir::instruction*>(cond);
for (auto op : instr->ops()) {
if (auto phi_op = dynamic_cast<ir::phi_node*>(op)) {
phis.insert(phi_op);
return;
}
if (dynamic_cast<ir::instruction*>(op))
get_induction_vars(op, phis);
}
}
/// assume incoming block is 1
ir::value* rematerialize_vals(ir::builder& builder, ir::basic_block* block, ir::value* v,
std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i || i->get_parent() != block)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) {
if (prev_phi_vals.find(phi) == prev_phi_vals.end())
throw std::runtime_error("Don't have that phi node\n");
return prev_phi_vals.at(phi);
}
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize_vals(builder, block, op, prev_phi_vals));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
ret->set_operand(k, new_ops[k]);
builder.insert(ret);
return ret;
}
ir::value* rematerialize(ir::builder& builder, ir::basic_block* block,
ir::value* v, size_t phi_idx){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i || i->get_parent() != block)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
return phi->get_incoming_value(phi_idx);
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize(builder, block, op, phi_idx));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
ret->set_operand(k, new_ops[k]);
builder.insert(ret);
return ret;
}
/// moving the prev phi vals to the next iteration
std::map<ir::phi_node*, ir::value*> update_prev_phi_vals(
ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
std::map<ir::phi_node*, ir::value*> next_phi_vals;
for (auto &[phi, val] : prev_phi_vals) {
next_phi_vals[phi] = rematerialize_vals(builder, block, phi->get_incoming_value(1), prev_phi_vals);
}
return next_phi_vals;
}
void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& load_ivs,
std::map<ir::phi_node*, ir::value*>& next_load_ivs) {
for (auto& [phi, val] : load_ivs) {
if (auto new_phi = dynamic_cast<ir::phi_node*>(val)) {
ir::value* next_k = rematerialize_vals(builder, block, phi->get_incoming_value(1), load_ivs);
assert(new_phi->get_num_operands() == 1 && "should be incomplete phi");
new_phi->add_incoming(next_k, phi->get_incoming_block(1));
// cache next_k (to be used by next_mask)
next_load_ivs[phi] = next_k;
} else
throw std::runtime_error("must be phi");
}
}
struct pipeline_info_t {
ir::load_inst* load;
ir::phi_node* ptr;
ir::dot_inst* dot;
pipeline_info_t(ir::load_inst* load, ir::phi_node* ptr, ir::dot_inst* dot)
: load(load), ptr(ptr), dot(dot) {}
};
void pipeline::run(ir::module &mod) {
if (num_stages_ <= 1)
return;
// *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if:
// - the pointer is a phi node that references a value
// in its basic block (i.e., pointer induction variable)
// - the load has only a single use in a dot instruction
// As more use cases become apparent, this pass will be improved
std::vector<pipeline_info_t> to_pipeline;
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
auto users = load->get_users();
auto dot = dynamic_cast<ir::dot_inst*>(*users.begin());
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
&& users.size() == 1 && dot)
to_pipeline.push_back({load, ptr, dot});
}});
// do the pipelining
std::vector<ir::phi_node*> new_loads;
ir::builder &builder = mod.get_builder();
const int num_stages = num_stages_;
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
for(auto info: to_pipeline){
ir::load_inst* load = info.load;
ir::phi_node* ptr = info.ptr;
ir::basic_block* block = load->get_parent();
ir::basic_block* header = block->get_predecessors()[0];
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
auto* header_br = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back());
assert(block_br);
assert(header_br);
ir::type* ty = load->get_type();
// multi-stage pipe
if (has_copy_async_ && num_stages > 2) {
ir::value* header_cond = header_br->get_cond();
ir::value* block_cond = block_br->get_cond();
// 1. collect induction variables
std::set<ir::phi_node*> induction_vars;
get_induction_vars(block_cond, induction_vars);
std::vector<ir::value*> first_ptrs(num_stages-1);
std::vector<ir::value*> first_loads(num_stages-1);
std::vector<ir::value*> first_masks(num_stages-1);
std::vector<ir::value*> loop_conds(num_stages-1);
std::map<ir::phi_node*, ir::value*> prev_phi_vals;
// initialize prev_phi_vals
// Add all phi nodes. The following DCE pass will delete dead ones.
for (ir::instruction *instr : block->get_inst_list())
if (auto *phi = dynamic_cast<ir::phi_node*>(instr))
if (phi->get_incoming_block(1) == block)
prev_phi_vals[phi] = phi->get_value_for_block(header);
builder.set_insert_point(header->get_inst_list().back());
first_ptrs[0] = ptr->get_value_for_block(header);
loop_conds[0] = header_cond;
first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes());
ir::value* false_value = nullptr;
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask =rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals) ;
ir::value* remat_false_value =
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
first_masks[0] = builder.create_and(first_masks[0], remat_mask);
false_value = remat_false_value;
} else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
for (int stage = 1; stage < num_stages-1; ++stage) {
// mask is the loop condition of the previous iteration
loop_conds[stage] = rematerialize_vals(builder, block, block_cond, prev_phi_vals);
prev_phi_vals = update_prev_phi_vals(builder, block, prev_phi_vals);
first_ptrs[stage] = rematerialize_vals(builder, block, ptr, prev_phi_vals);
first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals);
ir::value* remat_false_value =
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
false_value = remat_false_value;
}
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
}
// create new phis for induction variables
builder.set_insert_point(block->get_first_non_phi());
std::map<ir::phi_node*, ir::value*> load_ivs;
std::map<ir::phi_node*, ir::value*> next_load_ivs;
for (auto& [iv, val] : prev_phi_vals) {
ir::phi_node* pn = builder.create_phi(iv->get_type(), 2);
pn->add_incoming(prev_phi_vals[iv], header);
load_ivs[iv] = pn;
}
// add incoming for phis & update next_load_ivs
finalize_iv_vals(builder, block, load_ivs, next_load_ivs);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
// ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_ptr = rematerialize_vals(builder, block, ptr->get_value_for_block(block), load_ivs);
ir::value* next_mask = builder.create_splat(
rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), next_load_ivs);
// TODO: false may depends on some other phi nodes
ir::value* remat_false_value =
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), next_load_ivs);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// phi node
ptr->set_incoming_value(0, first_ptrs.back());
builder.set_insert_point(block->get_first_non_phi());
// nested phis for load
std::vector<ir::phi_node*> new_load_phis(num_stages-1);
for (auto& pn : new_load_phis)
pn = builder.create_phi(ty, 2);
for (int i=0; i<num_stages-2; ++i) {
new_load_phis[i]->add_incoming(first_loads[i], header);
new_load_phis[i]->add_incoming(new_load_phis[i+1], block);
}
new_load_phis.back()->add_incoming(first_loads.back(), header);
new_load_phis.back()->add_incoming(next_load, block);
load->replace_all_uses_with(new_load_phis.front());
new_loads.push_back(new_load_phis.back());
// record first_loads to reorder them
preheader_loads.push_back({new_load_phis.front(), first_loads});
} else {
// pre-fetch first iteration
builder.set_insert_point(header->get_inst_list().back());
ir::value* first_ptr = ptr->get_value_for_block(header);
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
ir::value* false_value;
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 0);
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 0);
first_mask = builder.create_and(first_mask, remat_mask);
false_value = remat_false_value;
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 1);
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 1);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// phi node
builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2);
new_load->add_incoming(first_load, header);
new_load->add_incoming(next_load, block);
load->replace_all_uses_with(new_load);
new_loads.push_back(new_load);
}
}
// try to reorder prefetched value from a0, a1, a2, ..., b0, b1, b2, ... to
// a0, b0, a1, b1, ...
if (!preheader_loads.empty()) {
ir::basic_block* header = preheader_loads.begin()->first->get_incoming_block(0);
builder.set_insert_point(header->get_inst_list().back());
for (int i=1; i<num_stages-1; ++i) {
for (auto iter = preheader_loads.begin(); iter != preheader_loads.end(); ++iter) {
ir::instruction* original_load = static_cast<ir::instruction*>(iter->second.at(i));
ir::instruction* moved_load = original_load->clone();
builder.insert(moved_load);
original_load->replace_all_uses_with(moved_load);
}
}
}
// try to move dot_inst after loads
// for better overlap of io and compute
struct move_config_t{
std::vector<ir::instruction*> insts;
ir::load_inst* dst;
};
std::vector<move_config_t> to_move(to_pipeline.size());
if(has_copy_async_){
for (size_t idx = 0; idx < to_pipeline.size(); ++idx) {
auto info = to_pipeline[idx];
ir::load_inst* load = info.load;
ir::phi_node* ptr = info.ptr;
ir::dot_inst* dot = info.dot;
ir::basic_block* bb = dot->get_parent();
recursive_deps(dot, bb, to_move[idx].insts);
to_move[idx].dst = load;
}
for(auto& move_config: to_move){
builder.set_insert_point_after(move_config.dst);
for(ir::instruction* i: move_config.insts){
i->get_parent()->erase(i);
builder.insert(i);
}
}
}
}
}
}
}

View File

@@ -1,133 +0,0 @@
#include "triton/codegen/transform/prefetch.h"
#include "triton/codegen/target.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
#include "triton/ir/print.h"
#include <iostream>
#include <vector>
#include <algorithm>
namespace triton::codegen::transform {
/// find defs till phis
static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector<ir::instruction*> &ret) {
ir::instruction *i = dynamic_cast<ir::instruction*>(v);
if (!i || i->get_parent() != bb)
return;
if (i->get_id() == ir::INST_PHI)
return;
ret.push_back(i);
for (ir::value *op : i->ops())
recursive_defs(op, bb, ret);
}
void prefetch::run(ir::module &mod) {
// 1. collect dots that can be prefethced
std::vector<ir::dot_inst*> to_prefetch;
ir::for_each_instruction(mod, [&](ir::instruction *i) {
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
// Now only do prefetching when dot is using tensor cores
if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8)
&& dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8)
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
)
)
return;
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
if (a && a->get_incoming_block(1) == a->get_parent() &&
b && b->get_incoming_block(1) == b->get_parent())
to_prefetch.push_back(dot);
}
});
assert(to_prefetch.size() <=1 && "Don't know what to do with multiple dots");
ir::builder &builder = mod.get_builder();
// 2. do the prefetching
for (ir::dot_inst* dot : to_prefetch) {
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
assert(a->get_incoming_block(0) == b->get_incoming_block(0));
ir::basic_block *loop_header = a->get_incoming_block(0);
ir::basic_block *loop_body = a->get_parent();
// mark as prefetched
dot->set_prefetched(true);
// 1. in the loop header (first iteration)
builder.set_insert_point(loop_header->get_inst_list().back());
assert(a && b);
builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0);
builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0);
// 2. at the end of the loop body (next iteration)
builder.set_insert_point(loop_body->get_inst_list().back());
builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1);
builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1);
prefetched_vals_.insert(a->get_incoming_value(0));
prefetched_vals_.insert(b->get_incoming_value(0));
// nested phis
ir::value* next_a = a->get_incoming_value(1);
while (auto* next_a_phi = dynamic_cast<ir::phi_node*>(next_a)) {
prefetched_vals_.insert(next_a_phi->get_incoming_value(0));
next_a = next_a_phi->get_incoming_value(1);
}
prefetched_vals_.insert(next_a);
ir::value* next_b = b->get_incoming_value(1);
while (auto* next_b_phi = dynamic_cast<ir::phi_node*>(next_b)) {
prefetched_vals_.insert(next_b_phi->get_incoming_value(0));
next_b = next_b_phi->get_incoming_value(1);
}
prefetched_vals_.insert(next_b);
}
// move loads to the beginning of the loop
if (tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80) {
for (ir::function *fn : mod.get_function_list())
for (ir::basic_block *bb : fn->blocks()) {
// only apply to loop body
if (bb->get_predecessors().size() != 2 || bb->get_predecessors()[1] != bb)
continue;
// record loads (& dependency) to move
std::vector<ir::instruction*> loads;
// record original inst order
std::map<ir::instruction*, size_t> idx_map;
size_t idx = 0;
for (ir::instruction *inst : bb->get_inst_list()) {
if (auto *i = dynamic_cast<ir::masked_load_inst*>(inst))
recursive_defs(i, bb, loads);
idx_map[inst] = idx;
idx++;
}
// remove duplicates & keep the original input order
std::sort(loads.begin(), loads.end());
loads.erase(std::unique(loads.begin(), loads.end()), loads.end());
std::sort(loads.begin(), loads.end(), [&idx_map](ir::instruction *a, ir::instruction *b) {
return idx_map[a] < idx_map[b];
});
builder.set_insert_point(bb->get_first_non_phi());
auto& inst_list = bb->get_inst_list();
for (ir::instruction *i : loads){
auto it = std::find(inst_list.begin(), inst_list.end(), i);
// make sure we don't invalidate insert point
// in case instruction already at the top
if(it == builder.get_insert_point())
continue;
bb->erase(i);
builder.insert(i);
}
}
}
}
} // namespace triton::codegen::transform

View File

@@ -1,51 +0,0 @@
#include <iostream>
#include <algorithm>
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/codegen/transform/reorder.h"
namespace triton {
namespace codegen{
namespace transform{
void reorder::run(ir::module& mod){
// ir::builder &builder = mod.get_builder();
// std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
// for(ir::function *fn: mod.get_function_list())
// for(ir::basic_block *block: fn->blocks())
// for(ir::instruction* i: block->get_inst_list()){
// if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
// ir::value* _ptr = ld->get_pointer_operand();
// ir::value* _msk = ld->get_mask_operand();
// ir::value* _val = ld->get_false_value_operand();
// auto ptr = std::find(block->begin(), block->end(), _ptr);
// auto msk = std::find(block->begin(), block->end(), _msk);
// auto val = std::find(block->begin(), block->end(), _val);
// if(ptr == block->end() || msk == block->end() || val == block->end())
// continue;
// auto it = std::find(block->begin(), block->end(), i);
// int dist_ptr = std::distance(ptr, it);
// int dist_msk = std::distance(msk, it);
// int dist_val = std::distance(val, it);
// if(dist_ptr < dist_msk && dist_ptr < dist_val)
// builder.set_insert_point(++ptr);
// if(dist_msk < dist_ptr && dist_msk < dist_val)
// builder.set_insert_point(++msk);
// if(dist_val < dist_ptr && dist_val < dist_msk)
// builder.set_insert_point(++val);
// ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
// to_replace.push_back(std::make_pair(ld, new_ld));
// }
// }
// for(auto& x: to_replace)
// x.first->replace_all_uses_with(x.second);
}
}
}
}

View File

@@ -1,5 +0,0 @@
# Run the benchmarks
Install the required dependencies via `pip install -r requirements-bench.txt` from the triton/python/bench folder.
Run the benchmarks through `python3 bench/run.py`, this will produce an HTML report in a results folder.

View File

@@ -1,92 +0,0 @@
import torch
import triton
# -------------------------------
# Matrix Multiplication
# -------------------------------
nt = {False: 'n', True: 't'}
square_confs = [
triton.testing.Benchmark(
x_names=['M', 'N', 'K'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg='block',
line_vals=[16, 32, 64, 128],
line_names=['Block16', 'Block32', 'Block64', 'Block128'],
ylabel='TFLOPS',
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args={'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)
for AT in [False] for BT in [False]
for op_mode in ['dsd'] for layout_mode in ['dense']
]
@triton.testing.perf_report(square_confs)
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
Z, H = 1, 1
make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode]
# create layout
shape = {'sdd': (M, N), 'dsd': (K, M) if AT else (M, K), 'dds': (N, K) if BT else (K, N)}[op_mode]
layout = make_layout(H, shape[0] // block, shape[1] // block)
# creat inputs
a = torch.randn((Z, H, K, M) if AT else (Z, H, M, K), dtype=dtype, device='cuda')
b = torch.randn((Z, H, N, K) if BT else (Z, H, K, N), dtype=dtype, device='cuda')
# create op
tflops = lambda ms: num_flops / ms * 1e3
if provider == 'triton':
op = triton.ops.blocksparse.matmul(layout, block, op_mode, trans_a=AT, trans_b=BT)
# inputs
a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = {
'sdd': 2 * Z * K * float(layout.sum()) * block * block,
'dsd': 2 * Z * N * float(layout.sum()) * block * block,
'dds': 2 * Z * M * float(layout.sum()) * block * block
}[op_mode] * 1e-12
return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
# -------------------------------
# Softmax
# -------------------------------
square_confs = [
triton.testing.Benchmark(
x_names=['M', 'N'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg='block',
line_vals=[16, 32, 64],
line_names=['Block16', 'Block32', 'Block64'],
ylabel='GBPS',
plot_name=f'{layout_mode}-square',
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)
for layout_mode in ['dense', 'tril']
]
@triton.testing.perf_report(square_confs)
def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
Z, H = 1, 1
make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode]
layout = make_layout(H, M // block, N // block)
a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda')
if provider == 'triton':
a = triton.testing.sparsify_tensor(a, layout, block)
op = triton.ops.blocksparse.softmax(layout, block)
gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
bench_matmul.run(print_data=True, show_plots=True)

View File

@@ -1,41 +0,0 @@
import torch
import triton
confs = [
triton.testing.Benchmark(
x_names=['N'],
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'Torch'],
ylabel='GBPS',
plot_name=f'{mode}-2048',
args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
)
for mode in ['forward', 'backward']
]
@triton.testing.perf_report(confs)
def bench_op(M, N, dtype, mode, provider):
# create inputs
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
gbps = lambda ms: num_gb / ms * 1e3
# forward pass
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
'triton': triton.ops.cross_entropy}[provider]
if mode == 'forward':
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
if mode == 'backward':
y = op(x, idx)
dy = torch.randn_like(y)
fn = lambda: y.backward(dy, retain_graph=True)
mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=[x])
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
if __name__ == '__main__':
bench_op.run(print_data=True)

View File

@@ -1,67 +0,0 @@
import torch
import triton
def rounded_linspace(low, high, steps, div):
ret = torch.linspace(low, high, steps)
ret = (ret.int() + div - 1) // div * div
ret = torch.unique(ret)
return list(map(int, ret))
# Square benchmarks
nt = {False: "n", True: "t"}
square_confs = [
triton.testing.Benchmark(
x_names=["M", "N", "K"],
x_vals=rounded_linspace(512, 8192, 32, 128),
line_arg="provider",
line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS",
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
args={"AT": AT, "BT": BT, "dtype": torch.float16},
) for AT in [False] for BT in [False]
]
# Transformer training benchmarks
transformer_confs = [
triton.testing.Benchmark(
x_names=[x],
x_vals=rounded_linspace(NK // 16, NK, 32, 128),
line_arg="provider",
line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [12288]
for i, x in enumerate(["N", "K"])
for M in [2048]
]
@triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT:
a = a.t()
if BT:
b = b.t()
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
if provider == "cublas":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
if provider == "cutlass":
cutlass_matmul = triton.testing.cutlass_matmul
try:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms)
except Exception:
return None
return None

View File

@@ -1,2 +0,0 @@
pandas >= 1.3.3
matplotlib >= 3.4.3

View File

@@ -1,44 +0,0 @@
import argparse
import inspect
import os
import sys
import triton
def run_all(result_dir, names):
if not os.path.exists(result_dir):
os.makedirs(result_dir)
for mod in os.listdir(os.path.dirname(os.path.realpath(__file__))):
# skip non python files
if not mod.endswith('.py'):
continue
# skip file not in provided names
if names and names not in mod:
continue
# skip files that don't start with 'bench_'
if not mod.startswith('bench_'):
continue
print(f'running {mod}...')
mod = __import__(os.path.splitext(mod)[0])
benchmarks = inspect.getmembers(mod, lambda x: isinstance(x, triton.testing.Mark))
for name, bench in benchmarks:
curr_dir = os.path.join(result_dir, mod.__name__.replace('bench_', ''))
if len(benchmarks) > 1:
curr_dir = os.path.join(curr_dir, name.replace('bench_', ''))
if not os.path.exists(curr_dir):
os.makedirs(curr_dir)
bench.run(save_path=curr_dir)
def main(args):
parser = argparse.ArgumentParser(description="Run the benchmark suite.")
parser.add_argument("-r", "--result-dir", type=str, default='results', required=False)
parser.add_argument("-n", "--names", type=str, default='', required=False)
parser.set_defaults(feature=False)
args = parser.parse_args(args)
run_all(args.result_dir, args.names)
if __name__ == '__main__':
main(sys.argv[1:])

View File

@@ -0,0 +1,18 @@
import triton
import triton.language as tl
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xn,
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs))
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 128, "BLOCK_N": 128}, output="ttgir")
print(ret)

8
python/examples/empty.py Normal file
View File

@@ -0,0 +1,8 @@
import triton
import triton.language as tl
@triton.jit
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
pass
ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir")

View File

@@ -1,202 +0,0 @@
#include "cutlass/library/handle.h"
#include "cutlass/library/library.h"
#include "cutlass/library/operation_table.h"
#include "cutlass/library/singleton.h"
#include "pybind11/pybind11.h"
#include "triton/tools/bench.hpp"
using namespace cutlass;
using namespace cutlass::library;
std::map<std::vector<size_t>, const Operation *> op_cache_;
static int const kHostWorkspaceSize = (4 << 10);
static int const kDeviceWorkspaceSize = (4 << 20);
void run(int M, int N, int K,
int lda, int ldb, int ldc, int ldd,
void const *ptr_A, void const *ptr_B, void const *ptr_C, void *ptr_D,
void const *alpha, void const *beta,
ScalarPointerMode scalar_mode,
const Operation *operation,
cudaStream_t stream) {
GemmUniversalConfiguration configuration{
GemmUniversalMode::kGemm,
{M, N, K},
1,
lda,
ldb,
ldc,
ldd};
// host workspace size
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed)
throw std::runtime_error("Unable to find gemm operation");
char host_workspace[kHostWorkspaceSize];
// device workspace size
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
if (uint64_t(kDeviceWorkspaceSize) < device_workspace_size_needed)
throw std::runtime_error("Unable to find gemm operation");
static void *device_workspace;
// Initialize host and device workspaces
Status status = operation->initialize(&configuration, host_workspace, device_workspace, stream);
if (status != cutlass::Status::kSuccess)
throw std::runtime_error("Unable to initialize workspace");
// Run the operator
GemmArguments arguments{ptr_A, ptr_B, ptr_C, ptr_D, alpha, beta, scalar_mode};
operation->run(&arguments, host_workspace, device_workspace, stream);
}
const Operation *autotune(int M, int N, int K,
NumericTypeID element_compute,
NumericTypeID element_scalar,
void const *alpha,
NumericTypeID element_A,
LayoutTypeID layout_A,
ComplexTransform transform_A,
void const *ptr_A,
int lda,
NumericTypeID element_B,
LayoutTypeID layout_B,
ComplexTransform transform_B,
void const *ptr_B,
int ldb,
void const *beta,
NumericTypeID element_C,
void const *ptr_C,
int ldc,
void *ptr_D,
int ldd,
ScalarPointerMode scalar_mode,
int device_id,
cudaStream_t stream) {
// index operation table with functional key
GemmFunctionalKey key(
Provider::kCUTLASS,
GemmKind::kUniversal,
element_compute,
element_scalar,
element_A,
layout_A,
transform_A,
element_B,
layout_B,
transform_B,
element_C);
auto operators_it = Singleton::get().operation_table.gemm_operations.find(key);
if (operators_it == Singleton::get().operation_table.gemm_operations.end())
throw std::runtime_error("Unable to find gemm operation");
if (operators_it->second.empty())
throw std::runtime_error("Unable to find gemm operation");
cudaDeviceProp device_prop;
cudaError_t error = cudaGetDeviceProperties(&device_prop, device_id);
if (error != cudaSuccess)
throw std::runtime_error("Unable to get device properties");
int cc = device_prop.major * 10 + device_prop.minor;
// index operation table with preference key
// assume 8-bytes aligned memory pointers
int alignment = 8;
GemmPreferenceKey preference_key(cc, alignment);
auto autotune_it = operators_it->second.find(preference_key);
if (autotune_it == operators_it->second.end())
throw std::runtime_error("Unable to find gemm operation");
const std::vector<const Operation *> &operations = autotune_it->second;
if (operations.empty())
throw std::runtime_error("Unable to find gemm operation");
// auto-tune
const Operation *best = nullptr;
double best_ms = std::numeric_limits<double>::max();
for (const Operation *op : operations) {
auto fn = [&]() { run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D,
alpha, beta, scalar_mode, op, stream); };
triton::driver::cu_stream tt_stream((CUstream)stream, false);
double ms = triton::tools::bench(fn, &tt_stream, 10, 25);
if (ms < best_ms) {
best_ms = ms;
best = op;
}
}
return best;
}
// map of torch datatypes to cutlass datatypes
std::map<std::string, NumericTypeID> type_map = {
{"float16", NumericTypeID::kF16},
{"float32", NumericTypeID::kF32},
{"float64", NumericTypeID::kF64}};
void cutlass_matmul(uintptr_t A, uintptr_t B, uintptr_t C,
size_t M, size_t N, size_t K,
size_t stride_a_0, size_t stride_a_1,
size_t stride_b_0, size_t stride_b_1,
size_t stride_c_0, size_t stride_c_1,
std::string type_a, std::string type_b, std::string type_c,
size_t dev_id, uint64_t stream_handle) {
void *ptr_A = (void *)A;
void *ptr_B = (void *)B;
void *ptr_C = (void *)C;
void *ptr_D = ptr_C;
size_t lda = stride_a_0;
size_t ldb = stride_b_0;
size_t ldc = stride_c_1;
size_t ldd = ldc;
float alpha = 1.0f;
float beta = 0.0f;
// layout for A
LayoutTypeID layout_A;
if (stride_a_0 == 1)
layout_A = LayoutTypeID::kColumnMajor;
else if (stride_a_1 == 1)
layout_A = LayoutTypeID::kRowMajor;
else
throw std::runtime_error("A layout is not supported");
// layout for B
LayoutTypeID layout_B;
if (stride_b_0 == 1)
layout_B = LayoutTypeID::kColumnMajor;
else if (stride_b_1 == 1)
layout_B = LayoutTypeID::kRowMajor;
else
throw std::runtime_error("B layout is not supported");
// data types
NumericTypeID element_compute = NumericTypeID::kF32;
NumericTypeID element_A = type_map[type_a];
NumericTypeID element_B = type_map[type_b];
NumericTypeID element_C = type_map[type_c];
// misc. flags
ScalarPointerMode scalar_mode = ScalarPointerMode::kHost;
NumericTypeID element_scalar = NumericTypeID::kF32;
ComplexTransform transform_A = ComplexTransform::kNone;
ComplexTransform transform_B = ComplexTransform::kNone;
// runtime flags
cudaStream_t stream = (cudaStream_t)stream_handle;
// auto-tune
std::vector<size_t> tune_key = {M, N, K, (size_t)element_A, (size_t)element_B, (size_t)element_C,
dev_id, (size_t)element_compute, (size_t)scalar_mode};
auto it = op_cache_.find(tune_key);
if (it == op_cache_.end()) {
const Operation *op = autotune(M, N, K, element_compute, element_scalar, &alpha,
element_A, layout_A, transform_A, ptr_A, lda,
element_B, layout_B, transform_B, ptr_B, ldb,
&beta, element_C, ptr_C, ldc, ptr_D, ldd, scalar_mode,
dev_id, stream);
it = op_cache_.insert({tune_key, op}).first;
}
run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D, &alpha, &beta,
scalar_mode, it->second, stream);
}
void init_cutlass(pybind11::module &m) {
pybind11::module subm = m.def_submodule("cutlass");
subm.def("matmul", &cutlass_matmul, "matrix multiplication");
}

View File

@@ -1,676 +0,0 @@
#include "triton/ir/builder.h"
#include <functional>
#include <iostream>
#include <pybind11/pybind11.h>
namespace ir = triton::ir;
namespace py = pybind11;
static const std::string _builder_doc = R"pbdoc(
:param builder: IR builder to generate code into, optional, set automatically when called inside a @triton.jit function
:type builder: triton.ir.builder
)pbdoc";
#define VA_ARGS(...) , ##__VA_ARGS__
#define DEF_FUNC(MOD, PY_NAME, C_FUNC, ...) \
MOD.def(PY_NAME, C_FUNC, (C_FUNC##_docstr + _builder_doc).c_str(), \
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
void throw_not_implemented(std::string key) {
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. This is likely a bug on our side.");
}
void throw_not_int_or_float(std::string key) {
throw std::runtime_error("`" + key + "` only supported for integer and floating point types.");
}
enum type_code {
_bool,
int8,
int16,
int32,
int64,
float16,
float32,
float64
};
ir::type *make_ir(type_code ty, ir::builder *builder) {
switch (ty) {
case float16:
return builder->get_half_ty();
case float32:
return builder->get_float_ty();
default:
throw_not_implemented("make_ir");
}
}
type_code from_ir(ir::type *ty) {
if (ty->is_half_ty())
return float16;
if (ty->is_float_ty())
return float32;
throw_not_implemented("from_ir");
}
/*----------------------------------------------
definition of triton.cast / triton.ir.value.to
----------------------------------------------*/
std::string cast_docstr = R"pbdoc(
Tries to cast a block to a new data type.
:param input: The input block.
:type input: triton.ir.value
)pbdoc";
ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) {
ir::type *src_ty = input->get_type();
ir::type *dst_ty = make_ir(_dtype, builder);
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
// FP Truncation
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
if (truncate_fp)
return builder->create_fp_trunc(input, dst_ty);
// FP Extension
bool ext_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width();
if (ext_fp)
return builder->create_fp_ext(input, dst_ty);
// Int cast
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
return builder->create_int_cast(input, dst_ty, true);
// Float -> Int
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty())
return builder->create_fp_to_si(input, dst_ty);
// int -> Float
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty())
return builder->create_si_to_fp(input, dst_ty);
// Ptr -> Ptr
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::BitCast, input, dst_ty);
// * -> Bool
if (dst_sca_ty->is_bool_ty()) {
if (src_sca_ty->is_pointer_ty())
input = cast(input, int64, builder);
ir::value *other = builder->get_int64(0);
if (src_ty->is_bool_ty())
other = builder->create_splat(other, src_ty->get_block_shapes());
return builder->create_icmpNE(input, other);
}
throw_not_implemented("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
}
/*----------------------------------------------
definition of triton.broadcast_check
----------------------------------------------*/
std::string try_broadcast_docstr = R"pbdoc(
Tries to broadcast two blocks to a common compatible shape.
:param input: The first input block.
:type input: triton.ir.value
:param other: The second input block.
:type other: triton.ir.value
)pbdoc";
std::tuple<ir::value *, ir::value *> try_broadcast(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type();
// make_shape_compatible(block, scalar)
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());
// make_shape_compatible(scalar, block)
else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty())
lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes());
// make_shape_compatible(block, block)
else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) {
auto lhs_shape = lhs_ty->get_block_shapes();
auto rhs_shape = rhs_ty->get_block_shapes();
if (lhs_shape.size() != rhs_shape.size())
throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank");
ir::type::block_shapes_t ret_shape;
for (size_t i = 0; i < lhs_shape.size(); ++i) {
unsigned left = lhs_shape[i];
unsigned right = rhs_shape[i];
if (left == 1)
ret_shape.push_back(right);
else if (right == 1)
ret_shape.push_back(left);
else if (left == right)
ret_shape.push_back(left);
else
throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) +
": " + std::to_string(left) + " and " + std::to_string(right));
}
if (lhs_shape != ret_shape)
lhs = builder->create_broadcast(lhs, ret_shape);
if (rhs_shape != ret_shape)
rhs = builder->create_broadcast(rhs, ret_shape);
}
return std::make_tuple(lhs, rhs);
}
/*----------------------------------------------
definition of triton.broadcast_to
----------------------------------------------*/
std::string broadcast_to_docstr = R"pbdoc(
Tries to broadcast a block to a new shape.
:param input: The input block.
:type input: triton.value
:param shape: The new shape.
:type shape: tuple of int
)pbdoc";
ir::value *broadcast_to(ir::value *input, const ir::type::block_shapes_t &shape, ir::builder *builder) {
if (!input->get_type()->is_block_ty())
return builder->create_splat(input, shape);
auto src_shape = input->get_type()->get_block_shapes();
if (src_shape.size() != shape.size())
throw std::runtime_error("Cannot broadcast");
return builder->create_broadcast(input, shape);
}
/*----------------------------------------------
definition of triton.load
----------------------------------------------*/
std::string load_docstr = R"pbdoc(
Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`.
:param pointer: Pointer to the data to be loaded.
:type pointer: Block of triton.pointer
:param mask: if mask[idx] is false, do not load the data at `pointer[idx]`.
:type mask: Block of triton.bool, optional
:param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]`
:type other: Block of triton.value, optional
)pbdoc";
ir::value *load(ir::value *pointer, std::optional<ir::value *> _mask, std::optional<ir::value *> _other, ir::builder *builder) {
if (!_mask.has_value() && !_other.has_value())
return builder->create_load(pointer);
if (!_mask.has_value())
throw std::runtime_error("`other` cannot be provided without `mask`");
ir::value *mask = _mask.value();
ir::type *elt_ty = pointer->get_type()->get_scalar_ty()->get_pointer_element_ty();
auto shape = pointer->get_type()->get_block_shapes();
ir::value *other = _other.has_value() ? _other.value() : ir::undef_value::get(elt_ty);
other = cast(other, from_ir(elt_ty), builder);
other = broadcast_to(other, shape, builder);
mask = broadcast_to(mask, shape, builder);
return builder->create_masked_load(pointer, mask, other);
}
/*----------------------------------------------
definition of triton.store
----------------------------------------------*/
std::string store_docstr = R"pbdoc(
Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`.
:param pointer: The memory locations where the elements of `value` are stored.
:type pointer: Block of triton.pointer
:param value: The block of elements to be stored.
:type value: Block of triton.value
:param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`.
:type mask: Block of triton.bool, optional
)pbdoc";
ir::value *store(ir::value *ptr, ir::value *val, std::optional<ir::value *> _mask, ir::builder *builder) {
if (!_mask.has_value())
return builder->create_store(ptr, val);
ir::value *mask = _mask.value();
return builder->create_masked_store(ptr, val, mask);
}
/*----------------------------------------------
definition of triton.dot
----------------------------------------------*/
std::string dot_docstr = R"pbdoc(
Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions.
:param input: The first block to be multiplied.
:type input: 2D block of scalar-type in {`float16`, `float32`}
:param other: The second block to be multiplied.
:type other: 2D block of scalar-type in {`float16`, `float32`}
)pbdoc";
ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
ir::value *_0 = builder->get_float32(0);
unsigned M = lhs->get_type()->get_block_shapes()[0];
unsigned N = rhs->get_type()->get_block_shapes()[1];
_0 = builder->create_splat(_0, {M, N});
return builder->create_dot(lhs, rhs, _0);
}
/*----------------------------------------------
definition of triton.where
----------------------------------------------*/
std::string where_docstr = R"pbdoc(
Returns a block of elements from either `x` or `y`, depending on `condition`.
Note that `x` and `y` are always evaluated regardless of the value of `condition`.
If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
:param condition: When True (nonzero), yield x, otherwise yield y.
:type condition: Block of triton.bool
:param x: values selected at indices where condition is True.
:param y: values selected at indices where condition is False.
)pbdoc";
ir::value *where(ir::value *condition, ir::value *x, ir::value *y, ir::builder *builder) {
return builder->create_select(condition, x, y);
};
/*----------------------------------------------
definition of triton.arange
----------------------------------------------*/
std::string arange_docstr = R"pbdoc(
Returns contiguous values within the open interval [start, end).
:param start: Start of the interval.
:type start: int
:param stop: End of the interval.
:type stop: int
)pbdoc";
ir::value *arange(int start, int end, ir::builder *builder) {
return builder->get_range(start, end);
};
/*----------------------------------------------
definition of triton.program_id
----------------------------------------------*/
std::string program_id_docstr = R"pbdoc(
Returns the id of the current program instance along the given `axis`.
Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
)pbdoc";
ir::value *program_id(int axis, ir::builder *builder) {
return builder->create_get_program_id(axis);
};
/*----------------------------------------------
definition of triton.num_programs
----------------------------------------------*/
std::string num_programs_docstr = R"pbdoc(
Returns the number of program instances launched along the given `axis`.
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
:type axis: int
)pbdoc";
ir::value *num_programs(int axis, ir::builder *builder) {
return builder->create_get_num_programs(axis);
};
/*----------------------------------------------
definition of triton.zeros
----------------------------------------------*/
std::string zeros_docstr = R"pbdoc(
Returns a block filled with the scalar value 0 and the given shape.
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
:type shape: tuple of ints
:param dtype: Data-type of the new array, e.g., tl.float16
:type dtype: triton.ir.dtype
)pbdoc";
ir::value *zeros(ir::type::block_shapes_t shape, type_code _dtype, ir::builder *builder) {
ir::type *dtype = make_ir(_dtype, builder);
ir::value *_0 = ir::constant::get_null_value(dtype);
return builder->create_splat(_0, shape);
};
/*----------------------------------------------
definition of triton.exp
----------------------------------------------*/
std::string _exp_docstr = R"pbdoc(
Returns the element-wise exponential of `input`.
)pbdoc";
ir::value *_exp(ir::value *input, ir::builder *builder) {
return builder->create_exp(input);
};
/*----------------------------------------------
definition of triton.log
----------------------------------------------*/
std::string _log_docstr = R"pbdoc(
Returns the element-wise natural logarithm of `input`.
)pbdoc";
ir::value *_log(ir::value *input, ir::builder *builder) {
return builder->create_log(input);
};
/*----------------------------------------------
definition of triton.sqrt
----------------------------------------------*/
std::string sqrt_docstr = R"pbdoc(
Returns the element-wise square root of `input`.
)pbdoc";
ir::value *sqrt(ir::value *input, ir::builder *builder) {
return builder->create_sqrt(input);
};
/*----------------------------------------------
definition of triton.min
----------------------------------------------*/
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
if (scalar_ty->is_floating_point_ty())
return builder->create_reduce(input, FLOAT_OP, axis);
else if (scalar_ty->is_integer_ty())
return builder->create_reduce(input, INT_OP, axis);
else
throw_not_int_or_float(name);
}
std::string min_docstr = R"pbdoc(
Returns the minimum value of `input`.
)pbdoc";
ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
};
/*----------------------------------------------
definition of triton.max
----------------------------------------------*/
std::string max_docstr = R"pbdoc(
Returns the maximum value of `input`.
)pbdoc";
ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
};
/*----------------------------------------------
definition of triton.sum
----------------------------------------------*/
std::string sum_docstr = R"pbdoc(
Returns the sum of `input`.
)pbdoc";
ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
};
/*----------------------------------------------
definition of triton.atomic_cas
----------------------------------------------*/
std::string atomic_cas_docstr = R"pbdoc(
Atomic compare-and-swap.
)pbdoc";
ir::value *atomic_cas(ir::value *ptr, ir::value *cmp, ir::value *val, ir::builder *builder) {
return builder->create_atomic_cas(ptr, cmp, val);
};
/*----------------------------------------------
definition of triton.atomic_xchg
----------------------------------------------*/
std::string atomic_xchg_docstr = R"pbdoc(
Atomic exchange.
)pbdoc";
ir::value *atomic_xchg(ir::value *ptr, ir::value *val, ir::builder *builder) {
return builder->create_atomic_exch(ptr, val);
};
/*----------------------------------------------
debug barrier
----------------------------------------------*/
std::string debug_barrier_docstr = R"pbdoc(
Temporary hacky fixup for when the compiler forgets to insert sync barriers
)pbdoc";
ir::value *debug_barrier(ir::builder *builder) {
return builder->create_barrier();
}
#define DEF_BINARY_OP(MOD, PY_NAME, C_FUNC, ...) \
MOD.def(PY_NAME, binary_op(C_FUNC), (C_FUNC##_docstr + _builder_doc).c_str(), \
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
template <class FN>
std::function<ir::value *(ir::value *, ir::value *, ir::builder *builder)>
binary_op(const FN &fn) {
auto ret = [&fn](ir::value *self, ir::value *other, ir::builder *builder) {
//std::tie(self, other) = try_broadcast(self, other, builder);
return fn(self, other, builder);
};
return ret;
}
/*----------------------------------------------
definition of self + other
----------------------------------------------*/
std::string add_docstr = R"pbdoc(
Returns self + other, element-wise.
)pbdoc";
ir::value *add(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// ptr + offset
if (scalar_ty->is_pointer_ty())
return builder->create_gep(self, {other});
// float + float
else if (scalar_ty->is_floating_point_ty())
return builder->create_fadd(self, other);
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_add(self, other);
throw_not_implemented("add");
}
/*----------------------------------------------
definition of self - other
----------------------------------------------*/
std::string sub_docstr = R"pbdoc(
Returns self - other, element-wise.
)pbdoc";
ir::value *sub(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// ptr + offset
if (scalar_ty->is_pointer_ty())
return builder->create_gep(self, {other});
// float + float
if (scalar_ty->is_floating_point_ty())
return builder->create_fsub(self, other);
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_sub(self, other);
throw_not_implemented("sub");
}
/*----------------------------------------------
definition of self * other
----------------------------------------------*/
std::string mul_docstr = R"pbdoc(
Returns self * other, element-wise.
)pbdoc";
ir::value *mul(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float * float
if (scalar_ty->is_floating_point_ty())
return builder->create_fmul(self, other);
// int * int
else if (scalar_ty->is_integer_ty())
return builder->create_mul(self, other);
throw_not_implemented("mul");
}
/*----------------------------------------------
definition of self > other
----------------------------------------------*/
std::string greater_than_docstr = R"pbdoc(
Returns self > other, element-wise.
)pbdoc";
ir::value *greater_than(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float > float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGT(self, other);
// int > int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGT(self, other);
throw_not_implemented("greater_than");
}
/*----------------------------------------------
definition of self >= other
----------------------------------------------*/
std::string greater_equal_docstr = R"pbdoc(
Returns self >= other, element-wise.
)pbdoc";
ir::value *greater_equal(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float >= float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGE(self, other);
// int >= int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGE(self, other);
throw_not_implemented("greater_equal");
}
/*----------------------------------------------
definition of self < other
----------------------------------------------*/
std::string less_than_docstr = R"pbdoc(
Returns self < other, element-wise.
)pbdoc";
ir::value *less_than(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLT(self, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLT(self, other);
throw_not_implemented("less_than");
}
/*----------------------------------------------
definition of self <= other
----------------------------------------------*/
std::string less_equal_docstr = R"pbdoc(
Returns self <= other, element-wise.
)pbdoc";
ir::value *less_equal(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLE(self, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLE(self, other);
throw_not_implemented("less_equal");
}
/*----------------------------------------------
definition of self == other
----------------------------------------------*/
std::string equal_docstr = R"pbdoc(
Returns self == other, element-wise.
)pbdoc";
ir::value *equal(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOEQ(self, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpEQ(self, other);
throw_not_implemented("equal");
}
/*----------------------------------------------
definition of self / other
----------------------------------------------*/
std::string _div_docstr = R"pbdoc(
Returns self / other, element-wise.
)pbdoc";
ir::value *_div(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float / float
if (scalar_ty->is_floating_point_ty())
return builder->create_fdiv(self, other);
// int / int
else if (scalar_ty->is_integer_ty())
return builder->create_sdiv(self, other);
throw_not_implemented("div");
}
/*----------------------------------------------
definition of self % other
----------------------------------------------*/
std::string mod_docstr = R"pbdoc(
Returns self % other, element-wise.
)pbdoc";
ir::value *mod(ir::value *self, ir::value *other, ir::builder *builder) {
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
// float % int
if (scalar_ty->is_floating_point_ty())
return builder->create_frem(self, other);
// int % int
else if (scalar_ty->is_integer_ty())
return builder->create_srem(self, other);
throw_not_implemented("mod");
}
/*----------------------------------------------
definition of self & other
----------------------------------------------*/
std::string _and_docstr = R"pbdoc(
Returns self & other, element-wise.
)pbdoc";
ir::value *_and(ir::value *self, ir::value *other, ir::builder *builder) {
return builder->create_and(self, other);
}
/*----------------------------------------------
definition of minimum(self, other)
----------------------------------------------*/
std::string minimum_docstr = R"pbdoc(
Returns element-wise minimum of self and other
)pbdoc";
ir::value *minimum(ir::value *self, ir::value *other, ir::builder *builder) {
return where(less_than(self, other, builder), self, other, builder);
}
/*----------------------------------------------
definition of self[slices]
----------------------------------------------*/
enum slice_mode_t {
NEWAXIS,
ALL
};
std::string subscript_docstr = R"pbdoc(
returns self[slices].
:param slices: The slices to subscript with.
:type slices: List of `None` or `:` slices.
)pbdoc";
ir::value *subscript(ir::value *self, std::vector<py::object> slices, ir::builder *builder) {
std::vector<slice_mode_t> modes;
for (py::object slice : slices) {
py::object none = py::none();
py::object all = py::make_tuple(none, none, none);
if (slice.is(none))
modes.push_back(NEWAXIS);
else if (all.attr("__eq__")(slice))
modes.push_back(ALL);
else
throw std::runtime_error("slice must be None or (None, None, None)");
}
ir::type::block_shapes_t shape;
size_t curr = 0;
for (slice_mode_t mode : modes) {
if (mode == NEWAXIS)
shape.push_back(1);
else {
assert(mode == ALL);
shape.push_back(self->get_type()->get_block_shapes()[curr++]);
}
}
return builder->create_reshape(self, shape);
}

View File

@@ -8,8 +8,4 @@ void init_cutlass(pybind11::module &m);
PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API";
init_triton(m);
init_superblocking(m);
#ifdef WITH_CUTLASS_BINDINGS
init_cutlass(m);
#endif
}

View File

@@ -1,119 +0,0 @@
#include <iostream>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
// row-major 3d tensor
class tensor_3d {
public:
tensor_3d(int size_0, int size_1, int size_2, int *data = nullptr) : data_(size_0 * size_1 * size_2, 0) {
if (data)
std::copy(data, data + data_.size(), data_.begin());
stride_0_ = size_1 * size_2;
stride_1_ = size_2;
stride_2_ = 1;
}
int &operator()(int i, int j, int k) {
return data_[i * stride_0_ + j * stride_1_ + k];
}
private:
std::vector<int> data_;
int stride_0_;
int stride_1_;
int stride_2_;
};
std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width, int H, int M, int N) {
tensor_3d tmp(H, M, N);
std::vector<int> current(H, 0);
int num = 0;
std::vector<int> lut(H * M * N * 4);
for (ssize_t h = 0; h < H; h++) {
// surrounding indices
std::vector<int> ii_left(max_width, -1);
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
// start the dynamic programming algorithm
for (ssize_t m = 0; m < M; m++) {
for (ssize_t n = 0; n < N; n++) {
int v = layout(h, m, n);
if (v == 0)
continue;
int n_left = ii_left[max_width - 1];
int m_top = ii_top[max_width - 1][n];
int top = (m_top >= 0) ? tmp(h, m_top, n) : 0;
int left = (n_left >= 0) ? tmp(h, m, n_left) : 0;
int topleft = (m_top >= 0 && n_left >= 0) ? tmp(h, m_top, n_left) : 0;
int width = std::min(left, std::min(top, topleft)) + 1;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for (int nn = n_left + 1; nn < n; nn++)
if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n])
width = 1;
tmp(h, m, n) = width;
// update n_left ring buffer
for (int k = 0; k < max_width - 1; k++)
ii_left[k] = ii_left[k + 1];
ii_left[max_width - 1] = n;
// update ii_top ring buffer
for (int k = 0; k < max_width - 1; k++)
ii_top[k][n] = ii_top[k + 1][n];
ii_top[max_width - 1][n] = m;
// block is too small -- skip
if (width != max_width)
continue;
// retained blocks are set to zeros
for (ssize_t km = 0; km < max_width; km++)
for (ssize_t kn = 0; kn < max_width; kn++) {
int mm = ii_top[km][n];
int nn = ii_left[kn];
if (mm < 0 || nn < 0)
continue;
layout(h, mm, nn) = 0;
tmp(h, mm, nn) = 0;
lut[num++] = (int)h;
lut[num++] = (int)mm;
lut[num++] = (int)nn;
lut[num++] = idx(h, mm, nn);
}
}
}
}
lut.resize(num);
return lut;
}
typedef std::pair<int, pybind11::array_t<int>> lut_t;
std::vector<lut_t> superblock(uintptr_t LAYOUT, int H, int M, int N, int start_width) {
std::vector<lut_t> ret;
int current = 0;
tensor_3d layout(H, M, N, (int *)LAYOUT);
tensor_3d idx(H, M, N);
for (int64_t h = 0; h < H; h++)
for (int64_t m = 0; m < M; m++)
for (int64_t n = 0; n < N; n++) {
if (layout(h, m, n) == 0)
continue;
idx(h, m, n) = current++;
}
// create lut
for (int max_width = start_width; max_width > 0; max_width /= 2) {
auto lut = segment_blocks(layout, idx, max_width, H, M, N);
if (lut.size() == 0)
continue;
ret.push_back(std::make_pair(max_width, pybind11::array_t<int>(lut.size(), lut.data())));
}
return ret;
}
void init_superblocking(pybind11::module &m) {
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
}

View File

@@ -764,19 +764,26 @@ void init_triton_ir(py::module &&m) {
.def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference);
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
.def("dump", &mlir::ModuleOp::dump)
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("has_function", [](mlir::ModuleOp &self, std::string &funcName) -> bool {
if (self.lookupSymbol(funcName))
return true;
return false;
})
.def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
return self.lookupSymbol<mlir::FuncOp>(funcName);
})
// dynamic_attr is used to transfer ownership of the MLIR context to the module
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module", py::dynamic_attr())
.def("dump", &mlir::ModuleOp::dump)
.def("str", [](mlir::ModuleOp &self) -> std::string {
std::string str;
llvm::raw_string_ostream os(str);
self.print(os);
return str;
})
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("has_function", [](mlir::ModuleOp &self, std::string &funcName) -> bool {
if (self.lookupSymbol(funcName))
return true;
return false;
})
.def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
return self.lookupSymbol<mlir::FuncOp>(funcName);
})
;
py::class_<mlir::FuncOp, mlir::OpState>(m, "function")

View File

@@ -1,164 +0,0 @@
import subprocess
import sys
import pytest
import torch
import triton
import triton.language as tl
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
DEVICE_NAME = 'v100'
#######################
# Utilities
#######################
def nvsmi(attrs):
attrs = ','.join(attrs)
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
out = subprocess.check_output(cmd)
ret = out.decode(sys.stdout.encoding).split(',')
ret = [int(x) for x in ret]
return ret
#######################
# Matrix Multiplication
#######################
sm_clocks = {'v100': 1350, 'a100': 1350}
mem_clocks = {'v100': 877, 'a100': 1215}
matmul_data = {
'v100': {
# square
(256, 256, 256): {'float16': 0.027},
(512, 512, 512): {'float16': 0.158},
(1024, 1024, 1024): {'float16': 0.466},
(2048, 2048, 2048): {'float16': 0.695},
(4096, 4096, 4096): {'float16': 0.831},
(8192, 8192, 8192): {'float16': 0.849},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0128},
(16, 4096, 4096): {'float16': 0.0883},
(16, 8192, 8192): {'float16': 0.101},
(64, 1024, 1024): {'float16': 0.073},
(64, 4096, 4096): {'float16': 0.270},
(64, 8192, 8192): {'float16': 0.459},
(1024, 64, 1024): {'float16': 0.0692},
(4096, 64, 4096): {'float16': 0.264},
(8192, 64, 8192): {'float16': 0.452},
},
'a100': {
(256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006},
(512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030},
(1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385},
(4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711},
(8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259},
(16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431},
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
(64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177},
}
# # deep reductions
# (64 , 64 , 16384) : {'a100': 0.},
# (64 , 64 , 65536) : {'a100': 0.},
# (256 , 256 , 8192 ) : {'a100': 0.},
# (256 , 256 , 32768) : {'a100': 0.},
}
@pytest.mark.parametrize('M, N, K, dtype_str',
[(M, N, K, dtype_str)
for M, N, K in matmul_data[DEVICE_NAME].keys()
for dtype_str in ['float16']])
def test_matmul(M, N, K, dtype_str):
if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100':
pytest.skip('Only test float32 & int8 on a100')
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
torch.manual_seed(0)
ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
ref_sm_clock = sm_clocks[DEVICE_NAME]
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
if dtype == torch.int8:
a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda')
b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda')
b = b.t() # only test row-col layout
else:
a = torch.randn((M, K), dtype=dtype, device='cuda')
b = torch.randn((K, N), dtype=dtype, device='cuda')
fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
#######################
# Element-Wise
#######################
@triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
elementwise_data = {
'v100': {
1024 * 16: 0.0219,
1024 * 64: 0.0791,
1024 * 256: 0.243,
1024 * 1024: 0.534,
1024 * 4096: 0.796,
1024 * 16384: 0.905,
1024 * 65536: 0.939,
},
'a100': {
1024 * 16: 0.008,
1024 * 64: 0.034,
1024 * 256: 0.114,
1024 * 1024: 0.315,
1024 * 4096: 0.580,
1024 * 16384: 0.782,
1024 * 65536: 0.850,
}
}
@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys())
def test_elementwise(N):
torch.manual_seed(0)
ref_gpu_util = elementwise_data[DEVICE_NAME][N]
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
ref_mem_clock = mem_clocks[DEVICE_NAME]
max_gpu_perf = get_dram_gbps()
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz'
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z)
y = torch.randn_like(z)
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)

File diff suppressed because it is too large Load Diff

View File

@@ -1,177 +0,0 @@
import numpy as np
import pytest
import scipy.stats
import torch
import triton
import triton.language as tl
#####################################
# Reference Philox Implementation
#####################################
class PhiloxConfig:
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE)
self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE)
self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE)
self.DTYPE = DTYPE
# This is better for GPU
PHILOX_32 = PhiloxConfig(
PHILOX_KEY_A=0x9E3779B9,
PHILOX_KEY_B=0xBB67AE85,
PHILOX_ROUND_A=0xD2511F53,
PHILOX_ROUND_B=0xCD9E8D57,
DTYPE=np.uint32,
)
# This is what numpy implements
PHILOX_64 = PhiloxConfig(
PHILOX_KEY_A=0x9E3779B97F4A7C15,
PHILOX_KEY_B=0xBB67AE8584CAA73B,
PHILOX_ROUND_A=0xD2E7470EE14C6C93,
PHILOX_ROUND_B=0xCA5A826395121157,
DTYPE=np.uint64,
)
class CustomPhilox4x:
def __init__(self, seed, config):
self._config = config
seed = self._into_pieces(seed)
self._key = np.array(seed[:2], dtype=self._dtype)
self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype)
@property
def _dtype(self):
return self._config.DTYPE
def _into_pieces(self, n, pad=4):
res = []
while len(res) < pad:
res.append(np.array(n, dtype=self._dtype))
n >>= (np.dtype(self._dtype).itemsize * 8)
assert n == 0
return tuple(res)
def _multiply_low_high(self, a, b):
low = a * b
high = int(a) * int(b)
high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype)
return low, high
def _single_round(self, counter, key):
lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0])
lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2])
ret0 = hi1 ^ counter[1] ^ key[0]
ret1 = lo1
ret2 = hi0 ^ counter[3] ^ key[1]
ret3 = lo0
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
def _raise_key(self, key):
pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B]
return key + np.array(pk, dtype=self._dtype)
def random_raw(self):
counter = self._counter
key = self._key
for _ in range(10):
counter = self._single_round(counter, key)
key = self._raise_key(key)
self.advance(1)
return counter
def advance(self, n_steps):
self._counter[0] += n_steps
assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets"
class CustomPhilox(CustomPhilox4x):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.buffer = []
def random_raw(self):
if len(self.buffer) == 0:
self.buffer = list(super().random_raw())[::-1]
return int(self.buffer.pop())
#####################################
# Unit Tests
#####################################
BLOCK = 1024
# test generation of random uint32
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in ['10', '4,53', '10000']
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
)
def test_randint(size, seed, device='cuda'):
size = list(map(int, size.split(',')))
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.randint(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.int32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()
# reference result
gen = CustomPhilox4x(seed, config=PHILOX_32)
out_ref = [gen.random_raw()[0] for _ in out_tri]
assert out_tri == out_ref
# test uniform PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_rand(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.rand(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert all((x >= 0) & (x <= 1))
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG
@pytest.mark.parametrize('size, seed',
[(size, seed) for size in [1000000]
for seed in [0, 42, 124, 54]]
)
def test_randn(size, seed, device='cuda'):
@triton.jit
def kernel(X, N, seed):
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
rand = tl.randn(seed, offset)
tl.store(X + offset, rand, mask=offset < N)
# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert abs(x.mean()) < 1e-2
assert abs(x.std() - 1) < 1e-2

View File

@@ -1,187 +0,0 @@
import pytest
import torch
import triton
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@pytest.mark.parametrize("TRANS_A", [False, True])
@pytest.mark.parametrize("TRANS_B", [False, True])
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
@pytest.mark.parametrize("DTYPE", [torch.float16])
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
seed = 0
torch.manual_seed(seed)
is_sdd = MODE == "sdd"
is_dsd = MODE == "dsd"
is_dds = MODE == "dds"
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
# create inputs
# create op
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
c_shape = (Z, H, M, N)
shape = {
"sdd": (M, N),
"dsd": (a_shape[2], a_shape[3]),
"dds": (b_shape[2], b_shape[3]),
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# create data
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1)
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1)
dc_ref, dc_tri = triton.testing.make_pair(c_shape)
# compute [torch]
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
a_ref = do_mask(a_ref) if is_dsd else a_ref
b_ref = do_mask(b_ref) if is_dds else b_ref
a_ref.retain_grad()
b_ref.retain_grad()
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
b_ref.transpose(2, 3) if TRANS_B else b_ref)
c_ref.backward(dc_ref)
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad
# triton result
dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
a_tri = do_sparsify(a_tri) if is_dsd else a_tri
b_tri = do_sparsify(b_tri) if is_dds else b_tri
a_tri.retain_grad()
b_tri.retain_grad()
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest)
triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest)
da_tri = a_tri.grad
db_tri = b_tri.grad
# compare
triton.testing.assert_almost_equal(c_ref, c_tri)
triton.testing.assert_almost_equal(da_ref, da_tri)
triton.testing.assert_almost_equal(db_ref, db_tri)
configs = [
(16, 256),
(32, 576),
(64, 1871),
(128, 2511),
]
@pytest.mark.parametrize("is_dense", [False, True])
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
# set seed
torch.random.manual_seed(0)
Z, H, M, N = 2, 3, WIDTH, WIDTH
# initialize layout
# make sure each row has at least one non-zero element
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
if is_dense:
layout[:] = 1
else:
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# initialize data
a_shape = (Z, H, M, N)
a_ref, a_tri = triton.testing.make_pair(a_shape)
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
# compute [torch]
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
a_ref.retain_grad()
at_mask = torch.ones((M, N), device="cuda")
if is_causal:
at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
a_ref[M == 0] = float("-inf")
out_ref = torch.softmax(a_ref * scale, -1)
out_ref.backward(dout_ref)
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
# compute [triton]
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
a_tri.retain_grad()
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
out_tri.backward(dout_tri)
da_tri = a_tri.grad
# compare
triton.testing.assert_almost_equal(out_tri, out_ref)
triton.testing.assert_almost_equal(da_tri, da_ref)
@pytest.mark.parametrize("block", [16, 32, 64])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_attention_fwd_bwd(
block,
dtype,
input_scale=1.0,
scale=1 / 8.0,
n_ctx=256,
batch_size=2,
n_heads=2,
):
# inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
]
# Triton:
n_blocks = n_ctx // block
layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long))
query, key, value = [x.clone() for x in qkvs]
query.retain_grad()
key.retain_grad()
value.retain_grad()
attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
# ad hoc loss
loss = (attn_out ** 2).mean()
loss.backward()
grads = [query.grad, key.grad, value.grad]
# Torch version:
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype)
attn_mask = torch.tril(attn_mask, diagonal=0)
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
torch_q.retain_grad()
torch_k.retain_grad()
torch_v.retain_grad()
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
scores = scores + attn_mask
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
# ad hoc loss
torch_loss = (torch_attn_out ** 2).mean()
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
# comparison
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
triton.testing.assert_almost_equal(loss, torch_loss)
for g1, g2 in zip(grads, torch_grads):
triton.testing.assert_almost_equal(g1, g2)
@pytest.mark.parametrize("block", [16, 32, 64])
def triton_attention(
layout,
block: int,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
w = sparse_dot_sdd_nt(query, key)
w = sparse_softmax(w, scale=scale, is_causal=True)
a = sparse_dot_dsd_nn(w, value)
return a

View File

@@ -1,35 +0,0 @@
import pytest
import torch
import triton
@pytest.mark.parametrize("M, N, dtype, mode",
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']
for mode in ['forward', 'backward']
]
)
def test_op(M, N, dtype, mode):
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
# create inputs
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
# forward pass
tt_y = triton.ops.cross_entropy(x, idx)
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
if mode == 'forward':
triton.testing.assert_almost_equal(th_y, tt_y)
# backward pass
elif mode == 'backward':
dy = torch.randn_like(tt_y)
# triton backward
tt_y.backward(dy)
tt_dx = x.grad.clone()
# torch backward
x.grad.zero_()
th_y.backward(dy)
th_dx = x.grad.clone()
triton.testing.assert_almost_equal(th_dx, tt_dx)

View File

@@ -1,98 +0,0 @@
import itertools
import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
itertools.chain(
*[
[
# 1 warp
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
# split-k
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
],
# n-stage
*[
[
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
# split-k
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
]
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80 and DTYPE == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
if DTYPE == "bfloat16" and SPLIT_K != 1:
pytest.skip("bfloat16 matmuls don't allow split_k for now")
torch.manual_seed(0)
# nuke kernel decorators -- will set meta-parameters manually
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
kernel = triton.ops._matmul.kernel
decorators = kernel.kernel_decorators
kernel.kernel_decorators = []
triton.autotune(configs, [])(kernel)
kernel.kernel_decorators += decorators[1:]
# get matrix shape
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K * SPLIT_K if K is None else K
# allocate/transpose inputs
DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE]
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a
b = b.t() if BT else b
# run test
th_c = torch.matmul(a, b)
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
triton.testing.assert_almost_equal(th_c, tt_c)

View File

@@ -1,132 +0,0 @@
import os
import re
import shutil
import pytest
import torch
import triton
import triton.language as tl
from triton.code_gen import JITFunction
tmpdir = ".tmp"
@triton.jit
def function_1(i):
i = i + 1
i = function_2(i)
return i
@triton.jit
def function_2(i):
i = i + 1
return i
@triton.jit
def kernel(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
@triton.jit(do_not_specialize=["i"])
def kernel_nospec(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
def apply_src_change(target, old, new):
kernel.hash = None
function_1.hash = None
function_2.hash = None
function_1.src = function_1.src.replace(old, new)
target.src = target.src.replace(old, new)
ret = target.cache_key
target.src = target.src.replace(new, old)
return ret
def test_nochange():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
assert baseline == updated
def test_toplevel_change():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
assert baseline != updated
def test_nested1_change():
baseline = kernel.cache_key
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
assert baseline != updated
def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
def test_reuse():
counter = 0
def inc_counter(*args, **kwargs):
nonlocal counter
counter += 1
JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
for i in range(10):
kernel[(1,)](x, 1, BLOCK=1024)
assert counter == 1
@pytest.mark.parametrize('mode', ['enable', 'disable'])
def test_specialize(mode):
counter = 0
def inc_counter(*args, **kwargs):
nonlocal counter
counter += 1
JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
target = {'enable': 5, 'disable': 1}[mode]
for i in [1, 2, 4, 8, 16, 32]:
function[(1,)](x, i, BLOCK=512)
assert counter == target
@pytest.mark.parametrize("value, value_type", [
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
cache_str = None
def get_cache_str(*args, **kwargs):
nonlocal cache_str
cache_str = kwargs['key'].split('-')
triton.code_gen.JITFunction.cache_hook = get_cache_str
reset_tmp_dir()
x = torch.tensor([3.14159], device='cuda')
kernel[(1, )](value, x)
triton.code_gen.JITFunction.cache_hook = None
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
spec_type = None if cache_str_match is None else cache_str_match.group(1)
assert spec_type == value_type

View File

@@ -1,98 +0,0 @@
import subprocess
import numpy as np
import pytest
import torch
import triton
import triton.language as tl
def get_p2p_matrix():
try:
stdout = subprocess.check_output(["nvidia-smi", "topo", "-p2p", "n"]).decode("ascii")
except subprocess.CalledProcessError:
return pytest.skip("No multi-GPU topology", allow_module_level=True)
lines = stdout.split("Legend")[0].split('\n')[1:]
matrix = np.array([line.split('\t')[1:-1] for line in lines][:-2])
if matrix.size <= 1:
return pytest.skip("No multi-GPU topology", allow_module_level=True)
else:
return matrix
def get_p2p_devices():
matrix = get_p2p_matrix()
idx = np.where(matrix == "OK")
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
def get_non_p2p_devices():
matrix = get_p2p_matrix()
idx = np.where(matrix == "NS")
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
p2p_devices = get_p2p_devices()
non_p2p_devices = get_non_p2p_devices()
@triton.jit
def _copy(from_ptr, to_ptr, N, **meta):
pid = tl.program_id(0)
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
values = tl.load(from_ptr + offsets, mask=offsets < N)
tl.store(to_ptr + offsets, values, mask=offsets < N)
@pytest.mark.skipif(not p2p_devices, reason="No pair of device with P2P support")
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
[(device_kernel, device_from, device_to, stream_from, stream_to)
for device_kernel in p2p_devices
for device_from in p2p_devices
for device_to in p2p_devices
for stream_from in ['default', 'custom']
for stream_to in ['default', 'custom']
])
def test_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
if device_to == device_from:
return pytest.skip()
torch.cuda.set_device(device_kernel)
N = 512
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
_copy[grid](x_from, x_to, N, BLOCK=1024)
assert torch.allclose(x_from, x_to.to(device_from))
@pytest.mark.skipif(not non_p2p_devices, reason="No pair of device with no P2P support")
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
[(device_kernel, device_from, device_to, stream_from, stream_to)
for device_kernel in non_p2p_devices
for device_from in non_p2p_devices
for device_to in non_p2p_devices
for stream_from in ['default', 'custom']
for stream_to in ['default', 'custom']
])
def test_non_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
if device_to == device_from:
return pytest.skip()
with pytest.raises(RuntimeError):
torch.cuda.set_device(device_kernel)
N = 512
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
_copy[grid](x_from, x_to, N, BLOCK=1024)

View File

@@ -6,9 +6,9 @@ __version__ = '2.0.0'
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
# submodules
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \
JITFunction, Config, Autotuner, reinterpret
from .utils import *
from .runtime import jit, Config, autotune, heuristics
from .compiler import compile
from . import language
from . import code_gen
from . import testing
from . import ops

File diff suppressed because it is too large Load Diff

806
python/triton/compiler.py Normal file
View File

@@ -0,0 +1,806 @@
from __future__ import annotations
import ast
import sys
import warnings
from typing import Dict, Union
import triton
import triton._C.libtriton.triton as _triton
def str_to_ty(name):
if name[0] == "*":
ty = str_to_ty(name[1:])
return triton.language.pointer_type(ty)
tys = {
"fp8": triton.language.float8,
"fp16": triton.language.float16,
"bf16": triton.language.bfloat16,
"fp32": triton.language.float32,
"fp64": triton.language.float64,
"i8": triton.language.int8,
"i16": triton.language.int16,
"i32": triton.language.int32,
"i64": triton.language.int64,
"u8": triton.language.uint8,
"u16": triton.language.uint16,
"u32": triton.language.uint32,
"u64": triton.language.uint64,
"B": triton.language.int1,
}
return tys[name]
def mangle_ty(ty):
if ty.is_ptr():
return 'P' + mangle_ty(ty.element_ty)
if ty.is_int():
return 'i' + str(ty.int_bitwidth)
if ty.is_fp8():
return 'fp8'
if ty.is_fp16():
return 'fp16'
if ty.is_bf16():
return 'bf16'
if ty.is_fp32():
return 'fp32'
if ty.is_fp64():
return 'fp64'
if ty.is_void():
return 'V'
if ty.is_block():
elt = mangle_ty(ty.scalar)
shape = '_'.join(map(str, ty.shape))
return f'{elt}S{shape}S'
assert False, "Unsupport type"
def mangle_fn(name, arg_tys, constants):
# doesn't mangle ret type, which must be a function of arg tys
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
mangled_constants = mangled_constants.replace('.', '_d_')
mangled_constants = mangled_constants.replace("'", '_sq_')
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
return ret
class enter_sub_region:
def __init__(self, generator: CodeGenerator):
self.generator = generator
def __enter__(self):
# record lscope & local_defs in the parent scope
self.liveins = self.generator.lscope.copy()
self.prev_defs = self.generator.local_defs.copy()
self.generator.local_defs = {}
self.insert_block = self.generator.builder.get_insertion_block()
return self.liveins, self.insert_block
def __exit__(self, *args, **kwargs):
self.generator.builder.set_insertion_point_to_end(self.insert_block)
self.generator.lscope = self.liveins
self.generator.local_defs = self.prev_defs
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
self.builder = _triton.ir.builder(context)
self.module = self.builder.create_module() if module is None else module
self.function_ret_types = function_types
self.prototype = prototype
self.gscope = gscope
self.lscope = dict()
self.attributes = attributes
self.constants = constants
self.is_kernel = is_kernel
self.last_node = None
self.builtins = {
'range': range,
'min': triton.language.minimum,
'float': float,
'int': int,
'print': print,
'isinstance': isinstance,
'getattr': getattr,
}
# SSA-construction
# name => triton.language.tensor
self.local_defs: Dict[str, triton.language.tensor] = {}
self.global_uses: Dict[str, triton.language.tensor] = {}
def get_value(self, name):
''' This function:
1. make sure `name` is defined
2. if `name` is triton.language.tensor, get stored tensor by calling
`self._get_tensor()`
'''
# search node.id in local scope
ret = None
if name in self.lscope:
ret = self.lscope[name]
if name not in self.local_defs:
self.global_uses[name] = ret
# search node.id in global scope
elif name in self.gscope:
ret = self.gscope[name]
# search node.id in builtins
elif name in self.builtins:
ret = self.builtins[name]
else:
raise ValueError(f'{name} is not defined')
return ret
def set_value(self, name: str,
value: Union[triton.language.tensor, triton.language.constexpr]) -> None:
''' This function:
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
1. record local defined name (FIXME: should consider control flow)
2. store tensor in self.lvalue
'''
self.lscope[name] = value
self.local_defs[name] = value
def is_triton_tensor(self, value):
return isinstance(value, triton.language.tensor)
#
# AST visitor
#
def visit_compound_statement(self, stmts):
for stmt in stmts:
self.last_ret_type = self.visit(stmt)
if isinstance(stmt, ast.Return):
break
return stmts and isinstance(stmt, ast.Return)
def visit_Module(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_List(self, node):
ctx = self.visit(node.ctx)
assert ctx is None
elts = [self.visit(elt) for elt in node.elts]
return elts
# By design, only non-kernel functions can return
def visit_Return(self, node):
ret_value = self.visit(node.value)
if ret_value is None:
self.builder.ret([])
return None
if isinstance(ret_value, tuple):
ret_values = [triton.language.core._to_tensor(v, self.builder) for v in ret_value]
ret_types = [v.type for v in ret_values]
self.builder.ret([v.handle for v in ret_values])
return tuple(ret_types)
else:
ret = triton.language.core._to_tensor(ret_value, self.builder)
self.builder.ret([ret_value.handle])
return ret.type
def visit_FunctionDef(self, node):
arg_names, kwarg_names = self.visit(node.args)
# initialize defaults
for i, default_value in enumerate(node.args.defaults):
arg_node = node.args.args[-i - 1]
annotation = arg_node.annotation
name = arg_node.arg
st_target = ast.Name(id=name, ctx=ast.Store())
if annotation is None:
init_node = ast.Assign(targets=[st_target], value=default_value)
else:
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node)
# initialize function
fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants)
fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder))
self.module.push_back(fn)
entry = fn.add_entry_block()
arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names):
if i in self.constants:
cst = self.constants[i]
if not isinstance(cst, triton.language.constexpr):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst)
else:
pass
if i in self.attributes:
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i])
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
idx += 1
insert_pt = self.builder.get_insertion_block()
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
self.builder.set_insertion_point_to_start(entry)
# visit function body
has_ret = self.visit_compound_statement(node.body)
# finalize function
if not has_ret:
self.builder.ret([])
else:
# update return type
if isinstance(self.last_ret_type, tuple):
self.prototype.ret_types = list(self.last_ret_type)
fn.reset_type(self.prototype.to_ir(self.builder))
else:
self.prototype.ret_types = [self.last_ret_type]
fn.reset_type(self.prototype.to_ir(self.builder))
if insert_pt:
self.builder.set_insertion_point_to_end(insert_pt)
def visit_arguments(self, node):
arg_names = []
for arg in node.args:
arg_names += [self.visit(arg)]
kwarg_names = self.visit(node.kwarg)
return arg_names, kwarg_names
def visit_arg(self, node):
ast.NodeVisitor.generic_visit(self, node)
return node.arg
def visit_AnnAssign(self, node):
# extract attributes
annotation = self.visit(node.annotation)
target = self.visit(node.target)
value = self.visit(node.value)
# constexpr
if annotation == triton.language.constexpr:
if target in self.lscope:
raise ValueError(f'{target} is already defined.'
f' constexpr cannot be reassigned.')
if not isinstance(value, triton.language.constexpr):
value = triton.language.constexpr(value)
self.lscope[target] = value
return self.lscope[target]
# default: call visit_Assign
return self.visit_Assign(node)
def visit_Assign(self, node):
_names = []
for target in node.targets:
_names += [self.visit(target)]
assert len(_names) == 1
names = _names[0]
values = self.visit(node.value)
if not isinstance(names, tuple):
names = [names]
if not isinstance(values, tuple):
values = [values]
for name, value in zip(names, values):
# by default, constexpr are assigned into python variable
if isinstance(value, triton.language.constexpr):
value = value.value
if not isinstance(value, triton.language.tensor):
value = triton.language.core._to_tensor(value, self.builder)
self.set_value(name, value)
def visit_AugAssign(self, node):
name = node.target.id
lhs = ast.Name(id=name, ctx=ast.Load())
rhs = ast.BinOp(lhs, node.op, node.value)
assign = ast.Assign(targets=[node.target], value=rhs)
self.visit(assign)
return self.get_value(name)
def visit_Name(self, node):
if type(node.ctx) == ast.Store:
return node.id
return self.get_value(node.id)
def visit_Store(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_Load(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_Tuple(self, node):
args = [self.visit(x) for x in node.elts]
return tuple(args)
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
if isinstance(lhs, triton.language.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.constexpr):
rhs = rhs.value
fn = {
ast.Add: '__add__',
ast.Sub: '__sub__',
ast.Mult: '__mul__',
ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__',
ast.Mod: '__mod__',
ast.Pow: '__pow__',
ast.LShift: '__lshift__',
ast.RShift: '__rshift__',
ast.BitAnd: '__and__',
ast.BitOr: '__or__',
ast.BitXor: '__xor__',
}[type(node.op)]
if self.is_triton_tensor(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_tensor(rhs):
fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder)
else:
return getattr(lhs, fn)(rhs)
def visit_If(self, node):
cond = self.visit(node.test)
if isinstance(cond, triton.language.tensor):
cond = cond.to(triton.language.int1, _builder=self.builder)
with enter_sub_region(self) as sr:
liveins, ip_block = sr
then_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(then_block)
self.visit_compound_statement(node.body)
then_defs = self.local_defs.copy()
# when need an else block when:
# 1. we have an orelse node
# or
# 2. the then block defines new variable
if then_defs or node.orelse:
if node.orelse:
self.lscope = liveins
self.local_defs = {}
else_block = self.builder.create_block()
self.builder.set_insertion_point_to_end(else_block)
self.visit_compound_statement(node.orelse)
else_defs = self.local_defs.copy()
else:
# collect else_defs
else_defs = {}
for name in then_defs:
if name in liveins:
assert self.is_triton_tensor(then_defs[name])
assert self.is_triton_tensor(liveins[name])
else_defs[name] = liveins[name]
# collect yields
names = []
ret_types = []
for then_name in then_defs:
for else_name in else_defs:
if then_name == else_name:
if then_defs[then_name].type == else_defs[else_name].type:
names.append(then_name)
ret_types.append(then_defs[then_name].type)
self.builder.set_insertion_point_to_end(ip_block)
if then_defs or node.orelse: # with else block
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
then_block.merge_block_before(if_op.get_then_block())
self.builder.set_insertion_point_to_end(if_op.get_then_block())
self.builder.create_yield_op([then_defs[n].handle for n in names])
if not node.orelse:
else_block = if_op.get_else_block()
else:
else_block.merge_block_before(if_op.get_else_block())
self.builder.set_insertion_point_to_end(if_op.get_else_block())
self.builder.create_yield_op([else_defs[n].handle for n in names])
else: # no else block
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
then_block.merge_block_before(if_op.get_then_block())
# update values yielded by IfOp
for i, name in enumerate(names):
new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i])
self.lscope[name] = new_tensor
self.local_defs[name] = new_tensor
else:
if isinstance(cond, triton.language.constexpr):
cond = cond.value
if cond:
self.visit_compound_statement(node.body)
else:
self.visit_compound_statement(node.orelse)
def visit_IfExp(self, node):
cond = self.visit(node.test)
if cond.value:
return self.visit(node.body)
else:
return self.visit(node.orelse)
def visit_Pass(self, node):
pass
def visit_Compare(self, node):
assert len(node.comparators) == 1
assert len(node.ops) == 1
lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0])
if isinstance(lhs, triton.language.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.constexpr):
rhs = rhs.value
if type(node.ops[0]) == ast.Is:
return triton.language.constexpr(lhs is rhs)
if type(node.ops[0]) == ast.IsNot:
return triton.language.constexpr(lhs is not rhs)
fn = {
ast.Eq: '__eq__',
ast.NotEq: '__ne__',
ast.Lt: '__lt__',
ast.LtE: '__le__',
ast.Gt: '__gt__',
ast.GtE: '__ge__',
}[type(node.ops[0])]
if self.is_triton_tensor(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_tensor(rhs):
fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder)
else:
return getattr(lhs, fn)(rhs)
def visit_UnaryOp(self, node):
op = self.visit(node.operand)
if type(node.op) == ast.Not:
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
return triton.language.constexpr(not op)
if isinstance(op, triton.language.constexpr):
op = op.value
fn = {
ast.USub: '__neg__',
ast.UAdd: '__pos__',
ast.Invert: '__invert__',
}[type(node.op)]
if self.is_triton_tensor(op):
return getattr(op, fn)(_builder=self.builder)
return getattr(op, fn)()
def visit_While(self, node):
with enter_sub_region(self) as sr:
liveins, insert_block = sr
# condtion (the before region)
cond_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(cond_block)
cond = self.visit(node.test)
# loop body (the after region)
loop_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(loop_block)
self.visit_compound_statement(node.body)
loop_defs = self.local_defs
# collect loop-carried values
names = []
ret_types = []
init_args = []
yields = []
for name in loop_defs:
if name in liveins:
# We should not def new constexpr
assert self.is_triton_tensor(loop_defs[name])
assert self.is_triton_tensor(liveins[name])
if loop_defs[name].type == liveins[name].type:
# these are loop-carried values
names.append(name)
ret_types.append(loop_defs[name].type)
init_args.append(liveins[name])
yields.append(loop_defs[name])
self.builder.set_insertion_point_to_end(insert_block)
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
[arg.handle for arg in init_args])
# merge the condition region
before_block = self.builder.create_block_with_parent(while_op.get_before(),
[ty.to_ir(self.builder) for ty in ret_types])
cond_block.merge_block_before(before_block)
self.builder.set_insertion_point_to_end(before_block)
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
# merge the loop body
after_block = self.builder.create_block_with_parent(while_op.get_after(),
[ty.to_ir(self.builder) for ty in ret_types])
loop_block.merge_block_before(after_block)
self.builder.set_insertion_point_to_end(after_block)
self.builder.create_yield_op([y.handle for y in yields])
# update global uses in while_op
for i, name in enumerate(names):
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
# WhileOp defines new values, update the symbol table (lscope, local_defs)
for i, name in enumerate(names):
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
self.lscope[name] = new_def
self.local_defs[name] = new_def
for stmt in node.orelse:
assert False, "Not implemented"
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Subscript(self, node):
assert node.ctx.__class__.__name__ == "Load"
lhs = self.visit(node.value)
slices = self.visit(node.slice)
if self.is_triton_tensor(lhs):
return lhs.__getitem__(slices, _builder=self.builder)
return lhs[slices]
def visit_ExtSlice(self, node):
return [self.visit(dim) for dim in node.dims]
def visit_For(self, node):
iterator = self.visit(node.iter.func)
if iterator != self.builtins['range']:
raise RuntimeError('Only `range` iterator currently supported')
# static for loops: all iterator arguments are constexpr
iter_args = [self.visit(arg) for arg in node.iter.args]
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
if is_static:
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
iter_args = [arg.value for arg in iter_args]
range = iterator(*iter_args)
if len(range) <= 10:
for i in iterator(*iter_args):
self.lscope[node.target.id] = triton.language.constexpr(i)
self.visit_compound_statement(node.body)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
return
# collect lower bound (lb), upper bound (ub), and step
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
# lb/ub/step might be constexpr, we need to cast them to tensor
lb = triton.language.core._to_tensor(lb, self.builder).handle
ub = triton.language.core._to_tensor(ub, self.builder).handle
step = triton.language.core._to_tensor(step, self.builder).handle
# ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
lb = self.builder.create_to_index(lb)
ub = self.builder.create_to_index(ub)
step = self.builder.create_to_index(step)
with enter_sub_region(self) as sr:
liveins, insert_block = sr
# create loop body block
block = self.builder.create_block()
self.builder.set_insertion_point_to_start(block)
# visit loop body
self.visit_compound_statement(node.body)
# If a variable (name) is defined in both its parent & itself, then it's
# a loop-carried variable. (They must be of the same type)
init_args = []
yields = []
names = []
for name in self.local_defs:
if name in liveins:
assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
assert self.is_triton_tensor(liveins[name])
if self.local_defs[name].type == liveins[name].type:
names.append(name)
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
# create ForOp
self.builder.set_insertion_point_to_end(insert_block)
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
block.merge_block_before(for_op.get_body(0))
# create YieldOp
self.builder.set_insertion_point_to_end(for_op.get_body(0))
self.builder.create_yield_op([y.handle for y in yields])
for_op_region = for_op.get_body(0).get_parent()
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
# replace global uses with block arguments
for i, name in enumerate(names):
# arg0 is the induction variable
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
# update lscope & local_defs (ForOp defines new values)
for i, name in enumerate(names):
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
for stmt in node.orelse:
assert False, "Don't know what to do with else after for"
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Slice(self, node):
lower = self.visit(node.lower)
upper = self.visit(node.upper)
step = self.visit(node.step)
return slice(lower, upper, step)
def visit_Index(self, node):
return self.visit(node.value)
def visit_keyword(self, node):
return {node.arg: self.visit(node.value)}
def visit_Call(self, node):
fn = self.visit(node.func)
if isinstance(fn, triton.language.constexpr):
fn = fn.value
kws = dict()
for keyword in node.keywords:
kws.update(self.visit(keyword))
args = [self.visit(arg) for arg in node.args]
if isinstance(fn, triton.runtime.JITFunction):
from inspect import getcallargs
args = getcallargs(fn.fn, *args, **kws)
args = [args[name] for name in fn.arg_names]
args = [arg if isinstance(arg, triton.language.tensor)
else triton.language.constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
constants = {i: args[i] for i in constexprs}
# generate call
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
arg_vals = [arg.handle for arg in args if arg is not None]
arg_types = [arg.type for arg in args if arg is not None]
fn_name = mangle_fn(fn.__name__, arg_types, constants)
# generate function def if necessary
if not self.module.has_function(fn_name):
ret_type = triton.language.void
prototype = triton.language.function_type([ret_type], arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
else:
callee_ret_type = self.function_ret_types[fn_name]
symbol = self.module.get_function(fn_name)
call_op = self.builder.call(symbol, arg_vals)
if call_op.get_num_results() == 0:
return None
elif call_op.get_num_results() == 1:
return triton.language.tensor(call_op.get_result(0), callee_ret_type)
else:
# should return a tuple of tl.tensor
results = []
for i in range(call_op.get_num_results()):
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
return tuple(results)
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core:
return fn(*args, _builder=self.builder, **kws)
if fn in self.builtins.values():
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
for arg in args]
return fn(*args, **kws)
def visit_Constant(self, node):
return triton.language.constexpr(node.value)
if sys.version_info < (3, 8):
def visit_NameConstant(self, node):
return triton.language.constexpr(node.value)
def visit_Num(self, node):
return triton.language.constexpr(node.n)
def visit_Str(self, node):
return triton.language.constexpr(ast.literal_eval(node))
def visit_Attribute(self, node):
lhs = self.visit(node.value)
return getattr(lhs, node.attr)
def visit_Expr(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_NoneType(self, node):
return None
def visit(self, node):
if node is not None:
self.last_node = node
with warnings.catch_warnings():
# The ast library added visit_Constant and deprecated some other
# methods but we can't move to that without breaking Python 3.6 and 3.7.
warnings.simplefilter("ignore", DeprecationWarning) # python 3.9
warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
return super().visit(node)
def generic_visit(self, node):
typename = type(node).__name__
raise NotImplementedError("Unsupported node: {}".format(typename))
class CompilationError(Exception):
def __init__(self, src, node):
self.message = f'at {node.lineno}:{node.col_offset}:\n'
self.message += '\n'.join(src.split('\n')[:node.lineno])
self.message += '\n' + ' ' * node.col_offset + '^'
self.src = src
self.node = node
super().__init__(self.message)
def __reduce__(self):
# this is necessary to make CompilationError picklable
return (type(self), (self.src, self.node))
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.required = required
self.limit = limit
self.name = name
super().__init__(self.message)
def __reduce__(self):
# this is necessary to make CompilationError picklable
return (type(self), (self.required, self.limit, self.name))
def make_triton_ir(fn, signature, constants = dict(), attributes = dict()):
context = _triton.ir.context()
context.load_triton()
# create kernel prototype
arg_types = signature.replace(' ','').split(',')
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
arg_types = [str_to_ty(x) for x in arg_types]
prototype = triton.language.function_type([], arg_types)
# visit kernel AST
gscope = fn.__globals__.copy()
generator = CodeGenerator(context, prototype, gscope=gscope, constants=constants, attributes=attributes, is_kernel=True)
try:
generator.visit(fn.parse())
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(fn.src, node) from e
ret = generator.module
# module takes ownership of the MLIR context
ret.context = context
return ret
def make_tritongpu_ir(mod, num_warps):
pm = _triton.ir.pass_manager(mod.context)
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.run(mod)
return mod
def optimize_tritongpu_ir(mod, num_stages):
pm = _triton.ir.pass_manager(mod.context)
pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_triton_gpu_combine_pass()
pm.add_triton_gpu_verifier_pass()
pm.run(mod)
return mod
def make_ptx(mod):
# TODO
return mod
def compile(fn, signature, constants = dict(), attributes = dict(), num_warps=4, num_stages=3, output = "ttgir"):
assert output in ["ttir", "ttgir", "ptx"]
# triton-ir
module = make_triton_ir(fn, signature, constants, attributes)
if output == "ttir":
return module.str()
# tritongpu-ir
module = make_tritongpu_ir(module, num_warps)
module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir":
return module.str()
# ptx
if output == "ptx":
return make_ptx(module)
assert False

View File

@@ -0,0 +1,2 @@
from .jit import JITFunction, jit
from .autotuner import Config, autotune, heuristics

View File

@@ -0,0 +1,202 @@
from __future__ import annotations
import builtins
import time
from typing import Dict
class Autotuner:
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'''
if not configs:
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.cache = dict()
self.kernel = kernel
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'early_config_prune' in prune_configs_by:
early_config_prune = prune_configs_by['early_config_prune']
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return triton.testing.do_bench(kernel_call)
def __call__(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx])
if key not in self.cache:
# prune configs
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
class Config:
"""
An object that represents a possible kernel configuration for the auto-tuner to try.
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
:type meta: dict[Str, Any]
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
`num_warps=8`, then each kernel instance will be automatically parallelized to
cooperatively execute using `8 * 32 = 256` threads.
:type num_warps: int
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
:type num_stages: int
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
function are args.
"""
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
self.kwargs = kwargs
self.num_warps = num_warps
self.num_stages = num_stages
self.pre_hook = pre_hook
def __str__(self):
res = []
for k, v in self.kwargs.items():
res.append(f'{k}: {v}')
res.append(f'num_warps: {self.num_warps}')
res.append(f'num_stages: {self.num_stages}')
return ', '.join(res)
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
def wrapper(kernel):
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
fn.kernel_decorators.append(wrapper)
return fn
return decorator
def heuristics(values):
"""
Decorator for specifying how the values of certain meta-parameters may be computed.
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
.. highlight:: python
.. code-block:: python
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
each such function takes a list of positional arguments as input.
.type values: dict[str, Callable[[list[Any]], Any]]
"""
def decorator(fn):
def wrapper(kernel):
def fun(*args, **meta):
for v, heur in values.items():
assert v not in meta
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
return kernel(*args, **meta)
return fun
fn.kernel_decorators.append(wrapper)
return fn
return decorator

View File

@@ -0,0 +1,268 @@
from __future__ import annotations
import ast
import functools
import hashlib
import inspect
import os
import subprocess
import tempfile
import textwrap
import triton
import triton._C.libtriton.triton as _triton
from ..tools.disasm import extract
# -----------------------------------------------------------------------------
# Binary
# -----------------------------------------------------------------------------
class Binary:
def __init__(self, backend, name, asm, shared_mem, num_warps):
self.backend = backend
self.name = name
self.asm = asm
self.shared_mem = shared_mem
self.num_warps = num_warps
class LoadedBinary:
def __init__(self, device: int, bin: Binary):
module, kernel = _triton.code_gen.load_binary(bin.backend,
bin.name,
bin.asm,
bin.shared_mem,
device)
self.bin = bin
self.asm = bin.asm
self.sass = ''
self.module = module
self.kernel = kernel
self.device = device
self.shared_mem = bin.shared_mem
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
grid_0, grid_1, grid_2,
self.bin.num_warps * 32, 1, 1,
args, self.bin.shared_mem)
def get_sass(self, fun=None):
if self.sass:
return self.sass
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(self.asm['cubin'])
self.sass = extract(path, fun)
finally:
os.remove(path)
self.asm['sass'] = self.sass
return self.sass
# -----------------------------------------------------------------------------
# Kernel
# -----------------------------------------------------------------------------
class Kernel:
def __call__(self, *args, grid, num_warps=4, num_stages=3, **kwargs):
raise RuntimeError("Not implemented. Public repo implementation will be rewritten to reduce latency.")
# -----------------------------------------------------------------------------
# Dependencies Finder
# -----------------------------------------------------------------------------
class DependenciesFinder(ast.NodeVisitor):
"""
This AST visitor is used to find dependencies of a JITFunction. This can
be used to invalidate a JITFunction's hash when its source code -- or
that of its dependencies -- changes.
"""
def __init__(self, globals, src) -> None:
super().__init__()
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
self.globals = globals
def visit_Name(self, node):
return self.globals.get(node.id, None)
def visit_Attribute(self, node):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or lhs is triton:
return None
return getattr(lhs, node.attr)
def visit_Call(self, node):
func = self.visit(node.func)
if func is None:
return
if inspect.isbuiltin(func):
return
if func.__module__ and func.__module__.startswith('triton.'):
return
assert isinstance(func, JITFunction)
if func.hash is None:
tree = ast.parse(func.src)
finder = DependenciesFinder(func.__globals__, func.src)
finder.visit(tree)
func.hash = finder.ret
self.ret = (self.ret + func.hash).encode("utf-8")
self.ret = hashlib.md5(self.ret).hexdigest()
# -----------------------------------------------------------------------------
# JITFunction
# -----------------------------------------------------------------------------
@functools.lru_cache()
def version_key():
import pkgutil
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# backend
with open(triton._C.libtriton.__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# language
language_path = os.path.join(*triton.__path__, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# ptxas version
try:
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
except Exception:
ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
class JITFunction:
cache_hook = None
def __init__(self, fn, version=None, inline=True, do_not_specialize=None):
# information of wrapped function
self.fn = fn
self.module = fn.__module__
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.version = version
self.inline = inline
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize]
# cache for callable driver objects (e.g. CUkernel)
self.bin_cache = dict()
self.hash = None
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__
self.kernel_decorators = []
self.kernel = None
# annotations
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
self.__annotations__ = fn.__annotations__
# constexprs
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
# forward docs
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
self.__globals__ = fn.__globals__
self.__module__ = fn.__module__
@property
@functools.lru_cache()
def cache_key(self):
# TODO : hash should be attribute of `self`
if self.hash is None:
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.hash = dependencies_finder.ret + version_key()
return self.hash
# we do not parse `src` in the constructor because
# the user might want to monkey-patch self.src dynamically.
# Some unit tests do this, for example.
def parse(self):
tree = ast.parse(self.src)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
assert isinstance(tree.body[0], ast.FunctionDef)
return tree
def __call__(self, *args, **kwargs):
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
# - when `.src` attribute is set, cache path needs
# to be reinitialized
# - when kernel decorators change, cached kernel
# needs to be cleared
def __setattr__(self, name, value):
if name == 'kernel_decorators':
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
if name == 'src':
self.hash = None
JITFunction.cache_key.fget.cache_clear()
def _init_kernel(self):
if self.kernel is None:
self.kernel = Kernel(self)
for decorator in reversed(self.kernel_decorators):
self.kernel = decorator(self.kernel)
return self.kernel
def __getitem__(self, grid):
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
class Launcher:
def __init__(self, kernel, grid):
self.kernel = kernel
self.grid = grid
def __call__(self, *wargs, **kwargs):
return self.kernel(*wargs, **kwargs, grid=self.grid)
return Launcher(self._init_kernel(), grid)
def __repr__(self):
return f"JITFunction({self.module}:{self.fn.__name__})"
# -----------------------------------------------------------------------------
# `jit` decorator
# -----------------------------------------------------------------------------
def jit(*args, **kwargs):
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* objects within the triton.language package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
if args:
assert len(args) == 1
assert callable(args[0])
return JITFunction(args[0], **kwargs)
else:
def decorator(fn):
return JITFunction(fn, **kwargs)
return decorator

View File

@@ -5,7 +5,7 @@ import sys
import torch
import triton._C.libtriton.triton as _triton
from .code_gen import OutOfResources
from .compiler import OutOfResources
try:
import triton._C.libtriton.cutlass as _cutlass

46
python/triton/utils.py Normal file
View File

@@ -0,0 +1,46 @@
from __future__ import annotations
import torch
def cdiv(x, y):
return (x + y - 1) // y
def next_power_of_2(n):
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
self.is_cuda = base.is_cuda
self.device = base.device
def data_ptr(self):
return self.base.data_ptr()
def __str__(self) -> str:
return f'TensorWrapper[{self.dtype}]({self.base})'
def reinterpret(tensor, dtype):
if isinstance(tensor, TensorWrapper):
if dtype == tensor.base.dtype:
# Reinterpreting to the original interpretation; return the base.
return tensor.base
else:
# Reinterpreting a wrapped tensor to a different type.
return TensorWrapper(tensor.base, dtype)
elif isinstance(tensor, torch.Tensor):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')

View File

@@ -1,261 +0,0 @@
============================= test session starts ==============================
platform linux -- Python 3.7.7, pytest-7.1.1, pluggy-1.0.0
rootdir: /home/da/codes/triton-mlir-rewrite/triton/rewrite-test
collected 6 items
scf_tests.py .....module {
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = call @"cdiv__i32__1cconstexpr[64]"(%arg3) : (i32) -> i32
%2 = call @"cdiv__i32__1cconstexpr[64]"(%arg4) : (i32) -> i32
%c8_i32 = arith.constant 8 : i32
%3 = arith.muli %2, %c8_i32 : i32
%4 = arith.divsi %0, %3 : i32
%c8_i32_0 = arith.constant 8 : i32
%5 = arith.muli %4, %c8_i32_0 : i32
%6 = arith.subi %1, %5 : i32
%7 = call @"minimum__i32__1cconstexpr[8]"(%6) : (i32) -> i32
%8 = arith.remsi %0, %7 : i32
%9 = arith.addi %5, %8 : i32
%10 = arith.remsi %0, %3 : i32
%11 = arith.divsi %10, %7 : i32
%c64_i32 = arith.constant 64 : i32
%12 = arith.muli %9, %c64_i32 : i32
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%14 = tt.broadcast %12 : (i32) -> tensor<64xi32>
%15 = arith.addi %14, %13 : tensor<64xi32>
%c64_i32_1 = arith.constant 64 : i32
%16 = arith.muli %11, %c64_i32_1 : i32
%17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%18 = tt.broadcast %16 : (i32) -> tensor<64xi32>
%19 = arith.addi %18, %17 : tensor<64xi32>
%20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%21 = tt.reshape %15 : (tensor<64xi32>) -> tensor<64x1xi32>
%22 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
%23 = arith.muli %21, %22 : tensor<64x1xi32>
%24 = tt.reshape %20 : (tensor<32xi32>) -> tensor<1x32xi32>
%c1_i32 = arith.constant 1 : i32
%25 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32>
%26 = arith.muli %24, %25 : tensor<1x32xi32>
%27 = tt.broadcast %23 : (tensor<64x1xi32>) -> tensor<64x32xi32>
%28 = tt.broadcast %26 : (tensor<1x32xi32>) -> tensor<64x32xi32>
%29 = arith.addi %27, %28 : tensor<64x32xi32>
%30 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
%31 = tt.getelementptr %30, %29, : tensor<64x32x!tt.ptr<f16>>
%32 = tt.reshape %20 : (tensor<32xi32>) -> tensor<32x1xi32>
%33 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
%34 = arith.muli %32, %33 : tensor<32x1xi32>
%35 = tt.reshape %19 : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_2 = arith.constant 1 : i32
%36 = tt.broadcast %c1_i32_2 : (i32) -> tensor<1x64xi32>
%37 = arith.muli %35, %36 : tensor<1x64xi32>
%38 = tt.broadcast %34 : (tensor<32x1xi32>) -> tensor<32x64xi32>
%39 = tt.broadcast %37 : (tensor<1x64xi32>) -> tensor<32x64xi32>
%40 = arith.addi %38, %39 : tensor<32x64xi32>
%41 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
%42 = tt.getelementptr %41, %40, : tensor<32x64x!tt.ptr<f16>>
%cst = arith.constant 0.000000e+00 : f32
%43 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
%44 = arith.index_cast %c0_i32 : i32 to index
%45 = arith.index_cast %arg5 : i32 to index
%46 = arith.index_cast %c32_i32 : i32 to index
%47:3 = scf.for %arg9 = %44 to %45 step %46 iter_args(%arg10 = %43, %arg11 = %31, %arg12 = %42) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
%cst_6 = arith.constant dense<true> : tensor<64x32xi1>
%cst_7 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
%77 = tt.load %arg11, %cst_6, %cst_7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
%cst_8 = arith.constant dense<true> : tensor<32x64xi1>
%cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
%78 = tt.load %arg12, %cst_8, %cst_9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
%cst_10 = arith.constant 0.000000e+00 : f32
%79 = tt.broadcast %cst_10 : (f32) -> tensor<64x64xf32>
%80 = tt.dot %77, %78, %79 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
%81 = arith.addf %arg10, %80 : tensor<64x64xf32>
%c32_i32_11 = arith.constant 32 : i32
%82 = tt.broadcast %c32_i32_11 : (i32) -> tensor<64x32xi32>
%83 = tt.getelementptr %arg11, %82, : tensor<64x32x!tt.ptr<f16>>
%c32_i32_12 = arith.constant 32 : i32
%84 = arith.muli %arg7, %c32_i32_12 : i32
%85 = tt.broadcast %84 : (i32) -> tensor<32x64xi32>
%86 = tt.getelementptr %arg12, %85, : tensor<32x64x!tt.ptr<f16>>
scf.yield %81, %83, %86 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
}
%48 = arith.truncf %47#0 : tensor<64x64xf32> to tensor<64x64xf16>
%c64_i32_3 = arith.constant 64 : i32
%49 = arith.muli %9, %c64_i32_3 : i32
%50 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%51 = tt.broadcast %49 : (i32) -> tensor<64xi32>
%52 = arith.addi %51, %50 : tensor<64xi32>
%c64_i32_4 = arith.constant 64 : i32
%53 = arith.muli %11, %c64_i32_4 : i32
%54 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%55 = tt.broadcast %53 : (i32) -> tensor<64xi32>
%56 = arith.addi %55, %54 : tensor<64xi32>
%57 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32>
%58 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
%59 = arith.muli %58, %57 : tensor<64x1xi32>
%60 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
%61 = tt.getelementptr %60, %59, : tensor<64x1x!tt.ptr<f16>>
%62 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_5 = arith.constant 1 : i32
%63 = tt.broadcast %c1_i32_5 : (i32) -> tensor<1x64xi32>
%64 = arith.muli %62, %63 : tensor<1x64xi32>
%65 = tt.broadcast %61 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
%66 = tt.broadcast %64 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%67 = tt.getelementptr %65, %66, : tensor<64x64x!tt.ptr<f16>>
%68 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32>
%69 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
%70 = arith.cmpi slt, %68, %69 : tensor<64x1xi32>
%71 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32>
%72 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
%73 = arith.cmpi slt, %71, %72 : tensor<1x64xi32>
%74 = tt.broadcast %70 : (tensor<64x1xi1>) -> tensor<64x64xi1>
%75 = tt.broadcast %73 : (tensor<1x64xi1>) -> tensor<64x64xi1>
%76 = arith.andi %74, %75 : tensor<64x64xi1>
tt.store %67, %48, %76, : tensor<64x64xf16>
return
}
func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 {
%c64_i32 = arith.constant 64 : i32
%0 = arith.addi %arg0, %c64_i32 : i32
%c1_i32 = arith.constant 1 : i32
%1 = arith.subi %0, %c1_i32 : i32
%c64_i32_0 = arith.constant 64 : i32
%2 = arith.divsi %1, %c64_i32_0 : i32
return %2 : i32
}
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
%c8_i32 = arith.constant 8 : i32
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
%c8_i32_0 = arith.constant 8 : i32
%1 = select %0, %arg0, %c8_i32_0 : i32
return %1 : i32
}
}
module {
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%c8_i32 = arith.constant 8 : i32
%c63_i32 = arith.constant 63 : i32
%c64_i32 = arith.constant 64 : i32
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%cst_0 = arith.constant dense<true> : tensor<64x32xi1>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
%cst_2 = arith.constant dense<true> : tensor<32x64xi1>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
%c32_i32 = arith.constant 32 : i32
%c1_i32 = arith.constant 1 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.addi %arg3, %c63_i32 : i32
%2 = arith.divsi %1, %c64_i32 : i32
%3 = arith.addi %arg4, %c63_i32 : i32
%4 = arith.divsi %3, %c64_i32 : i32
%5 = arith.muli %4, %c8_i32 : i32
%6 = arith.divsi %0, %5 : i32
%7 = arith.muli %6, %c8_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = arith.cmpi slt, %8, %c8_i32 : i32
%10 = select %9, %8, %c8_i32 : i32
%11 = arith.remsi %0, %10 : i32
%12 = arith.addi %7, %11 : i32
%13 = arith.remsi %0, %5 : i32
%14 = arith.divsi %13, %10 : i32
%15 = arith.muli %12, %c64_i32 : i32
%16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%17 = tt.broadcast %15 : (i32) -> tensor<64xi32>
%18 = arith.addi %17, %16 : tensor<64xi32>
%19 = arith.muli %14, %c64_i32 : i32
%20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%21 = tt.broadcast %19 : (i32) -> tensor<64xi32>
%22 = arith.addi %21, %20 : tensor<64xi32>
%23 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%24 = tt.reshape %18 : (tensor<64xi32>) -> tensor<64x1xi32>
%25 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
%26 = arith.muli %24, %25 : tensor<64x1xi32>
%27 = tt.reshape %23 : (tensor<32xi32>) -> tensor<1x32xi32>
%28 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32>
%29 = arith.muli %27, %28 : tensor<1x32xi32>
%30 = tt.broadcast %26 : (tensor<64x1xi32>) -> tensor<64x32xi32>
%31 = tt.broadcast %29 : (tensor<1x32xi32>) -> tensor<64x32xi32>
%32 = arith.addi %30, %31 : tensor<64x32xi32>
%33 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
%34 = tt.getelementptr %33, %32, : tensor<64x32x!tt.ptr<f16>>
%35 = tt.reshape %23 : (tensor<32xi32>) -> tensor<32x1xi32>
%36 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
%37 = arith.muli %35, %36 : tensor<32x1xi32>
%38 = tt.reshape %22 : (tensor<64xi32>) -> tensor<1x64xi32>
%39 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32>
%40 = arith.muli %38, %39 : tensor<1x64xi32>
%41 = tt.broadcast %37 : (tensor<32x1xi32>) -> tensor<32x64xi32>
%42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<32x64xi32>
%43 = arith.addi %41, %42 : tensor<32x64xi32>
%44 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
%45 = tt.getelementptr %44, %43, : tensor<32x64x!tt.ptr<f16>>
%46 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
%47 = arith.index_cast %arg5 : i32 to index
%48:3 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
%78 = tt.load %arg11, %cst_0, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
%79 = tt.load %arg12, %cst_2, %cst_3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
%80 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
%81 = tt.dot %78, %79, %80 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
%82 = arith.addf %arg10, %81 : tensor<64x64xf32>
%83 = tt.broadcast %c32_i32 : (i32) -> tensor<64x32xi32>
%84 = tt.getelementptr %arg11, %83, : tensor<64x32x!tt.ptr<f16>>
%85 = arith.muli %arg7, %c32_i32 : i32
%86 = tt.broadcast %85 : (i32) -> tensor<32x64xi32>
%87 = tt.getelementptr %arg12, %86, : tensor<32x64x!tt.ptr<f16>>
scf.yield %82, %84, %87 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
}
%49 = arith.truncf %48#0 : tensor<64x64xf32> to tensor<64x64xf16>
%50 = arith.muli %12, %c64_i32 : i32
%51 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%52 = tt.broadcast %50 : (i32) -> tensor<64xi32>
%53 = arith.addi %52, %51 : tensor<64xi32>
%54 = arith.muli %14, %c64_i32 : i32
%55 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%56 = tt.broadcast %54 : (i32) -> tensor<64xi32>
%57 = arith.addi %56, %55 : tensor<64xi32>
%58 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32>
%59 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
%60 = arith.muli %59, %58 : tensor<64x1xi32>
%61 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
%62 = tt.getelementptr %61, %60, : tensor<64x1x!tt.ptr<f16>>
%63 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32>
%64 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32>
%65 = arith.muli %63, %64 : tensor<1x64xi32>
%66 = tt.broadcast %62 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
%67 = tt.broadcast %65 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%68 = tt.getelementptr %66, %67, : tensor<64x64x!tt.ptr<f16>>
%69 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32>
%70 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
%71 = arith.cmpi slt, %69, %70 : tensor<64x1xi32>
%72 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32>
%73 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
%74 = arith.cmpi slt, %72, %73 : tensor<1x64xi32>
%75 = tt.broadcast %71 : (tensor<64x1xi1>) -> tensor<64x64xi1>
%76 = tt.broadcast %74 : (tensor<1x64xi1>) -> tensor<64x64xi1>
%77 = arith.andi %75, %76 : tensor<64x64xi1>
tt.store %68, %49, %77, : tensor<64x64xf16>
return
}
func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 {
%c64_i32 = arith.constant 64 : i32
%c63_i32 = arith.constant 63 : i32
%0 = arith.addi %arg0, %c63_i32 : i32
%1 = arith.divsi %0, %c64_i32 : i32
return %1 : i32
}
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
%c8_i32 = arith.constant 8 : i32
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
%1 = select %0, %arg0, %c8_i32 : i32
return %1 : i32
}
}
.
============================== 6 passed in 1.21s ===============================

View File

@@ -1,51 +0,0 @@
import triton
@triton.jit
def if_else(lb, ub, value):
if value > lb:
a = 0.0
else:
a = 1.0
c = a + a
@triton.jit
def only_if(lb, ub, value):
a = -1.0
if value > lb:
a = 0.0
c = a + a
@triton.jit
def only_if_invalid(lb, ub, value):
if value > lb:
a = 0.0
c = a + a
@triton.jit
def nested_if(lb, ub, value):
if value > lb:
if value < ub:
a = 2.0
else:
a = 1.0
else:
a = 0.0
c = a + a
mod_if_else, ctx_if_else = if_else.compile_to_ttir(2, 4, 3, grid=(1,))
mod_if_else.dump()
mod_only_if, ctx_only_if = only_if.compile_to_ttir(2, 4, 3, grid=(1,))
mod_only_if.dump()
try:
mod_only_if_invalid, ctx_only_if = only_if_invalid.compile_to_ttir(2, 4, 3, grid=(1,))
mod_only_if_invalid.dump()
except:
print('value error')
mod_nested_if, ctx_nested_if = nested_if.compile_to_ttir(2, 4, 3, grid=(1,))
mod_nested_if.dump()
print(mod_nested_if.str())

View File

@@ -1,261 +0,0 @@
module {
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c128_13c128_14c128_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = call @"cdiv__i32__1cconstexpr[128]"(%arg3) : (i32) -> i32
%2 = call @"cdiv__i32__1cconstexpr[128]"(%arg4) : (i32) -> i32
%c8_i32 = arith.constant 8 : i32
%3 = arith.muli %2, %c8_i32 : i32
%4 = arith.divsi %0, %3 : i32
%c8_i32_0 = arith.constant 8 : i32
%5 = arith.muli %4, %c8_i32_0 : i32
%6 = arith.subi %1, %5 : i32
%7 = call @"minimum__i32__1cconstexpr[8]"(%6) : (i32) -> i32
%8 = arith.remsi %0, %7 : i32
%9 = arith.addi %5, %8 : i32
%10 = arith.remsi %0, %3 : i32
%11 = arith.divsi %10, %7 : i32
%c128_i32 = arith.constant 128 : i32
%12 = arith.muli %9, %c128_i32 : i32
%13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%14 = tt.broadcast %12 : (i32) -> tensor<128xi32>
%15 = arith.addi %14, %13 : tensor<128xi32>
%c128_i32_1 = arith.constant 128 : i32
%16 = arith.muli %11, %c128_i32_1 : i32
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%18 = tt.broadcast %16 : (i32) -> tensor<128xi32>
%19 = arith.addi %18, %17 : tensor<128xi32>
%20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%21 = tt.reshape %15 : (tensor<128xi32>) -> tensor<128x1xi32>
%22 = tt.broadcast %arg6 : (i32) -> tensor<128x1xi32>
%23 = arith.muli %21, %22 : tensor<128x1xi32>
%24 = tt.reshape %20 : (tensor<128xi32>) -> tensor<1x128xi32>
%c1_i32 = arith.constant 1 : i32
%25 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32>
%26 = arith.muli %24, %25 : tensor<1x128xi32>
%27 = tt.broadcast %23 : (tensor<128x1xi32>) -> tensor<128x128xi32>
%28 = tt.broadcast %26 : (tensor<1x128xi32>) -> tensor<128x128xi32>
%29 = arith.addi %27, %28 : tensor<128x128xi32>
%30 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>>
%31 = tt.getelementptr %30, %29, : tensor<128x128x!tt.ptr<f16>>
%32 = tt.reshape %20 : (tensor<128xi32>) -> tensor<128x1xi32>
%33 = tt.broadcast %arg7 : (i32) -> tensor<128x1xi32>
%34 = arith.muli %32, %33 : tensor<128x1xi32>
%35 = tt.reshape %19 : (tensor<128xi32>) -> tensor<1x128xi32>
%c1_i32_2 = arith.constant 1 : i32
%36 = tt.broadcast %c1_i32_2 : (i32) -> tensor<1x128xi32>
%37 = arith.muli %35, %36 : tensor<1x128xi32>
%38 = tt.broadcast %34 : (tensor<128x1xi32>) -> tensor<128x128xi32>
%39 = tt.broadcast %37 : (tensor<1x128xi32>) -> tensor<128x128xi32>
%40 = arith.addi %38, %39 : tensor<128x128xi32>
%41 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>>
%42 = tt.getelementptr %41, %40, : tensor<128x128x!tt.ptr<f16>>
%cst = arith.constant 0.000000e+00 : f32
%43 = tt.broadcast %cst : (f32) -> tensor<128x128xf32>
%c0_i32 = arith.constant 0 : i32
%c128_i32_3 = arith.constant 128 : i32
%44 = arith.index_cast %c0_i32 : i32 to index
%45 = arith.index_cast %arg5 : i32 to index
%46 = arith.index_cast %c128_i32_3 : i32 to index
%47:3 = scf.for %arg9 = %44 to %45 step %46 iter_args(%arg10 = %43, %arg11 = %31, %arg12 = %42) -> (tensor<128x128xf32>, tensor<128x128x!tt.ptr<f16>>, tensor<128x128x!tt.ptr<f16>>) {
%cst_7 = arith.constant dense<true> : tensor<128x128xi1>
%cst_8 = arith.constant dense<0.000000e+00> : tensor<128x128xf16>
%77 = tt.load %arg11, %cst_7, %cst_8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16>
%cst_9 = arith.constant dense<true> : tensor<128x128xi1>
%cst_10 = arith.constant dense<0.000000e+00> : tensor<128x128xf16>
%78 = tt.load %arg12, %cst_9, %cst_10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16>
%cst_11 = arith.constant 0.000000e+00 : f32
%79 = tt.broadcast %cst_11 : (f32) -> tensor<128x128xf32>
%80 = tt.dot %77, %78, %79 {allowTF32 = true} : tensor<128x128xf16> * tensor<128x128xf16> -> tensor<128x128xf32>
%81 = arith.addf %arg10, %80 : tensor<128x128xf32>
%c128_i32_12 = arith.constant 128 : i32
%82 = tt.broadcast %c128_i32_12 : (i32) -> tensor<128x128xi32>
%83 = tt.getelementptr %arg11, %82, : tensor<128x128x!tt.ptr<f16>>
%c128_i32_13 = arith.constant 128 : i32
%84 = arith.muli %arg7, %c128_i32_13 : i32
%85 = tt.broadcast %84 : (i32) -> tensor<128x128xi32>
%86 = tt.getelementptr %arg12, %85, : tensor<128x128x!tt.ptr<f16>>
scf.yield %81, %83, %86 : tensor<128x128xf32>, tensor<128x128x!tt.ptr<f16>>, tensor<128x128x!tt.ptr<f16>>
}
%48 = arith.truncf %47#0 : tensor<128x128xf32> to tensor<128x128xf16>
%c128_i32_4 = arith.constant 128 : i32
%49 = arith.muli %9, %c128_i32_4 : i32
%50 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%51 = tt.broadcast %49 : (i32) -> tensor<128xi32>
%52 = arith.addi %51, %50 : tensor<128xi32>
%c128_i32_5 = arith.constant 128 : i32
%53 = arith.muli %11, %c128_i32_5 : i32
%54 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%55 = tt.broadcast %53 : (i32) -> tensor<128xi32>
%56 = arith.addi %55, %54 : tensor<128xi32>
%57 = tt.reshape %52 : (tensor<128xi32>) -> tensor<128x1xi32>
%58 = tt.broadcast %arg8 : (i32) -> tensor<128x1xi32>
%59 = arith.muli %58, %57 : tensor<128x1xi32>
%60 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>>
%61 = tt.getelementptr %60, %59, : tensor<128x1x!tt.ptr<f16>>
%62 = tt.reshape %56 : (tensor<128xi32>) -> tensor<1x128xi32>
%c1_i32_6 = arith.constant 1 : i32
%63 = tt.broadcast %c1_i32_6 : (i32) -> tensor<1x128xi32>
%64 = arith.muli %62, %63 : tensor<1x128xi32>
%65 = tt.broadcast %61 : (tensor<128x1x!tt.ptr<f16>>) -> tensor<128x128x!tt.ptr<f16>>
%66 = tt.broadcast %64 : (tensor<1x128xi32>) -> tensor<128x128xi32>
%67 = tt.getelementptr %65, %66, : tensor<128x128x!tt.ptr<f16>>
%68 = tt.reshape %52 : (tensor<128xi32>) -> tensor<128x1xi32>
%69 = tt.broadcast %arg3 : (i32) -> tensor<128x1xi32>
%70 = arith.cmpi slt, %68, %69 : tensor<128x1xi32>
%71 = tt.reshape %56 : (tensor<128xi32>) -> tensor<1x128xi32>
%72 = tt.broadcast %arg4 : (i32) -> tensor<1x128xi32>
%73 = arith.cmpi slt, %71, %72 : tensor<1x128xi32>
%74 = tt.broadcast %70 : (tensor<128x1xi1>) -> tensor<128x128xi1>
%75 = tt.broadcast %73 : (tensor<1x128xi1>) -> tensor<128x128xi1>
%76 = arith.andi %74, %75 : tensor<128x128xi1>
tt.store %67, %48, %76, : tensor<128x128xf16>
return
}
func @"cdiv__i32__1cconstexpr[128]"(%arg0: i32) -> i32 {
%c128_i32 = arith.constant 128 : i32
%0 = arith.addi %arg0, %c128_i32 : i32
%c1_i32 = arith.constant 1 : i32
%1 = arith.subi %0, %c1_i32 : i32
%c128_i32_0 = arith.constant 128 : i32
%2 = arith.divsi %1, %c128_i32_0 : i32
return %2 : i32
}
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
%c8_i32 = arith.constant 8 : i32
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
%c8_i32_0 = arith.constant 8 : i32
%1 = select %0, %arg0, %c8_i32_0 : i32
return %1 : i32
}
}
module {
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c128_13c128_14c128_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%c8_i32 = arith.constant 8 : i32
%c127_i32 = arith.constant 127 : i32
%c128_i32 = arith.constant 128 : i32
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%cst_0 = arith.constant dense<true> : tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%c1_i32 = arith.constant 1 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.addi %arg3, %c127_i32 : i32
%2 = arith.divsi %1, %c128_i32 : i32
%3 = arith.addi %arg4, %c127_i32 : i32
%4 = arith.divsi %3, %c128_i32 : i32
%5 = arith.muli %4, %c8_i32 : i32
%6 = arith.divsi %0, %5 : i32
%7 = arith.muli %6, %c8_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = arith.cmpi slt, %8, %c8_i32 : i32
%10 = select %9, %8, %c8_i32 : i32
%11 = arith.remsi %0, %10 : i32
%12 = arith.addi %7, %11 : i32
%13 = arith.remsi %0, %5 : i32
%14 = arith.divsi %13, %10 : i32
%15 = arith.muli %12, %c128_i32 : i32
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%17 = tt.broadcast %15 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%18 = arith.addi %17, %16 : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%19 = arith.muli %14, %c128_i32 : i32
%20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%21 = tt.broadcast %19 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%22 = arith.addi %21, %20 : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%23 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%24 = tt.reshape %18 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%25 = tt.broadcast %arg6 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%26 = arith.muli %24, %25 : tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%27 = tt.reshape %23 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
%28 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
%29 = arith.muli %27, %28 : tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
%30 = tt.broadcast %26 : (tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%31 = tt.broadcast %29 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%32 = arith.addi %30, %31 : tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%33 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%34 = tt.getelementptr %33, %32, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%35 = tt.reshape %23 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%36 = tt.broadcast %arg7 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%37 = arith.muli %35, %36 : tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%38 = tt.reshape %22 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
%39 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
%40 = arith.muli %38, %39 : tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
%41 = tt.broadcast %37 : (tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%42 = tt.broadcast %40 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%43 = arith.addi %41, %42 : tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%44 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%45 = tt.getelementptr %44, %43, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%46 = tt.broadcast %cst : (f32) -> tensor<128x128xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%47 = arith.index_cast %arg5 : i32 to index
%48 = "triton_gpu.copy_async"(%34, %cst_0, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>
%49 = "triton_gpu.copy_async"(%45, %cst_0, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>
%50 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%51 = tt.getelementptr %34, %50, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%52 = arith.muli %arg7, %c128_i32 : i32
%53 = tt.broadcast %52 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%54 = tt.getelementptr %45, %53, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%55:8 = scf.for %arg9 = %c0 to %47 step %c128 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45, %arg13 = %48, %arg14 = %49, %arg15 = %51, %arg16 = %54, %arg17 = %c0) -> (tensor<128x128xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, index) {
%85 = tt.dot %arg13, %arg14, %arg10 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> * tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> -> tensor<128x128xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%86 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%87 = tt.getelementptr %arg11, %86, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%88 = arith.muli %arg7, %c128_i32 : i32
%89 = tt.broadcast %88 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%90 = tt.getelementptr %arg12, %89, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%91 = arith.addi %arg17, %c128 : index
%92 = arith.cmpi slt, %91, %47 : index
%93 = tt.broadcast %92 : (i1) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%94 = "triton_gpu.copy_async"(%arg15, %93, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>
%95 = "triton_gpu.copy_async"(%arg16, %93, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>
%96 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%97 = tt.getelementptr %arg15, %96, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%98 = arith.muli %arg7, %c128_i32 : i32
%99 = tt.broadcast %98 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%100 = tt.getelementptr %arg16, %99, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
scf.yield %85, %87, %90, %94, %95, %97, %100, %91 : tensor<128x128xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, index
}
%56 = arith.truncf %55#0 : tensor<128x128xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">> to tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%57 = arith.muli %12, %c128_i32 : i32
%58 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%59 = tt.broadcast %57 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%60 = arith.addi %59, %58 : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%61 = arith.muli %14, %c128_i32 : i32
%62 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%63 = tt.broadcast %61 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%64 = arith.addi %63, %62 : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
%65 = tt.reshape %60 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%66 = tt.broadcast %arg8 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%67 = arith.muli %66, %65 : tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%68 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%69 = tt.getelementptr %68, %67, : tensor<128x1x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%70 = tt.reshape %64 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
%71 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
%72 = arith.muli %70, %71 : tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
%73 = tt.broadcast %69 : (tensor<128x1x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%74 = tt.broadcast %72 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%75 = tt.getelementptr %73, %74, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%76 = tt.reshape %60 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%77 = tt.broadcast %arg3 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%78 = "triton_gpu.cmpi"(%76, %77) {predicate = 2 : i64} : (tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x1xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%79 = tt.reshape %64 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
%80 = tt.broadcast %arg4 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
%81 = "triton_gpu.cmpi"(%79, %80) {predicate = 2 : i64} : (tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>, tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>) -> tensor<1x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
%82 = tt.broadcast %78 : (tensor<128x1xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%83 = tt.broadcast %81 : (tensor<1x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
%84 = arith.andi %82, %83 : tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
tt.store %75, %56, %84, : tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
return
}
func @"cdiv__i32__1cconstexpr[128]"(%arg0: i32) -> i32 {
%c128_i32 = arith.constant 128 : i32
%c127_i32 = arith.constant 127 : i32
%0 = arith.addi %arg0, %c127_i32 : i32
%1 = arith.divsi %0, %c128_i32 : i32
return %1 : i32
}
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
%c8_i32 = arith.constant 8 : i32
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
%1 = select %0, %arg0, %c8_i32 : i32
return %1 : i32
}
}

View File

@@ -1,105 +0,0 @@
import triton
import triton.language as tl
import triton._C.libtriton.triton as _triton
import torch
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse
# See above `L2 Cache Optimizations` section for details
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
# see above `Pointer Arithmetics` section for details
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
# Note that for simplicity, we don't apply a mask here.
# This means that if K is not a multiple of BLOCK_SIZE_K,
# this will access out-of-bounds memory and produce an
# error or (worse!) incorrect results.
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
# We accumulate along the K dimension
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
# Write back the block of the output matrix C
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
c = torch.empty((512, 512), device='cuda', dtype=torch.float16)
mod, ctx = matmul_kernel.compile_to_ttir(
a, b, c,
512, 512, 512,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
128, 128, 128,
8, grid=(2,),
num_stages=4
)
assert mod.verify()
# mod.dump()
res = matmul_kernel.compile_ttir_to_llir(mod, ctx)
assert mod.verify()
# assert res
mod.dump()

View File

@@ -1,27 +0,0 @@
import triton
import triton.language as tl
import triton._C.libtriton.triton as _triton
@triton.jit
def foo(a, b):
max, min = maxmin(a, b)
return max, min
@triton.jit
def maxmin(a, b):
max = tl.maximum(a, b)
min = tl.minimum(a, b)
return max, min
mod, ctx = foo.compile_to_ttir(3, 4, grid=(1,))
assert mod.verify()
mod.dump()
pm = _triton.ir.pass_manager(ctx)
pm.add_inliner_pass()
pm.run(mod)
assert mod.verify()
mod.dump()

View File

@@ -1,47 +0,0 @@
import torch
import triton
import triton.language as tl
import triton._C.libtriton.triton as _triton
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
size = 1024
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
z = torch.empty_like(x)
# add_kernel[(1,)](x, y, z, size, 256)
# print(add_kernel[(1,)].kernel.compile_to_ttir())
# print(add_kernel.annotations)
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,))
assert mod.verify()
mod.dump()
add_kernel.compile_ttir_to_llir(mod, ctx)
mod.dump()

View File

@@ -1,56 +0,0 @@
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
K,
stride
# BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# # NOTE: `constexpr` so it can be used as a shape value
):
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
block_start = pid * 256
offsets = block_start + tl.arange(0, 256)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
x_ptrs = x_ptr + offsets
y_ptrs = y_ptr + offsets
output = tl.zeros((256,), dtype=tl.float32)
for k in range(0, K, 32):
x = tl.load(x_ptrs, mask=mask, other=0.0)
y = tl.load(y_ptrs, mask=mask, other=0.0)
output += x + y
x_ptrs += stride
y_ptrs += stride
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
size = 1024
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
z = torch.empty_like(x)
# add_kernel[(1,)](x, y, z, size, 256)
# print(add_kernel[(1,)].kernel.compile_to_ttir())
mod, ctx = add_kernel.compile_to_ttir(
x, y, z, size, 128, 8, grid=(1,), num_stages=1)
mod.dump()
# print(mod)
res = add_kernel.compile_ttir_to_llir(mod, ctx)
mod.dump()

View File

@@ -1,82 +0,0 @@
module {
func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
%0 = tt.get_program_id {axis = 0 : i32} : i32
%c256_i32 = arith.constant 256 : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
%3 = tt.broadcast %1 : (i32) -> tensor<256xi32>
%4 = arith.addi %3, %2 : tensor<256xi32>
%5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32>
%6 = arith.cmpi slt, %4, %5 : tensor<256xi32>
%7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
%8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr<f32>>
%9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
%10 = tt.getelementptr %9, %4 : tensor<256x!tt.ptr<f32>>
%cst = arith.constant 0.000000e+00 : f32
%11 = tt.broadcast %cst : (f32) -> tensor<256xf32>
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
%12 = arith.index_cast %c0_i32 : i32 to index
%13 = arith.index_cast %arg4 : i32 to index
%14 = arith.index_cast %c32_i32 : i32 to index
%15:3 = scf.for %arg6 = %12 to %13 step %14 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>) {
%cst_0 = arith.constant 0.000000e+00 : f32
%18 = tt.broadcast %cst_0 : (f32) -> tensor<256xf32>
%19 = tt.load %arg8, %6, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32>
%cst_1 = arith.constant 0.000000e+00 : f32
%20 = tt.broadcast %cst_1 : (f32) -> tensor<256xf32>
%21 = tt.load %arg9, %6, %20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32>
%22 = arith.addf %19, %21 : tensor<256xf32>
%23 = arith.addf %arg7, %22 : tensor<256xf32>
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>
%25 = tt.getelementptr %arg8, %24 : tensor<256x!tt.ptr<f32>>
%26 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>
%27 = tt.getelementptr %arg9, %26 : tensor<256x!tt.ptr<f32>>
scf.yield %23, %25, %27 : tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>
}
%16 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
%17 = tt.getelementptr %16, %4 : tensor<256x!tt.ptr<f32>>
tt.store %17, %15#0, %6, : tensor<256xf32>
return
}
}
module {
func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c256_i32 = arith.constant 256 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%4 = triton_gpu.convert_layout %3 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%5 = arith.addi %4, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%6 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%7 = triton_gpu.convert_layout %6 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%8 = "triton_gpu.cmpi"(%5, %7) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%9 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%10 = triton_gpu.convert_layout %9 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%11 = tt.getelementptr %10, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%12 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%13 = triton_gpu.convert_layout %12 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%14 = tt.getelementptr %13, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%15 = arith.index_cast %arg4 : i32 to index
%16:3 = scf.for %arg6 = %c0 to %15 step %c32 iter_args(%arg7 = %cst, %arg8 = %11, %arg9 = %14) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) {
%20 = tt.load %arg8, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%21 = tt.load %arg9, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%22 = arith.addf %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%23 = arith.addf %arg7, %22 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%25 = triton_gpu.convert_layout %24 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%26 = tt.getelementptr %arg8, %25 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%27 = tt.getelementptr %arg9, %25 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
scf.yield %23, %26, %27 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
}
%17 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%18 = triton_gpu.convert_layout %17 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%19 = tt.getelementptr %18, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
tt.store %19, %16#0, %8, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
return
}
}

View File

@@ -1,38 +0,0 @@
import triton
import triton.language as tl
import torch
@triton.jit
def atomic(lock):
while tl.atomic_cas(lock, 0, 1) == 1:
pass
@triton.jit
def generic_while(lb, value):
c = -1
while c <= 0:
c += 1
# locks = torch.zeros(32, dtype=torch.int32, device='cuda')
# mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
# mod_atomic.dump()
# mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,))
# mod_generic_while.dump()
@triton.jit
def nested_cf(X, lb, ub, Z):
a = 0.0
if lb < ub:
for z in range(0, Z):
a += 2.0
else:
while a < 1.2:
a *= 2.0
for _ in range(0, Z, 2):
a *= -3.3
a -= 1.0
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
assert mod.verify(), mod.str()
mod.dump()

View File

@@ -1,58 +0,0 @@
import triton._C.libtriton.triton.ir as ir
ctx = ir.context()
ctx.load_triton()
# TODO
builder = ir.builder(ctx)
module = builder.create_module()
i1_ty = builder.get_int1_ty()
i8_ty = builder.get_int8_ty()
i16_ty = builder.get_int16_ty()
i32_ty = builder.get_int32_ty()
i64_ty = builder.get_int64_ty()
f16_ty = builder.get_half_ty()
f16_ptr_ty = builder.get_ptr_ty(f16_ty, 1)
func_ty = builder.get_function_ty([f16_ptr_ty, f16_ptr_ty, f16_ptr_ty], [])
func = builder.create_function('foo', func_ty)
module.push_back(func)
module.set_attr("num_warps", builder.get_int32_attr(4))
# ...
entry = func.add_entry_block()
builder.set_insertion_point_to_start(entry)
offsets = builder.create_make_range(0, 128)
pid = builder.create_get_program_id(0)
_128 = builder.get_int32(128)
offset = builder.create_add(pid, _128)
offset = builder.create_splat(offset, [128])
offsets = builder.create_add(offset, offsets)
a_ptrs = builder.create_splat(entry.arg(0), [128])
b_ptrs = builder.create_splat(entry.arg(1), [128])
a_ptrs = builder.create_gep(a_ptrs, offsets)
b_ptrs = builder.create_gep(b_ptrs, offsets)
a = builder.create_load(a_ptrs, ir.CACHE_MODIFIER.NONE, ir.EVICTION_POLICY.NORMAL, False)
b = builder.create_load(b_ptrs, ir.CACHE_MODIFIER.NONE, ir.EVICTION_POLICY.NORMAL, False)
c = builder.create_fadd(a, b)
c.set_attr("ieee_rounding", builder.get_bool_attr(True))
c_ptrs = builder.create_splat(entry.arg(2), [128])
c_ptrs = builder.create_gep(c_ptrs, offsets)
builder.create_store(c_ptrs, c)
# func.dump()
module.dump()

View File

@@ -1,417 +0,0 @@
import pytest
import triton
import triton.language as tl
import triton._C.libtriton.triton as _triton
import torch
def test_if():
ref_ir = """module {
func @only_if__i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32) {
%cst = arith.constant -1.000000e+00 : f32
%0 = arith.cmpi sgt, %arg2, %arg0 : i32
%1 = scf.if %0 -> (f32) {
%cst_0 = arith.constant 0.000000e+00 : f32
scf.yield %cst_0 : f32
} else {
scf.yield %cst : f32
}
%2 = arith.addf %1, %1 : f32
return
}
}
"""
@triton.jit
def only_if(lb, ub, value):
a = -1.0
if value > lb:
a = 0.0
c = a + a
mod, _ = only_if.compile_to_ttir(2, 3, 4, grid=(1,))
generated_ir = mod.str()
assert mod.verify()
assert ref_ir == generated_ir
def test_if_else():
ref_ir = """module {
func @if_else__i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32) {
%0 = arith.cmpi sgt, %arg2, %arg0 : i32
%1 = scf.if %0 -> (f32) {
%cst = arith.constant 0.000000e+00 : f32
scf.yield %cst : f32
} else {
%cst = arith.constant 1.000000e+00 : f32
scf.yield %cst : f32
}
%2 = arith.addf %1, %1 : f32
return
}
}
"""
@triton.jit
def if_else(lb, ub, value):
if value > lb:
a = 0.0
else:
a = 1.0
c = a + a
mod, _ = if_else.compile_to_ttir(2, 3, 4, grid=(1,))
generated_ir = mod.str()
assert mod.verify()
assert ref_ir == generated_ir
def test_for():
ref_ir = """module {
func @for_loop__i32__(%arg0: i32) {
%cst = arith.constant 1.000000e+00 : f32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%0 = arith.index_cast %c0_i32 : i32 to index
%1 = arith.index_cast %arg0 : i32 to index
%2 = arith.index_cast %c1_i32 : i32 to index
%3 = scf.for %arg1 = %0 to %1 step %2 iter_args(%arg2 = %cst) -> (f32) {
%cst_0 = arith.constant 1.000000e+00 : f32
%4 = arith.addf %arg2, %cst_0 : f32
scf.yield %4 : f32
}
return
}
}
"""
@triton.jit
def for_loop(K):
a = 1.0
for k in range(0, K):
a += 1.0
mod, _ = for_loop.compile_to_ttir(2, grid=(1,))
generated_ir = mod.str()
assert mod.verify()
assert ref_ir == generated_ir
def test_while():
ref_ir = """module {
func @generic_while__i32__(%arg0: i32) {
%c-1_i32 = arith.constant -1 : i32
%0 = scf.while (%arg1 = %c-1_i32) : (i32) -> i32 {
%c0_i32 = arith.constant 0 : i32
%1 = arith.cmpi sle, %arg1, %c0_i32 : i32
scf.condition(%1) %arg1 : i32
} do {
^bb0(%arg1: i32):
%c1_i32 = arith.constant 1 : i32
%1 = arith.addi %arg1, %c1_i32 : i32
scf.yield %1 : i32
}
return
}
}
"""
@triton.jit
def generic_while(x):
c = -1
while c <= 0:
c += 1
mod, _ = generic_while.compile_to_ttir(2, grid=(1,))
generated_ir = mod.str()
assert mod.verify()
assert ref_ir == generated_ir
def test_nested():
ref_ir = """module {
func @nested_cf__i32_i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) {
%cst = arith.constant 0.000000e+00 : f32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%0 = arith.index_cast %c0_i32 : i32 to index
%1 = arith.index_cast %arg0 : i32 to index
%2 = arith.index_cast %c1_i32 : i32 to index
%3 = scf.for %arg4 = %0 to %1 step %2 iter_args(%arg5 = %cst) -> (f32) {
%5 = arith.cmpi slt, %arg1, %arg2 : i32
%6 = scf.if %5 -> (f32) {
%c0_i32_1 = arith.constant 0 : i32
%c1_i32_2 = arith.constant 1 : i32
%7 = arith.index_cast %c0_i32_1 : i32 to index
%8 = arith.index_cast %arg3 : i32 to index
%9 = arith.index_cast %c1_i32_2 : i32 to index
%10 = scf.for %arg6 = %7 to %8 step %9 iter_args(%arg7 = %arg5) -> (f32) {
%cst_3 = arith.constant 2.000000e+00 : f32
%11 = arith.addf %arg7, %cst_3 : f32
scf.yield %11 : f32
}
scf.yield %10 : f32
} else {
%7 = scf.while (%arg6 = %arg5) : (f32) -> f32 {
%cst_1 = arith.constant 1.200000e+00 : f32
%8 = arith.cmpf olt, %arg6, %cst_1 : f32
scf.condition(%8) %arg6 : f32
} do {
^bb0(%arg6: f32):
%cst_1 = arith.constant 2.000000e+00 : f32
%8 = arith.mulf %arg6, %cst_1 : f32
scf.yield %8 : f32
}
scf.yield %7 : f32
}
scf.yield %6 : f32
}
%cst_0 = arith.constant 1.000000e+00 : f32
%4 = arith.subf %3, %cst_0 : f32
return
}
}
"""
@triton.jit
def nested_cf(X, lb, ub, Z):
a = 0.0
for x in range(0, X):
if lb < ub:
for z in range(0, Z):
a += 2.0
else:
while a < 1.2:
a *= 2.0
a -= 1.0
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
generated_ir = mod.str()
assert mod.verify(), generated_ir
assert ref_ir == generated_ir
def test_matmul():
ref_ir = """module {
func @matmul_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%0 = tt.get_program_id {axis = 0 : i32} : i32
%c64_i32 = arith.constant 64 : i32
%1 = arith.addi %arg3, %c64_i32 : i32
%c1_i32 = arith.constant 1 : i32
%2 = arith.subi %1, %c1_i32 : i32
%c64_i32_0 = arith.constant 64 : i32
%3 = arith.divsi %2, %c64_i32_0 : i32
%c64_i32_1 = arith.constant 64 : i32
%4 = arith.addi %arg4, %c64_i32_1 : i32
%c1_i32_2 = arith.constant 1 : i32
%5 = arith.subi %4, %c1_i32_2 : i32
%c64_i32_3 = arith.constant 64 : i32
%6 = arith.divsi %5, %c64_i32_3 : i32
%c8_i32 = arith.constant 8 : i32
%7 = arith.muli %6, %c8_i32 : i32
%8 = arith.divsi %0, %7 : i32
%c8_i32_4 = arith.constant 8 : i32
%9 = arith.muli %8, %c8_i32_4 : i32
%10 = arith.subi %3, %9 : i32
%c8_i32_5 = arith.constant 8 : i32
%11 = arith.cmpi slt, %10, %c8_i32_5 : i32
%c8_i32_6 = arith.constant 8 : i32
%12 = select %11, %10, %c8_i32_6 : i32
%13 = arith.remsi %0, %12 : i32
%14 = arith.addi %9, %13 : i32
%15 = arith.remsi %0, %7 : i32
%16 = arith.divsi %15, %12 : i32
%c64_i32_7 = arith.constant 64 : i32
%17 = arith.muli %14, %c64_i32_7 : i32
%18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%19 = tt.broadcast %17 : (i32) -> tensor<64xi32>
%20 = arith.addi %19, %18 : tensor<64xi32>
%c64_i32_8 = arith.constant 64 : i32
%21 = arith.muli %16, %c64_i32_8 : i32
%22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%23 = tt.broadcast %21 : (i32) -> tensor<64xi32>
%24 = arith.addi %23, %22 : tensor<64xi32>
%25 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%26 = tt.reshape %20 : (tensor<64xi32>) -> tensor<64x1xi32>
%27 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
%28 = arith.muli %26, %27 : tensor<64x1xi32>
%29 = tt.reshape %25 : (tensor<32xi32>) -> tensor<1x32xi32>
%c1_i32_9 = arith.constant 1 : i32
%30 = tt.broadcast %c1_i32_9 : (i32) -> tensor<1x32xi32>
%31 = arith.muli %29, %30 : tensor<1x32xi32>
%32 = tt.broadcast %28 : (tensor<64x1xi32>) -> tensor<64x32xi32>
%33 = tt.broadcast %31 : (tensor<1x32xi32>) -> tensor<64x32xi32>
%34 = arith.addi %32, %33 : tensor<64x32xi32>
%35 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
%36 = tt.getelementptr %35, %34, : tensor<64x32x!tt.ptr<f16>>
%37 = tt.reshape %25 : (tensor<32xi32>) -> tensor<32x1xi32>
%38 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
%39 = arith.muli %37, %38 : tensor<32x1xi32>
%40 = tt.reshape %24 : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_10 = arith.constant 1 : i32
%41 = tt.broadcast %c1_i32_10 : (i32) -> tensor<1x64xi32>
%42 = arith.muli %40, %41 : tensor<1x64xi32>
%43 = tt.broadcast %39 : (tensor<32x1xi32>) -> tensor<32x64xi32>
%44 = tt.broadcast %42 : (tensor<1x64xi32>) -> tensor<32x64xi32>
%45 = arith.addi %43, %44 : tensor<32x64xi32>
%46 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
%47 = tt.getelementptr %46, %45, : tensor<32x64x!tt.ptr<f16>>
%cst = arith.constant 0.000000e+00 : f32
%48 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
%49 = arith.index_cast %c0_i32 : i32 to index
%50 = arith.index_cast %arg5 : i32 to index
%51 = arith.index_cast %c32_i32 : i32 to index
%52:3 = scf.for %arg9 = %49 to %50 step %51 iter_args(%arg10 = %48, %arg11 = %36, %arg12 = %47) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
%cst_14 = arith.constant dense<true> : tensor<64x32xi1>
%cst_15 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
%82 = tt.load %arg11, %cst_14, %cst_15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
%cst_16 = arith.constant dense<true> : tensor<32x64xi1>
%cst_17 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
%83 = tt.load %arg12, %cst_16, %cst_17 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
%cst_18 = arith.constant 0.000000e+00 : f32
%84 = tt.broadcast %cst_18 : (f32) -> tensor<64x64xf32>
%85 = tt.dot %82, %83, %84 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
%86 = arith.addf %arg10, %85 : tensor<64x64xf32>
%c32_i32_19 = arith.constant 32 : i32
%87 = tt.broadcast %c32_i32_19 : (i32) -> tensor<64x32xi32>
%88 = tt.getelementptr %arg11, %87, : tensor<64x32x!tt.ptr<f16>>
%c32_i32_20 = arith.constant 32 : i32
%89 = arith.muli %arg7, %c32_i32_20 : i32
%90 = tt.broadcast %89 : (i32) -> tensor<32x64xi32>
%91 = tt.getelementptr %arg12, %90, : tensor<32x64x!tt.ptr<f16>>
scf.yield %86, %88, %91 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
}
%53 = arith.truncf %52#0 : tensor<64x64xf32> to tensor<64x64xf16>
%c64_i32_11 = arith.constant 64 : i32
%54 = arith.muli %14, %c64_i32_11 : i32
%55 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%56 = tt.broadcast %54 : (i32) -> tensor<64xi32>
%57 = arith.addi %56, %55 : tensor<64xi32>
%c64_i32_12 = arith.constant 64 : i32
%58 = arith.muli %16, %c64_i32_12 : i32
%59 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%60 = tt.broadcast %58 : (i32) -> tensor<64xi32>
%61 = arith.addi %60, %59 : tensor<64xi32>
%62 = tt.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32>
%63 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
%64 = arith.muli %63, %62 : tensor<64x1xi32>
%65 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
%66 = tt.getelementptr %65, %64, : tensor<64x1x!tt.ptr<f16>>
%67 = tt.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_13 = arith.constant 1 : i32
%68 = tt.broadcast %c1_i32_13 : (i32) -> tensor<1x64xi32>
%69 = arith.muli %67, %68 : tensor<1x64xi32>
%70 = tt.broadcast %66 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
%71 = tt.broadcast %69 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%72 = tt.getelementptr %70, %71, : tensor<64x64x!tt.ptr<f16>>
%73 = tt.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32>
%74 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
%75 = arith.cmpi slt, %73, %74 : tensor<64x1xi32>
%76 = tt.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32>
%77 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
%78 = arith.cmpi slt, %76, %77 : tensor<1x64xi32>
%79 = tt.broadcast %75 : (tensor<64x1xi1>) -> tensor<64x64xi1>
%80 = tt.broadcast %78 : (tensor<1x64xi1>) -> tensor<64x64xi1>
%81 = arith.andi %79, %80 : tensor<64x64xi1>
tt.store %72, %53, %81, : tensor<64x64xf16>
return
}
}
"""
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse
# See above `L2 Cache Optimizations` section for details
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
# see above `Pointer Arithmetics` section for details
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
# Note that for simplicity, we don't apply a mask here.
# This means that if K is not a multiple of BLOCK_SIZE_K,
# this will access out-of-bounds memory and produce an
# error or (worse!) incorrect results.
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
# We accumulate along the K dimension
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
# Write back the block of the output matrix C
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
c = torch.empty((512, 512), device='cuda', dtype=torch.float16)
mod, ctx = matmul_kernel.compile_to_ttir(
a, b, c,
512, 512, 512,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
64, 64, 32,
8, grid=(2,)
)
verify = mod.verify()
assert verify
# assert ref_ir == mod.str()
print(mod.str())
pm = _triton.ir.pass_manager(ctx)
pm.add_inliner_pass()
pm.run(mod)
verify = mod.verify()
assert verify
# assert ref_ir == mod.str()
print(mod.str())