[PYTHON] Cleaned up legacy code; added simple standalone compilation API (#22)
This commit is contained in:
1
.clang-format
Normal file
1
.clang-format
Normal file
@@ -0,0 +1 @@
|
||||
BasedOnStyle: LLVM
|
51
.github/workflows/documentation.yml
vendored
51
.github/workflows/documentation.yml
vendored
@@ -1,51 +0,0 @@
|
||||
name: Documentation
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
|
||||
jobs:
|
||||
|
||||
Build-Documentation:
|
||||
|
||||
runs-on: self-hosted
|
||||
|
||||
steps:
|
||||
|
||||
|
||||
- name: Checkout gh-pages
|
||||
uses: actions/checkout@v1
|
||||
with:
|
||||
ref: 'gh-pages'
|
||||
|
||||
- name: Checkout branch
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Build docs
|
||||
run: |
|
||||
git fetch origin master:master
|
||||
cd docs
|
||||
sphinx-multiversion . _build/html/
|
||||
|
||||
- name: Publish docs
|
||||
run: |
|
||||
git branch
|
||||
# update docs
|
||||
rm -r /tmp/triton-docs;
|
||||
mkdir /tmp/triton-docs;
|
||||
mv docs/_build/html/* /tmp/triton-docs/
|
||||
git checkout gh-pages
|
||||
cp -r CNAME /tmp/triton-docs/
|
||||
cp -r index.html /tmp/triton-docs/
|
||||
cp -r .nojekyll /tmp/triton-docs/
|
||||
rm -r *
|
||||
cp -r /tmp/triton-docs/* .
|
||||
# ln -s master/index.html .
|
||||
# mv master docs
|
||||
git add .
|
||||
git commit -am "[GH-PAGES] Updated website"
|
||||
# publish docs
|
||||
eval `ssh-agent -s`
|
||||
DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }}
|
||||
git remote set-url origin git@github.com:openai/triton.git
|
||||
git push
|
5
.github/workflows/integration-tests.yml
vendored
5
.github/workflows/integration-tests.yml
vendored
@@ -4,8 +4,7 @@ on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
- v2.0
|
||||
- main
|
||||
|
||||
|
||||
jobs:
|
||||
@@ -21,7 +20,7 @@ jobs:
|
||||
|
||||
- name: Clear cache
|
||||
run: |
|
||||
rm -r /tmp/triton/
|
||||
rm -r ~/.triton/
|
||||
continue-on-error: true
|
||||
|
||||
- name: Install Triton
|
||||
|
40
.github/workflows/wheels.yml
vendored
40
.github/workflows/wheels.yml
vendored
@@ -1,40 +0,0 @@
|
||||
name: Wheels
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
|
||||
jobs:
|
||||
|
||||
Build-Wheels:
|
||||
|
||||
runs-on: self-hosted
|
||||
|
||||
steps:
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Patch setup.py
|
||||
run: |
|
||||
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
|
||||
export LATEST_DATE=$(git show -s --format=%ci `git rev-parse HEAD` | cut -d ' ' -f 1 | sed 's/-//g')
|
||||
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py
|
||||
echo "" >> python/setup.cfg
|
||||
echo "[build_ext]" >> python/setup.cfg
|
||||
echo "base-dir=/project" >> python/setup.cfg
|
||||
|
||||
- name: Build wheels
|
||||
run: |
|
||||
export CIBW_MANYLINUX_X86_64_IMAGE="manylinux2014"
|
||||
export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="manylinux2014"
|
||||
export CIBW_BEFORE_BUILD="pip install cmake;\
|
||||
yum install -y llvm11 llvm11-devel llvm11-static llvm11-libs zlib-devel;"
|
||||
export CIBW_SKIP="{cp,pp}35-*"
|
||||
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64"
|
||||
python3 -m cibuildwheel python --output-dir wheelhouse
|
||||
|
||||
|
||||
- name: Upload wheels to PyPI
|
||||
run: |
|
||||
python3 -m twine upload wheelhouse/* --skip-existing
|
@@ -126,18 +126,10 @@ include_directories(${LLVM_INCLUDE_DIRS})
|
||||
# Python module
|
||||
if(BUILD_PYTHON_MODULE)
|
||||
message(STATUS "Adding Python module")
|
||||
# Build CUTLASS python wrapper if requested
|
||||
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
||||
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}")
|
||||
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}")
|
||||
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
|
||||
set(CUTLASS_SRC ${PYTHON_SRC_PATH}/cutlass.cc)
|
||||
add_definitions(-DWITH_CUTLASS_BINDINGS)
|
||||
set(CUTLASS_LIBRARIES "cutlass.a")
|
||||
endif()
|
||||
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
|
||||
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
|
||||
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
|
||||
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS})
|
||||
link_directories(${PYTHON_LINK_DIRS})
|
||||
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc)
|
||||
endif()
|
||||
|
||||
|
||||
|
1
deps/dlfcn-win32
vendored
1
deps/dlfcn-win32
vendored
Submodule deps/dlfcn-win32 deleted from 522c301ec3
@@ -59,13 +59,13 @@ For example, a row-major coalesced layout may distribute a 64x16 tensor over 2 w
|
||||
|
||||
thread tile size 2
|
||||
- - - - - - /\ - - - - - -
|
||||
block| thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3]
|
||||
warp | thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3]
|
||||
tile | tile size 2 || A_{1, 0}[T0] A_{1, 1}[T0] ... A_{1, 6}[T3] A_{1, 7}[T3] A_{1, 8}[T0] A_{1, 9}[T0] ... A_{1, 14}[T3] A_{1, 15}[T3]
|
||||
size } ....
|
||||
32 | A_{30, 0}[T60] A_{14, 1}[T60] ... A_{14, 6}[T63] A_{14, 7}[T63] A_{14, 8}[T60] A_{14, 9}[T60] ... A_{14, 14}[T63] A_{14, 15}[T63]
|
||||
| A_{31, 0}[T60] A_{15, 1}[T60] ... A_{15, 6}[T63] A_{15, 7}[T63] A_{15, 8}[T60] A_{15, 9}[T60] ... A_{15, 14}[T63] A_{15, 15}[T63]
|
||||
-----------------------------/\-----------------------------------
|
||||
block tile size 8
|
||||
warp tile size 8
|
||||
|
||||
|
||||
A_{32, 0}[T0] A_{32, 1}[T0] ... A_{32, 6}[T3] A_{32, 7}[T3] A_{32, 8}[T0] A_{32, 9}[T0] ... A_{32, 14}[T3] A_{32, 15}[T3]
|
||||
|
@@ -1,80 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
|
||||
#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class value;
|
||||
class module;
|
||||
class phi_node;
|
||||
class splat_inst;
|
||||
class cast_inst;
|
||||
class reshape_inst;
|
||||
class broadcast_inst;
|
||||
class binary_operator;
|
||||
class getelementptr_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class align {
|
||||
private:
|
||||
struct cst_info {
|
||||
unsigned num_cst;
|
||||
unsigned value;
|
||||
};
|
||||
// helpers
|
||||
std::vector<unsigned> get_shapes(ir::value *v);
|
||||
// populate is_constant
|
||||
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
|
||||
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
|
||||
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_default(ir::value* v);
|
||||
std::vector<cst_info> populate_is_constant(ir::value *v);
|
||||
// populate max_contiguous
|
||||
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
|
||||
std::vector<unsigned> populate_max_contiguous(ir::value *v);
|
||||
// populate starting_multiple
|
||||
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
|
||||
std::vector<unsigned> populate_starting_multiple(ir::value *v);
|
||||
// populate all maps
|
||||
void populate(ir::value *v);
|
||||
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
unsigned get(ir::value* v, unsigned ax) const;
|
||||
std::vector<unsigned> contiguous(ir::value* v) const;
|
||||
|
||||
private:
|
||||
std::map<ir::value*, std::vector<cst_info>> is_constant_;
|
||||
std::map<ir::value*, std::vector<unsigned>> max_contiguous_;
|
||||
std::map<ir::value*, std::vector<unsigned>> starting_multiple_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,47 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class function;
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class tiles;
|
||||
|
||||
class liveness;
|
||||
class cts;
|
||||
|
||||
class allocation {
|
||||
public:
|
||||
allocation(liveness *live)
|
||||
: liveness_(live) { }
|
||||
// accessors
|
||||
bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); }
|
||||
unsigned offset(const data_layout *x) const { return offsets_.at(x); }
|
||||
unsigned allocated_size() const { return allocated_size_; }
|
||||
// run
|
||||
void run(ir::module& mod);
|
||||
|
||||
private:
|
||||
std::map<const data_layout*, unsigned> offsets_;
|
||||
size_t allocated_size_;
|
||||
// dependences
|
||||
liveness *liveness_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,52 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_
|
||||
#define _TRITON_CODEGEN_ANALYSIS_AXES_H_
|
||||
|
||||
#include "triton/tools/graph.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class module;
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class axes {
|
||||
typedef std::pair<ir::value*, unsigned> node_t;
|
||||
|
||||
private:
|
||||
// update graph
|
||||
void update_graph_store(ir::instruction *i);
|
||||
void update_graph_reduce(ir::instruction *i);
|
||||
void update_graph_reshape(ir::instruction *i);
|
||||
void update_graph_trans(ir::instruction *i);
|
||||
void update_graph_broadcast(ir::instruction *i);
|
||||
void update_graph_dot(ir::instruction *i);
|
||||
void update_graph_elementwise(ir::instruction *i,
|
||||
bool is_masked_load_async=false);
|
||||
void update_graph_no_edge(ir::instruction *i);
|
||||
void update_graph(ir::instruction *i);
|
||||
|
||||
public:
|
||||
axes();
|
||||
void run(ir::module &mod);
|
||||
// accessors
|
||||
int get(ir::value *value, unsigned dim);
|
||||
std::vector<int> get(ir::value *value);
|
||||
|
||||
private:
|
||||
tools::graph<node_t> graph_;
|
||||
std::map<node_t, size_t> axes_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,345 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_
|
||||
#define _TRITON_CODEGEN_ANALYSIS_GRID_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "triton/tools/graph.h"
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class type;
|
||||
class module;
|
||||
class instruction;
|
||||
class phi_node;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class axes;
|
||||
class align;
|
||||
class layout_visitor;
|
||||
class data_layout;
|
||||
class mma_layout;
|
||||
class scanline_layout;
|
||||
class shared_layout;
|
||||
|
||||
|
||||
class layout_visitor {
|
||||
public:
|
||||
virtual void visit_layout(data_layout *);
|
||||
virtual void visit_layout_mma(mma_layout*) = 0;
|
||||
virtual void visit_layout_scanline(scanline_layout*) = 0;
|
||||
virtual void visit_layout_shared(shared_layout*) = 0;
|
||||
};
|
||||
|
||||
class data_layout {
|
||||
protected:
|
||||
enum id_t {
|
||||
MMA,
|
||||
SCANLINE,
|
||||
SHARED
|
||||
};
|
||||
|
||||
typedef std::vector<int> axes_t;
|
||||
typedef std::vector<unsigned> shape_t;
|
||||
typedef std::vector<int> order_t;
|
||||
typedef std::vector<ir::value*> values_t;
|
||||
|
||||
private:
|
||||
template<typename T>
|
||||
T* downcast(id_t id) {
|
||||
if(id_ == id)
|
||||
return static_cast<T*>(this);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
public:
|
||||
data_layout(id_t id,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned> &shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align);
|
||||
// visitor
|
||||
virtual void accept(layout_visitor* vst) = 0;
|
||||
// downcast
|
||||
mma_layout* to_mma() { return downcast<mma_layout>(MMA); }
|
||||
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
|
||||
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
|
||||
// accessors
|
||||
size_t get_rank() { return shape_.size(); }
|
||||
const shape_t& get_shape() const { return shape_; }
|
||||
const order_t& get_order() const { return order_; }
|
||||
const values_t& get_values() const { return values_;}
|
||||
int get_axis(size_t k) const { return axes_.at(k); }
|
||||
std::vector<int> get_axes() const { return axes_; }
|
||||
const int get_order(size_t k) const { return order_.at(k); }
|
||||
// find the position of given axis
|
||||
int find_axis(int to_find) const;
|
||||
|
||||
|
||||
private:
|
||||
id_t id_;
|
||||
axes_t axes_;
|
||||
values_t values_;
|
||||
|
||||
protected:
|
||||
order_t order_;
|
||||
shape_t shape_;
|
||||
};
|
||||
|
||||
class distributed_layout: public data_layout{
|
||||
public:
|
||||
distributed_layout(id_t id,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value*>& values,
|
||||
analysis::align* align);
|
||||
|
||||
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
|
||||
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
|
||||
virtual int contig_per_thread(size_t k) = 0;
|
||||
|
||||
protected:
|
||||
std::vector<int> shape_per_cta_;
|
||||
};
|
||||
|
||||
class mma_layout: public distributed_layout {
|
||||
public:
|
||||
enum TensorCoreType : uint8_t {
|
||||
// floating-point tensor core instr
|
||||
FP32_FP16_FP16_FP32 = 0, // default
|
||||
FP32_BF16_BF16_FP32,
|
||||
FP32_TF32_TF32_FP32,
|
||||
// integer tensor core instr
|
||||
INT32_INT1_INT1_INT32, // Not implemented
|
||||
INT32_INT4_INT4_INT32, // Not implemented
|
||||
INT32_INT8_INT8_INT32, // Not implemented
|
||||
//
|
||||
NOT_APPLICABLE,
|
||||
};
|
||||
|
||||
// Used on nvidia GPUs with sm >= 80
|
||||
inline static const std::map<TensorCoreType, std::vector<int>> mma_instr_shape_ = {
|
||||
{FP32_FP16_FP16_FP32, {16, 8, 16}},
|
||||
{FP32_BF16_BF16_FP32, {16, 8, 16}},
|
||||
{FP32_TF32_TF32_FP32, {16, 8, 8}},
|
||||
|
||||
{INT32_INT1_INT1_INT32, {16, 8, 256}},
|
||||
{INT32_INT4_INT4_INT32, {16, 8, 64}},
|
||||
{INT32_INT8_INT8_INT32, {16, 8, 32}},
|
||||
};
|
||||
|
||||
// shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices)
|
||||
inline static const std::map<TensorCoreType, std::vector<int>> mma_mat_shape_ = {
|
||||
{FP32_FP16_FP16_FP32, {8, 8, 8}},
|
||||
{FP32_BF16_BF16_FP32, {8, 8, 8}},
|
||||
{FP32_TF32_TF32_FP32, {8, 8, 4}},
|
||||
|
||||
{INT32_INT1_INT1_INT32, {8, 8, 64}},
|
||||
{INT32_INT4_INT4_INT32, {8, 8, 32}},
|
||||
{INT32_INT8_INT8_INT32, {8, 8, 16}},
|
||||
};
|
||||
|
||||
inline static const std::map<TensorCoreType, std::string> mma_instr_ptx_ = {
|
||||
{FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
|
||||
{FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"},
|
||||
{FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"},
|
||||
|
||||
{INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"},
|
||||
{INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"},
|
||||
{INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
|
||||
};
|
||||
|
||||
// vector length per ldmatrix (16*8/elelment_size_in_bits)
|
||||
inline static const std::map<TensorCoreType, int> mma_instr_vec_ = {
|
||||
{FP32_FP16_FP16_FP32, 8},
|
||||
{FP32_BF16_BF16_FP32, 8},
|
||||
{FP32_TF32_TF32_FP32, 4},
|
||||
|
||||
{INT32_INT1_INT1_INT32, 128},
|
||||
{INT32_INT4_INT4_INT32, 32},
|
||||
{INT32_INT8_INT8_INT32, 16},
|
||||
};
|
||||
|
||||
public:
|
||||
mma_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target *tgt,
|
||||
shared_layout* layout_a,
|
||||
shared_layout* layout_b,
|
||||
ir::value *dot);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
|
||||
// accessor
|
||||
int fpw(size_t k) { return fpw_.at(k); }
|
||||
int wpt(size_t k) { return wpt_.at(k); }
|
||||
int spw(size_t k) { return spw_.at(k); }
|
||||
int rep(size_t k) { return rep_.at(k); }
|
||||
int contig_per_thread(size_t k) { return contig_per_thread_.at(k); }
|
||||
|
||||
// helpers for generator.cc
|
||||
std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); }
|
||||
std::vector<int> get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); }
|
||||
std::vector<int> get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); }
|
||||
int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); }
|
||||
int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); }
|
||||
|
||||
// setter
|
||||
void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; }
|
||||
|
||||
private:
|
||||
// fragment per warp
|
||||
std::vector<int> fpw_;
|
||||
// shape per warp
|
||||
std::vector<int> spw_;
|
||||
// warp per tile
|
||||
std::vector<int> wpt_;
|
||||
// shape per tile
|
||||
std::vector<int> spt_;
|
||||
// repetitions
|
||||
std::vector<int> rep_;
|
||||
// contiguous per thread
|
||||
std::vector<int> contig_per_thread_;
|
||||
|
||||
TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
|
||||
};
|
||||
|
||||
struct scanline_layout: public distributed_layout {
|
||||
scanline_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align,
|
||||
target* tgt);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
|
||||
// accessor
|
||||
int mts(size_t k) { return mts_.at(k); }
|
||||
int nts(size_t k) { return nts_.at(k); }
|
||||
int contig_per_thread(size_t k) { return nts_.at(k); }
|
||||
|
||||
public:
|
||||
// micro tile size. The size of a tile held by a thread block.
|
||||
std::vector<int> mts_;
|
||||
// nano tile size. The size of a tile held by a thread.
|
||||
std::vector<int> nts_;
|
||||
};
|
||||
|
||||
struct double_buffer_info_t {
|
||||
ir::value* first;
|
||||
ir::value* latch;
|
||||
ir::phi_node* phi;
|
||||
};
|
||||
|
||||
struct N_buffer_info_t {
|
||||
std::vector<ir::value*> firsts; // not necessarily ordered as input order
|
||||
ir::value* latch;
|
||||
ir::phi_node* phi;
|
||||
std::map<ir::value*, int> firsts_idx;
|
||||
};
|
||||
|
||||
// abstract for dot and coresponding smem values
|
||||
class shared_layout: public data_layout {
|
||||
private:
|
||||
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
|
||||
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
|
||||
static void extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t>& res, int &prev_stages);
|
||||
|
||||
public:
|
||||
shared_layout(data_layout *arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shapes,
|
||||
const std::vector<ir::value *> &values_,
|
||||
ir::type *ty,
|
||||
analysis::align* align, target *tgt);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
|
||||
// accessors
|
||||
size_t get_size() { return size_; }
|
||||
ir::type* get_type() { return ty_; }
|
||||
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
|
||||
N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); }
|
||||
int get_num_stages() const;
|
||||
size_t get_per_stage_size() const { return size_ / get_num_stages(); }
|
||||
size_t get_per_stage_elements() const;
|
||||
size_t get_num_per_phase() { return num_per_phase_; }
|
||||
ir::value* hmma_dot_a() { return hmma_dot_a_; }
|
||||
ir::value* hmma_dot_b() { return hmma_dot_b_; }
|
||||
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
|
||||
int get_mma_vec() { return mma_vec_;}
|
||||
int get_mma_strided() { return mma_strided_; }
|
||||
bool allow_swizzle() const { return allow_swizzle_; }
|
||||
data_layout* get_arg_layout() { return arg_layout_; }
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
ir::type *ty_;
|
||||
std::shared_ptr<double_buffer_info_t> double_buffer_;
|
||||
std::shared_ptr<N_buffer_info_t> N_buffer_;
|
||||
size_t num_per_phase_;
|
||||
ir::value* hmma_dot_a_;
|
||||
ir::value* hmma_dot_b_;
|
||||
data_layout* arg_layout_;
|
||||
int mma_vec_;
|
||||
int mma_strided_;
|
||||
bool allow_swizzle_ = true;
|
||||
target *tgt_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class layouts {
|
||||
typedef ir::value* node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
|
||||
private:
|
||||
// graph creation
|
||||
void connect(ir::value *x, ir::value *y);
|
||||
void make_graph(ir::instruction *i);
|
||||
|
||||
void init_hmma_tile(data_layout& layouts);
|
||||
void init_scanline_tile(data_layout &layouts);
|
||||
|
||||
void create(size_t id, const std::vector<ir::value*>& values);
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
|
||||
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
|
||||
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
|
||||
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
|
||||
size_t num_layouts() const { return values_.size();}
|
||||
data_layout* get(size_t id) { return layouts_.at(id); }
|
||||
data_layout* get(ir::value *v) { return get(layout_of(v));}
|
||||
std::map<size_t, data_layout*> &get_all() { return layouts_; }
|
||||
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
|
||||
int tmp(ir::value* i) { return tmp_.at(i);}
|
||||
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
|
||||
// execution
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::axes* axes_;
|
||||
analysis::align* align_;
|
||||
size_t num_warps_;
|
||||
target* tgt_;
|
||||
tools::graph<ir::value*> graph_;
|
||||
std::map<ir::value*, size_t> groups_;
|
||||
std::map<size_t, std::vector<ir::value*>> values_;
|
||||
std::map<size_t, data_layout*> layouts_;
|
||||
std::map<ir::value*, size_t> tmp_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,67 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/tools/graph.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class phi_node;
|
||||
class function;
|
||||
class module;
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
typedef unsigned slot_index;
|
||||
|
||||
class tiles;
|
||||
class layouts;
|
||||
class data_layout;
|
||||
|
||||
struct segment {
|
||||
slot_index start;
|
||||
slot_index end;
|
||||
|
||||
bool contains(slot_index idx) const {
|
||||
return start <= idx && idx < end;
|
||||
}
|
||||
|
||||
bool intersect(const segment &Other){
|
||||
return contains(Other.start) || Other.contains(start);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class liveness {
|
||||
private:
|
||||
typedef std::map<shared_layout*, segment> intervals_map_t;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
liveness(layouts *l): layouts_(l){ }
|
||||
// accessors
|
||||
const intervals_map_t& get() const { return intervals_; }
|
||||
segment get(shared_layout* v) const { return intervals_.at(v); }
|
||||
// run
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
// analysis
|
||||
layouts *layouts_;
|
||||
intervals_map_t intervals_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -1,43 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
|
||||
|
||||
#include <map>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
class target;
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class layouts;
|
||||
class data_layout;
|
||||
|
||||
class swizzle {
|
||||
public:
|
||||
// constructor
|
||||
swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ }
|
||||
// accessors
|
||||
int get_per_phase(data_layout* layout) { return per_phase_.at(layout); }
|
||||
int get_max_phase(data_layout* layout) { return max_phase_.at(layout); }
|
||||
int get_vec (data_layout* layout) { return vec_.at(layout); }
|
||||
// run
|
||||
void run(ir::module &mod);
|
||||
private:
|
||||
layouts* layouts_;
|
||||
target* tgt_;
|
||||
std::map<data_layout*, int> per_phase_;
|
||||
std::map<data_layout*, int> max_phase_;
|
||||
std::map<data_layout*, int> vec_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -1,42 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_PASS_H_
|
||||
#define _TRITON_CODEGEN_PASS_H_
|
||||
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace llvm{
|
||||
class Module;
|
||||
class LLVMContext;
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace codegen {
|
||||
class target;
|
||||
}
|
||||
|
||||
namespace ir{
|
||||
class module;
|
||||
}
|
||||
namespace driver{
|
||||
class device;
|
||||
class module;
|
||||
class kernel;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
|
||||
codegen::target* target,
|
||||
int sm, int num_warps,
|
||||
int num_stages, int &shared_static);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,258 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_SELECTION_GENERATOR_H_
|
||||
#define _TRITON_SELECTION_GENERATOR_H_
|
||||
|
||||
#include "triton/ir/visitor.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include <functional>
|
||||
|
||||
// forward
|
||||
namespace llvm{
|
||||
class Type;
|
||||
class Value;
|
||||
class PHINode;
|
||||
class BasicBlock;
|
||||
class Attribute;
|
||||
class Instruction;
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class ConstantFolder;
|
||||
class IRBuilderDefaultInserter;
|
||||
template <typename T, typename Inserter>
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class attribute;
|
||||
class load_inst;
|
||||
class store_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
// forward
|
||||
namespace analysis{
|
||||
class liveness;
|
||||
class tiles;
|
||||
class align;
|
||||
class allocation;
|
||||
class cts;
|
||||
class axes;
|
||||
class layouts;
|
||||
class swizzle;
|
||||
}
|
||||
// typedef
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
llvm::IRBuilderDefaultInserter> Builder;
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Attribute Attribute;
|
||||
typedef llvm::BasicBlock BasicBlock;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
typedef std::vector<Value*> indices_t;
|
||||
class target;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
struct distributed_axis {
|
||||
int contiguous;
|
||||
std::vector<Value*> values;
|
||||
Value* thread_id;
|
||||
};
|
||||
|
||||
class adder{
|
||||
public:
|
||||
adder(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value* x, Value* y, const std::string& name = "");
|
||||
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class multiplier{
|
||||
public:
|
||||
multiplier(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value* x, Value* y, const std::string& name = "");
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class geper{
|
||||
public:
|
||||
geper(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value *ptr, Value* off, const std::string& name = "");
|
||||
Value* operator()(Type* ty, Value*ptr, std::vector<Value*> vals, const std::string& name = "");
|
||||
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class generator: public ir::visitor, public analysis::layout_visitor {
|
||||
private:
|
||||
void init_idx(ir::value *x);
|
||||
Instruction* add_barrier();
|
||||
Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx);
|
||||
void finalize_shared_layout(analysis::shared_layout*);
|
||||
void finalize_function(ir::function*);
|
||||
void finalize_phi_node(ir::phi_node*);
|
||||
|
||||
private:
|
||||
Type *cvt(ir::type *ty);
|
||||
llvm::Attribute cvt(ir::attribute attr);
|
||||
|
||||
public:
|
||||
generator(analysis::axes *a_axes,
|
||||
analysis::layouts *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
analysis::swizzle *swizzle,
|
||||
target *tgt,
|
||||
unsigned num_warps);
|
||||
|
||||
void visit_value(ir::value* v);
|
||||
void visit_phi_node(ir::phi_node*);
|
||||
void visit_binary_operator(ir::binary_operator*);
|
||||
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
||||
void visit_icmp_inst(ir::icmp_inst*);
|
||||
void visit_fcmp_inst(ir::fcmp_inst*);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
Value* bf16_to_fp32(Value *in0);
|
||||
Value* fp32_to_bf16(Value *in0);
|
||||
|
||||
void visit_cast_inst(ir::cast_inst*);
|
||||
void visit_return_inst(ir::return_inst*);
|
||||
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
||||
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
|
||||
void visit_load_inst(ir::load_inst*);
|
||||
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
|
||||
void visit_masked_load_inst(ir::masked_load_inst*);
|
||||
void visit_store_inst(ir::store_inst*);
|
||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||
void visit_cat_inst(ir::cat_inst*);
|
||||
void visit_reshape_inst(ir::reshape_inst*);
|
||||
void visit_splat_inst(ir::splat_inst*);
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
void visit_downcast_inst(ir::downcast_inst*);
|
||||
void visit_exp_inst(ir::exp_inst*);
|
||||
void visit_cos_inst(ir::cos_inst*);
|
||||
void visit_umulhi_inst(ir::umulhi_inst* x);
|
||||
void visit_sin_inst(ir::sin_inst*);
|
||||
void visit_log_inst(ir::log_inst*);
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
|
||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
|
||||
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||
void visit_dot_inst(ir::dot_inst*);
|
||||
void visit_trans_inst(ir::trans_inst*);
|
||||
void visit_sqrt_inst(ir::sqrt_inst*);
|
||||
Value* shfl_sync(Value* acc, int32_t i);
|
||||
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
void visit_select_inst(ir::select_inst*);
|
||||
void visit_layout_convert(ir::value *out, ir::value *in);
|
||||
void visit_cvt_layout_inst(ir::cvt_layout_inst*);
|
||||
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
|
||||
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
|
||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||
void visit_barrier_inst(ir::barrier_inst*);
|
||||
void visit_prefetch_s_inst(ir::prefetch_s_inst*);
|
||||
void visit_async_wait_inst(ir::async_wait_inst*);
|
||||
// void visit_make_range_dyn(ir::make_range_dyn*);
|
||||
void visit_make_range(ir::make_range*);
|
||||
// void visit_make_range_sta(ir::make_range_sta*);
|
||||
void visit_undef_value(ir::undef_value*);
|
||||
void visit_constant_int(ir::constant_int*);
|
||||
void visit_constant_fp(ir::constant_fp*);
|
||||
void visit_alloc_const(ir::alloc_const*);
|
||||
void visit_function(ir::function*);
|
||||
void visit_basic_block(ir::basic_block*);
|
||||
void visit_argument(ir::argument*);
|
||||
void visit(ir::module &, llvm::Module &);
|
||||
|
||||
// layouts
|
||||
void visit_layout_mma(analysis::mma_layout*);
|
||||
void visit_layout_scanline(analysis::scanline_layout*);
|
||||
void visit_layout_shared(analysis::shared_layout*);
|
||||
|
||||
|
||||
private:
|
||||
LLVMContext *ctx_;
|
||||
Builder* builder_;
|
||||
Module *mod_;
|
||||
|
||||
analysis::axes *a_axes_;
|
||||
analysis::swizzle *swizzle_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
target *tgt_;
|
||||
analysis::layouts *layouts_;
|
||||
analysis::align *alignment_;
|
||||
analysis::allocation *alloc_;
|
||||
Value *shmem_;
|
||||
std::set<ir::value*> seen_;
|
||||
|
||||
unsigned num_warps_;
|
||||
|
||||
std::map<analysis::data_layout*, Value*> offset_a_m_;
|
||||
std::map<analysis::data_layout*, Value*> offset_a_k_;
|
||||
std::map<analysis::data_layout*, Value*> offset_b_k_;
|
||||
std::map<analysis::data_layout*, Value*> offset_b_n_;
|
||||
|
||||
/// layout -> base ptr
|
||||
std::map<analysis::data_layout*, Value*> shared_ptr_;
|
||||
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
|
||||
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
|
||||
/// offset for double-buffered layout
|
||||
std::map<analysis::data_layout*, Value*> shared_off_;
|
||||
|
||||
/// Base shmem pointer of ir value
|
||||
std::map<ir::value*, Value*> shmems_;
|
||||
std::map<ir::value*, Value*> shoffs_;
|
||||
std::map<ir::value*, std::vector<indices_t>> idxs_;
|
||||
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
|
||||
/// idx for multi-stage pipeline
|
||||
std::map<analysis::data_layout*, Value*> read_smem_idx_;
|
||||
std::map<analysis::data_layout*, Value*> write_smem_idx_;
|
||||
|
||||
/// triton bb -> llvm bb
|
||||
std::map<ir::value*, BasicBlock *> bbs_;
|
||||
std::map<ir::value*, std::vector<int>> ords_;
|
||||
|
||||
// helper for creating llvm values
|
||||
adder add;
|
||||
multiplier mul;
|
||||
geper gep;
|
||||
|
||||
/// PHI nodes
|
||||
std::vector<std::tuple<llvm::PHINode*, Value*, ir::basic_block*>> lazy_phi_incs_;
|
||||
|
||||
/// Record prefetch instrs that needs to be moved
|
||||
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,105 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_TARGET_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_TARGET_H
|
||||
|
||||
namespace llvm{
|
||||
class Type;
|
||||
class Value;
|
||||
class Instruction;
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class ConstantFolder;
|
||||
class IRBuilderDefaultInserter;
|
||||
template <typename T, typename Inserter>
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
}
|
||||
|
||||
// typedefs
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
llvm::IRBuilderDefaultInserter> Builder;
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
class nvidia_cu_target;
|
||||
|
||||
class target {
|
||||
public:
|
||||
target(bool is_gpu): is_gpu_(is_gpu){}
|
||||
virtual ~target() {}
|
||||
virtual void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn) = 0;
|
||||
virtual Instruction* add_barrier(Module *module, Builder& builder) = 0;
|
||||
virtual Instruction* add_memfence(Module *module, Builder& builder) = 0;
|
||||
virtual Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax) = 0;
|
||||
virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual unsigned guaranteed_alignment() = 0;
|
||||
nvidia_cu_target* as_nvidia();
|
||||
bool is_gpu() const;
|
||||
|
||||
private:
|
||||
bool is_gpu_;
|
||||
};
|
||||
|
||||
class amd_cl_target: public target {
|
||||
public:
|
||||
amd_cl_target(): target(true){}
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
unsigned guaranteed_alignment() { return 16; }
|
||||
};
|
||||
|
||||
class nvidia_cu_target: public target {
|
||||
public:
|
||||
nvidia_cu_target(int sm): target(true), sm_(sm){}
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
int sm() { return sm_; }
|
||||
unsigned guaranteed_alignment() { return 16; }
|
||||
|
||||
private:
|
||||
int sm_;
|
||||
};
|
||||
|
||||
class cpu_target: public target {
|
||||
public:
|
||||
cpu_target(): target(false){}
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
unsigned guaranteed_alignment() { return 1; }
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,48 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class io_inst;
|
||||
class instruction;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class align;
|
||||
class layouts;
|
||||
class cts;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class coalesce {
|
||||
private:
|
||||
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
|
||||
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
|
||||
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
|
||||
|
||||
public:
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
|
||||
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::align* align_;
|
||||
analysis::layouts* layout_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,36 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
|
||||
#define TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class phi_node;
|
||||
class instruction;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class cts {
|
||||
private:
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
|
||||
|
||||
public:
|
||||
cts(bool use_async = false): use_async_(use_async) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
bool use_async_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,24 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
|
||||
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class dce {
|
||||
public:
|
||||
dce() {}
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,22 +0,0 @@
|
||||
#ifndef _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
|
||||
#define _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
|
||||
|
||||
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class disassociate {
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,72 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <list>
|
||||
#include <set>
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class basic_block;
|
||||
class instruction;
|
||||
class masked_load_async_inst;
|
||||
class value;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class allocation;
|
||||
class liveness;
|
||||
class layouts;
|
||||
class cts;
|
||||
class shared_layout;
|
||||
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class prefetch;
|
||||
|
||||
class membar {
|
||||
private:
|
||||
typedef std::pair<unsigned, unsigned> interval_t;
|
||||
typedef std::set<ir::value*> val_set_t;
|
||||
typedef std::vector<ir::value*> val_vec_t;
|
||||
|
||||
private:
|
||||
bool intersect(const val_set_t &X, const val_set_t &Y);
|
||||
bool check_safe_war(ir::instruction* i);
|
||||
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
|
||||
bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout);
|
||||
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
|
||||
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
|
||||
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
|
||||
|
||||
public:
|
||||
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc,
|
||||
transform::prefetch *prefetch, target* tgt):
|
||||
liveness_(liveness), layouts_(layouts), alloc_(alloc), prefetch_(prefetch), tgt_(tgt) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::liveness *liveness_;
|
||||
analysis::layouts *layouts_;
|
||||
analysis::allocation *alloc_;
|
||||
transform::prefetch *prefetch_;
|
||||
|
||||
target* tgt_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,53 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class instruction;
|
||||
class trans_inst;
|
||||
class builder;
|
||||
class constant_int;
|
||||
class dot_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
class layouts;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class peephole {
|
||||
private:
|
||||
// bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
|
||||
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
|
||||
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder);
|
||||
|
||||
public:
|
||||
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
target* tgt_;
|
||||
analysis::layouts* layouts_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,30 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||
|
||||
// forward declaration
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
} // namespace triton
|
||||
|
||||
namespace triton {
|
||||
namespace codegen {
|
||||
namespace transform {
|
||||
|
||||
class pipeline {
|
||||
public:
|
||||
pipeline(bool has_copy_async, int num_stages)
|
||||
: has_copy_async_(has_copy_async), num_stages_(num_stages) {}
|
||||
void run(ir::module &module);
|
||||
|
||||
private:
|
||||
bool has_copy_async_;
|
||||
int num_stages_;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
@@ -1,27 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
||||
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
||||
|
||||
#include <set>
|
||||
|
||||
// forward dclaration
|
||||
namespace triton::ir{
|
||||
class module;
|
||||
class value;
|
||||
}
|
||||
|
||||
namespace triton::codegen {
|
||||
class target;
|
||||
}
|
||||
|
||||
namespace triton::codegen::transform {
|
||||
class prefetch {
|
||||
target* tgt_;
|
||||
std::set<ir::value*> prefetched_vals_;
|
||||
public:
|
||||
prefetch(target *tgt) : tgt_(tgt) {}
|
||||
void run(ir::module &module);
|
||||
bool is_prefetched(ir::value* v) { return prefetched_vals_.find(v) != prefetched_vals_.end(); }
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,26 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H
|
||||
|
||||
namespace triton {
|
||||
|
||||
// forward declaration
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace transform{
|
||||
|
||||
class reorder {
|
||||
public:
|
||||
void run(ir::module& module);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,5 +0,0 @@
|
||||
file(GLOB_RECURSE CODEGEN_SRC *.cc)
|
||||
|
||||
add_library(TritonCodeGen
|
||||
${CODEGEN_SRC}
|
||||
)
|
@@ -1,533 +0,0 @@
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
// Function for extended Euclidean Algorithm
|
||||
int gcd_impl(int a, int b, int *x, int *y)
|
||||
{
|
||||
// Base Case
|
||||
if (a == 0)
|
||||
{
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
}
|
||||
|
||||
int x1, y1; // To store results of recursive call
|
||||
int gcd = gcd_impl(b%a, a, &x1, &y1);
|
||||
|
||||
// Update x and y using results of
|
||||
// recursive call
|
||||
*x = y1 - (b/a) * x1;
|
||||
*y = x1;
|
||||
|
||||
return gcd;
|
||||
}
|
||||
|
||||
int gcd(int a, int b) {
|
||||
int x, y;
|
||||
return gcd_impl(a, b, &x, &y);
|
||||
}
|
||||
|
||||
|
||||
inline int lcm(int a, int b) {
|
||||
return (a * b) / gcd(a, b);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
|
||||
return map[i] = value;
|
||||
}
|
||||
|
||||
/*
|
||||
* is constant
|
||||
*/
|
||||
|
||||
std::vector<unsigned> align::get_shapes(ir::value *v) {
|
||||
ir::type *ty = v->get_type();
|
||||
if(ty->is_block_ty())
|
||||
return ty->get_block_shapes();
|
||||
else
|
||||
return {1};
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<cst_info> result(shapes.size(), cst_info{1, 0});
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto it = is_constant_.find(inc);
|
||||
if(it != is_constant_.end())
|
||||
result = it->second;
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto cst = populate_is_constant(inc);
|
||||
for(size_t d = 0; d < cst.size(); d++)
|
||||
result[d].num_cst = std::min(result[d].num_cst, cst[d].num_cst);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
ir::value* op = x->get_operand(0);
|
||||
std::vector<cst_info> result;
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(auto d: shapes)
|
||||
result.push_back(cst_info{d, op_cst[0].value});
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(size_t d = 0; d < x_shapes.size(); d ++){
|
||||
cst_info ax ;
|
||||
if(x_shapes[d] == 1)
|
||||
ax = {1, op_cst[current].value};
|
||||
else if(!is_skewed
|
||||
&& x_shapes[d] == op_shapes[current])
|
||||
ax = {x_shapes[d], op_cst[current++].value};
|
||||
else {
|
||||
is_skewed = true;
|
||||
ax = {x_shapes[d], 0};
|
||||
}
|
||||
result.push_back(ax);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++)
|
||||
if(op_shapes[d] == 1)
|
||||
result.push_back(cst_info{x_shapes[d], op_cst[d].value});
|
||||
else
|
||||
result.push_back(op_cst[d]);
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
auto max_contiguous = populate_max_contiguous(lhs_op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
cst_info ax;
|
||||
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
|
||||
// todo might not be entirely true
|
||||
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
|
||||
ax = {num_constants, 0};
|
||||
}
|
||||
else
|
||||
ax = {std::min(lhs[d].num_cst, rhs[d].num_cst), 0};
|
||||
result.push_back(ax);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_gep(ir::getelementptr_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
std::vector<cst_info> result;
|
||||
for(size_t d = 0; d < x_shapes.size(); d++)
|
||||
result.push_back({std::min(lhs[d].num_cst, rhs[d].num_cst), 0});
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_default(ir::value *v) {
|
||||
auto shapes = get_shapes(v);
|
||||
std::vector<cst_info> result(shapes.size(), {1, 0});
|
||||
return add_to_cache(v, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
if(is_constant_.find(v) != is_constant_.end())
|
||||
return is_constant_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_is_constant_phi(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_is_constant_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_is_constant_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_is_constant_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_is_constant_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_is_constant_gep(x);
|
||||
return populate_is_constant_default(v);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* max contiguous
|
||||
*/
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_phi(ir::phi_node* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result(shapes.size(), 1);
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto it = max_contiguous_.find(inc);
|
||||
if(it != max_contiguous_.end())
|
||||
result = it->second;
|
||||
}
|
||||
add_to_cache(x, result, max_contiguous_);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto contiguous = populate_max_contiguous(inc);
|
||||
for(size_t d = 0; d < result.size(); d++)
|
||||
result[d] = std::min(result[d], contiguous[d]);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_splat(ir::splat_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
for(size_t d = 0; d < x_shapes.size(); d++)
|
||||
result.push_back({1});
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(size_t d = 0; d < shapes.size(); d ++){
|
||||
if(shapes[d] == 1)
|
||||
result.push_back(1);
|
||||
else if(!is_skewed
|
||||
&& shapes[d] == op_shapes[current])
|
||||
result.push_back(op_mc[current++]);
|
||||
else {
|
||||
is_skewed = true;
|
||||
result.push_back(1);
|
||||
}
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
if(op_shapes[d] == 1)
|
||||
result.push_back(1);
|
||||
else
|
||||
result.push_back(op_mc[d]);
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
auto lhs_cst_info = populate_is_constant(lhs);
|
||||
auto rhs_cst_info = populate_is_constant(rhs);
|
||||
auto lhs_starting_multiple = populate_starting_multiple(lhs);
|
||||
auto rhs_starting_multiple = populate_starting_multiple(rhs);
|
||||
std::vector<unsigned> result;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
unsigned value = 1;
|
||||
if(x->is_int_rem() && rhs_starting_multiple[d] > 0){
|
||||
value = std::min(lhs_max_contiguous[d], rhs_starting_multiple[d]);
|
||||
}
|
||||
if(x->is_int_mult()){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
if(rhs_cst_info[d].value == 1)
|
||||
lvalue = lhs_max_contiguous[d];
|
||||
if(lhs_cst_info[d].value == 1)
|
||||
rvalue = rhs_max_contiguous[d];
|
||||
value = std::max(lvalue, rvalue);
|
||||
}
|
||||
if(x->is_int_add_sub()){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]);
|
||||
rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]);
|
||||
value = std::max(lvalue, rvalue);
|
||||
}
|
||||
result.push_back(value);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
auto lhs_cst_info = populate_is_constant(lhs);
|
||||
auto rhs_cst_info = populate_is_constant(rhs);
|
||||
std::vector<unsigned> result(shapes.size(), 1);
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
if(lhs_cst_info[d].num_cst)
|
||||
lvalue = rhs_max_contiguous[d];
|
||||
if(rhs_cst_info[d].num_cst)
|
||||
rvalue = lhs_max_contiguous[d];
|
||||
result[d] = std::max(lvalue, rvalue);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
|
||||
if(!v->get_type()->is_block_ty())
|
||||
return add_to_cache(v, {1}, max_contiguous_);
|
||||
auto shapes = v->get_type()->get_block_shapes();
|
||||
if(dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){
|
||||
auto result = populate_max_contiguous(v->get_operand(0));
|
||||
return add_to_cache(v, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||
return max_contiguous_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
|
||||
if(max_contiguous > 0)
|
||||
return add_to_cache(x, {max_contiguous}, max_contiguous_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||
return populate_max_contiguous_cast(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_max_contiguous_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_max_contiguous_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_max_contiguous_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_max_contiguous_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_max_contiguous_gep(x);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_max_contiguous_phi(x);
|
||||
return populate_max_contiguous_default(v);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* starting multiple
|
||||
*/
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_splat(ir::splat_inst* x){
|
||||
auto shapes = get_shapes(x);
|
||||
auto op = populate_starting_multiple(x->get_operand(0));
|
||||
std::vector<unsigned> result(shapes.size(), op[0]);
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst* x){
|
||||
auto op = populate_starting_multiple(x->get_operand(0));
|
||||
auto op_shapes = get_shapes(x->get_operand(0));
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result(shapes.size(), 1);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(size_t d = 0; d < shapes.size(); d ++){
|
||||
if(shapes[d] == 1)
|
||||
result[d] = 1;
|
||||
else if(!is_skewed
|
||||
&& shapes[d] == op_shapes[current])
|
||||
result[d] = op[current++];
|
||||
else {
|
||||
is_skewed = true;
|
||||
result[d] = 1;
|
||||
}
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
|
||||
auto result = populate_starting_multiple(x->get_operand(0));
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operator* x){
|
||||
auto lhs = populate_starting_multiple(x->get_operand(0));
|
||||
auto rhs = populate_starting_multiple(x->get_operand(1));
|
||||
std::vector<unsigned> result(lhs.size(), 1);
|
||||
for(size_t d = 0; d < lhs.size(); d++){
|
||||
if(x->is_int_mult())
|
||||
result[d] = lhs[d] * rhs[d];
|
||||
if(x->is_int_add_sub())
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
if(x->is_int_div())
|
||||
result[d] = 1;
|
||||
if(x->is_int_rem() && rhs[d] > 1){
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
}
|
||||
if(x->is_shl())
|
||||
result[d] = lhs[d] << rhs[d];
|
||||
if(x->is_shr())
|
||||
result[d] = std::max<unsigned>(lhs[d] >> rhs[d], 1);
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_gep(ir::getelementptr_inst* x){
|
||||
auto lhs = populate_starting_multiple(x->get_operand(0));
|
||||
auto rhs = populate_starting_multiple(x->get_operand(1));
|
||||
std::vector<unsigned> result(lhs.size(), 1);
|
||||
for(size_t d = 0; d < lhs.size(); d++){
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
// std::cout << "starting multiple: " << x->get_name() << " " << d << " " << result[d] << std::endl;
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
|
||||
auto shape = get_shapes(x);
|
||||
std::vector<unsigned> result(shape.size(), 1);
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
if(starting_multiple_.find(inc) != starting_multiple_.end())
|
||||
result = starting_multiple_.at(inc);
|
||||
}
|
||||
add_to_cache(x, result, starting_multiple_);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto sm = populate_starting_multiple(inc);
|
||||
for(size_t d = 0; d < result.size(); d++)
|
||||
result[d] = gcd(result[d], sm[d]);
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){
|
||||
auto result = populate_starting_multiple(x->get_operand(0));
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
||||
ir::type* ty = v->get_type();
|
||||
if(ty->is_block_ty()) {
|
||||
return add_to_cache(v, ty->get_block_shapes(), starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::argument*>(v)){
|
||||
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
|
||||
for(auto attr: attributes){
|
||||
if(attr.get_kind() == ir::multiple_of){
|
||||
return add_to_cache(x, {attr.get_value()}, starting_multiple_);
|
||||
}
|
||||
if(attr.get_kind() == ir::aligned){
|
||||
ir::type* ty = x->get_type()->get_pointer_element_ty();
|
||||
int nbits = ty->get_primitive_size_in_bits();
|
||||
int nbytes = std::max<int>(nbits / 8, 1);
|
||||
return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
|
||||
}
|
||||
}
|
||||
}
|
||||
return add_to_cache(v, {1}, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||
return starting_multiple_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
if(multiple_of > 0)
|
||||
return add_to_cache(x, {multiple_of}, starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||
return populate_starting_multiple_cast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_starting_multiple_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_starting_multiple_gep(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_starting_multiple_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_starting_multiple_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_starting_multiple_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_starting_multiple_phi(x);
|
||||
return populate_starting_multiple_default(v);
|
||||
}
|
||||
|
||||
|
||||
unsigned align::get(ir::value *v, unsigned ax) const {
|
||||
unsigned starting_multiple = starting_multiple_.at(v)[ax];
|
||||
unsigned max_contiguous = max_contiguous_.at(v)[ax];
|
||||
return std::min(starting_multiple, max_contiguous);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::contiguous(ir::value* v) const {
|
||||
return max_contiguous_.at(v);
|
||||
}
|
||||
|
||||
|
||||
void align::populate(ir::value *v) {
|
||||
populate_is_constant(v);
|
||||
populate_starting_multiple(v);
|
||||
populate_max_contiguous(v);
|
||||
|
||||
}
|
||||
|
||||
void align::run(ir::module &mod) {
|
||||
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
|
||||
// ir::for_each_value(mod, [this](ir::value* v) {
|
||||
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
// std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl;
|
||||
// });
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,101 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
void allocation::run(ir::module &mod) {
|
||||
using std::max;
|
||||
using std::min;
|
||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
||||
|
||||
std::vector<shared_layout*> I;
|
||||
for(auto x: liveness_->get())
|
||||
I.push_back(x.first);
|
||||
std::vector<shared_layout*> J = I;
|
||||
|
||||
triples_map_type H;
|
||||
H.insert({0, segment{0, INT_MAX}});
|
||||
|
||||
std::vector<shared_layout*> V;
|
||||
std::map<shared_layout*, unsigned> starts;
|
||||
while(!J.empty()){
|
||||
auto h_it = H.begin();
|
||||
unsigned w = h_it->first;
|
||||
segment xh = h_it->second;
|
||||
H.erase(h_it);
|
||||
auto j_it = std::find_if(J.begin(), J.end(), [&](shared_layout* JJ){
|
||||
segment xj = liveness_->get(JJ);
|
||||
bool res = xj.intersect(xh);
|
||||
for(auto val: H)
|
||||
res = res && !val.second.intersect(xj);
|
||||
return res;
|
||||
});
|
||||
if(j_it != J.end()){
|
||||
unsigned size = (*j_it)->get_size();
|
||||
segment xj = liveness_->get(*j_it);
|
||||
starts[*j_it] = w;
|
||||
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
|
||||
if(xh.start < xj.start)
|
||||
H.insert({w, segment{xh.start, xj.end}});
|
||||
if(xj.end < xh.end)
|
||||
H.insert({w, segment{xj.start, xh.end}});
|
||||
V.push_back(*j_it);
|
||||
J.erase(j_it);
|
||||
}
|
||||
}
|
||||
// Build interference graph
|
||||
std::map<shared_layout*, std::set<shared_layout*>> interferences;
|
||||
for(shared_layout* x: V)
|
||||
for(shared_layout* y: V){
|
||||
if(x == y)
|
||||
continue;
|
||||
unsigned X0 = starts[x], Y0 = starts[y];
|
||||
unsigned NX = x->get_size();
|
||||
unsigned NY = y->get_size();
|
||||
segment XS = {X0, X0 + NX};
|
||||
segment YS = {Y0, Y0 + NY};
|
||||
if(liveness_->get(x).intersect(liveness_->get(y))
|
||||
&& XS.intersect(YS))
|
||||
interferences[x].insert(y);
|
||||
}
|
||||
// Initialize colors
|
||||
std::map<shared_layout*, int> colors;
|
||||
for(shared_layout* X: V)
|
||||
colors[X] = (X==V[0])?0:-1;
|
||||
// First-fit graph coloring
|
||||
std::vector<bool> available(V.size());
|
||||
for(shared_layout* x: V){
|
||||
// Non-neighboring colors are available
|
||||
std::fill(available.begin(), available.end(), true);
|
||||
for(shared_layout* Y: interferences[x]){
|
||||
int color = colors[Y];
|
||||
if(color >= 0)
|
||||
available[color] = false;
|
||||
}
|
||||
// Assigns first available color
|
||||
auto It = std::find(available.begin(), available.end(), true);
|
||||
colors[x] = std::distance(available.begin(), It);
|
||||
}
|
||||
// Finalize allocation
|
||||
for(shared_layout* x: V){
|
||||
unsigned Adj = 0;
|
||||
for(shared_layout* y: interferences[x])
|
||||
Adj = std::max<unsigned>(Adj, starts[y] + y->get_size());
|
||||
offsets_[x] = starts[x] + colors[x] * Adj;
|
||||
}
|
||||
// Save maximum size of induced memory space
|
||||
allocated_size_ = 0;
|
||||
for(shared_layout* x: V)
|
||||
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,162 +0,0 @@
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include <iostream>
|
||||
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
axes::axes() {}
|
||||
|
||||
void axes::update_graph_reduce(ir::instruction *i) {
|
||||
auto* red = static_cast<ir::reduce_inst*>(i);
|
||||
unsigned axis = red->get_axis();
|
||||
ir::value *arg = red->get_operand(0);
|
||||
auto in_shapes = arg->get_type()->get_block_shapes();
|
||||
unsigned current = 0;
|
||||
for(unsigned d = 0; d < in_shapes.size(); d++){
|
||||
if(d == axis)
|
||||
continue;
|
||||
graph_.add_edge({i, current++}, {arg, d});
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_reshape(ir::instruction *i) {
|
||||
auto* reshape = static_cast<ir::reshape_inst*>(i);
|
||||
// operands
|
||||
ir::value *op = reshape->get_operand(0);
|
||||
// shapes
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto res_shapes = reshape->get_type()->get_block_shapes();
|
||||
// construct edges
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(unsigned d = 0; d < res_shapes.size(); d ++){
|
||||
bool same_shape = res_shapes[d] == op_shapes[current];
|
||||
// either add edge between axis or just add a node in the graph
|
||||
if(!is_skewed && same_shape)
|
||||
graph_.add_edge({i, d}, {op, current++});
|
||||
else
|
||||
graph_.add_edge({i, d}, {i, d});
|
||||
// reshaping is skewed
|
||||
if(res_shapes[d] > 1 && !same_shape)
|
||||
is_skewed = true;
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_trans(ir::instruction *i) {
|
||||
auto *trans = static_cast<ir::trans_inst*>(i);
|
||||
ir::value *op = trans->get_operand(0);
|
||||
auto perm = trans->get_perm();
|
||||
// add edge between axis perm[d] and axis d
|
||||
for(unsigned d = 0; d < perm.size(); d++)
|
||||
graph_.add_edge({i, perm[d]}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_broadcast(ir::instruction *i) {
|
||||
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
|
||||
auto shapes = broadcast->get_type()->get_block_shapes();
|
||||
ir::value *op = broadcast->get_operand(0);
|
||||
ir::type *op_ty = op->get_type();
|
||||
const auto& op_shapes = op_ty->get_block_shapes();
|
||||
// add edge between non-broadcast axes
|
||||
for(unsigned d = 0; d < shapes.size(); d ++)
|
||||
if(op_shapes[d] == shapes[d])
|
||||
graph_.add_edge({i, d}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_dot(ir::instruction *i) {
|
||||
auto *dot = static_cast<ir::dot_inst*>(i);
|
||||
auto shapes = dot->get_type()->get_block_shapes();
|
||||
ir::value *A = dot->get_operand(0);
|
||||
ir::value *B = dot->get_operand(1);
|
||||
ir::value *D = dot->get_operand(2);
|
||||
// add edges between result and accumulator
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
graph_.add_edge({dot, d}, {D, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_elementwise(ir::instruction *i,
|
||||
bool is_masked_load_async) {
|
||||
if(i->get_num_operands() == 0)
|
||||
return;
|
||||
ir::value *op = i->get_operand(0);
|
||||
if(!op->get_type()->is_block_ty())
|
||||
return;
|
||||
auto rank = op->get_type()->get_tile_rank();
|
||||
for(unsigned d = 0; d < rank; d++) {
|
||||
// If we are dealing with a masked async load we need to attach the
|
||||
// dimensions so we match the behaviour of the copy_to_shared instruction
|
||||
// which async masked load replaces.
|
||||
if (is_masked_load_async) {
|
||||
graph_.add_edge({i, d}, {i, d});
|
||||
}
|
||||
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()) {
|
||||
if(!is_masked_load_async && !i->get_type()->is_void_ty())
|
||||
graph_.add_edge({i, d}, {opx, d});
|
||||
graph_.add_edge({opx, d}, {opy, d});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_no_edge(ir::instruction *i) {
|
||||
if(!i->get_type()->is_block_ty())
|
||||
return;
|
||||
auto rank = i->get_type()->get_tile_rank();
|
||||
for(unsigned d = 0; d < rank; d++)
|
||||
graph_.add_edge({i, d}, {i, d});
|
||||
}
|
||||
|
||||
void axes::update_graph(ir::instruction *i) {
|
||||
switch (i->get_id()) {
|
||||
case ir::INST_REDUCE: return update_graph_reduce(i);
|
||||
case ir::INST_RESHAPE: return update_graph_reshape(i);
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);
|
||||
case ir::INST_CAT: return update_graph_elementwise(i, true);
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true);
|
||||
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i);
|
||||
default: return update_graph_elementwise(i);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
int axes::get(ir::value *value, unsigned dim) {
|
||||
return axes_.at({value, dim});
|
||||
}
|
||||
|
||||
std::vector<int> axes::get(ir::value *value) {
|
||||
std::vector<int> result;
|
||||
for(size_t d = 0; d < value->get_type()->get_tile_rank(); d++)
|
||||
result.push_back(this->get(value, d));
|
||||
return result;
|
||||
}
|
||||
|
||||
void axes::run(ir::module &mod) {
|
||||
// make graph
|
||||
graph_.clear();
|
||||
axes_.clear();
|
||||
ir::for_each_instruction(mod, [this](ir::instruction *x) {
|
||||
update_graph(x);
|
||||
});
|
||||
// find connected components
|
||||
graph_.connected_components(nullptr, &axes_);
|
||||
std::set<size_t> uniq;
|
||||
for(auto x: axes_)
|
||||
uniq.insert(x.second);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@@ -1,653 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
// #include "triton/ir/type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
/* -------------------------------- *
|
||||
* Helper Functions *
|
||||
* -------------------------------- */
|
||||
|
||||
inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
|
||||
unsigned lo = std::min(a, b);
|
||||
unsigned hi = std::max(a, b);
|
||||
return std::min(std::max(x, lo), hi);
|
||||
}
|
||||
|
||||
inline bool is_hmma_c(ir::value *v, int sm){
|
||||
bool result = false;
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *a = x->get_operand(0);
|
||||
ir::type *a_ty = a->get_type();
|
||||
ir::value *b = x->get_operand(1);
|
||||
ir::type *b_ty = b->get_type();
|
||||
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
|
||||
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
|
||||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
|
||||
x->allow_tf32() && sm >= 80) ||
|
||||
(a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&
|
||||
sm >= 80);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
|
||||
mma_layout::TensorCoreType mma_type;
|
||||
if (auto* dot = dynamic_cast<ir::dot_inst*>(v)) {
|
||||
ir::value* a = dot->get_operand(0);
|
||||
ir::value* b = dot->get_operand(1);
|
||||
ir::type* a_ty = a->get_type();
|
||||
ir::type* b_ty = b->get_type();
|
||||
ir::type* c_ty = v->get_type();
|
||||
|
||||
if (c_ty->get_scalar_ty()->is_fp32_ty()) {
|
||||
// floating point tensor cores
|
||||
if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) {
|
||||
mma_type = mma_layout::FP32_FP16_FP16_FP32;
|
||||
return mma_type;
|
||||
}
|
||||
if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) {
|
||||
mma_type = mma_layout::FP32_BF16_BF16_FP32;
|
||||
return mma_type;
|
||||
}
|
||||
if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty()
|
||||
&& dot->allow_tf32()) {
|
||||
mma_type = mma_layout::FP32_TF32_TF32_FP32;
|
||||
return mma_type;
|
||||
}
|
||||
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
|
||||
// throw std::runtime_error("integer tensor cores are not yet supported");
|
||||
// // integer tensor cores
|
||||
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
|
||||
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
|
||||
// return mma_type;
|
||||
// }
|
||||
// if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) {
|
||||
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
|
||||
// return mma_type;
|
||||
// }
|
||||
if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
|
||||
mma_type = mma_layout::INT32_INT8_INT8_INT32;
|
||||
return mma_type;
|
||||
}
|
||||
}
|
||||
}
|
||||
return mma_layout::NOT_APPLICABLE;
|
||||
}
|
||||
|
||||
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||
if(i && i->get_pointer_operand() == v)
|
||||
result.insert(v);
|
||||
}
|
||||
}
|
||||
|
||||
inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||
if(i && i->get_operand(n) == v)
|
||||
result = v;
|
||||
}
|
||||
}
|
||||
|
||||
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||
if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) {
|
||||
result = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline bool is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
return true;
|
||||
}
|
||||
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
|
||||
bool result = true;
|
||||
for(ir::value *op: phi->ops())
|
||||
result = result && is_trans(op);
|
||||
return result;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* Layout Visitor *
|
||||
* -------------------------------- */
|
||||
|
||||
void layout_visitor::visit_layout(data_layout *layout) {
|
||||
layout->accept(this);
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* Base Data Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
data_layout::data_layout(id_t id,
|
||||
const std::vector<int> &axes,
|
||||
const std::vector<unsigned> &shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) {
|
||||
// io pointer
|
||||
std::set<ir::value*> ptr;
|
||||
for(ir::value* v: values_)
|
||||
extract_io_use(v, ptr);
|
||||
order_.resize(axes_.size());
|
||||
std::iota(order_.begin(), order_.end(), 0);
|
||||
std::vector<unsigned> max_contiguous;
|
||||
for(ir::value* p: ptr){
|
||||
std::vector<unsigned> curr = align->contiguous(p);
|
||||
if(curr.size() > max_contiguous.size())
|
||||
max_contiguous = curr;
|
||||
else if(curr.size() == max_contiguous.size()){
|
||||
if(*std::max_element(curr.begin(), curr.end()) > *std::max_element(max_contiguous.begin(), max_contiguous.end()))
|
||||
max_contiguous = curr;
|
||||
}
|
||||
}
|
||||
if(max_contiguous.size() > 0){
|
||||
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
|
||||
return max_contiguous[a] > max_contiguous[b];
|
||||
});
|
||||
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
|
||||
// std::cout << order_[0] << " " << order_[1] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int data_layout::find_axis(int to_find) const {
|
||||
auto it = std::find(axes_.begin(), axes_.end(), to_find);
|
||||
if(it == axes_.end())
|
||||
return -1;
|
||||
return std::distance(axes_.begin(), it);
|
||||
}
|
||||
|
||||
|
||||
distributed_layout::distributed_layout(id_t id,
|
||||
const std::vector<int> &axes,
|
||||
const std::vector<unsigned> &shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): data_layout(id, axes, shape, values, align)
|
||||
{ }
|
||||
|
||||
/* -------------------------------- *
|
||||
* MMA Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
mma_layout::mma_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target* tgt,
|
||||
shared_layout *layout_a, shared_layout *layout_b,
|
||||
ir::value *dot): distributed_layout(MMA, axes, shape, values, align) {
|
||||
tensor_core_type_ = get_mma_type(dot);
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
if(tgt->as_nvidia()->sm() < 80){
|
||||
fpw_ = {2, 2, 1};
|
||||
auto ord_a = layout_a->get_order();
|
||||
auto ord_b = layout_b->get_order();
|
||||
bool is_a_row = ord_a[0] != 0;
|
||||
bool is_b_row = ord_b[0] != 0;
|
||||
bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16);
|
||||
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
|
||||
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
|
||||
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
|
||||
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
|
||||
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
|
||||
contig_per_thread_ = {1, 1};
|
||||
}
|
||||
else{
|
||||
// fpw_ = {1, 1, 1};
|
||||
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
|
||||
contig_per_thread_ = {1, 2};
|
||||
// rep_ = {2, 2, 1};
|
||||
}
|
||||
order_ = {0, 1};
|
||||
|
||||
/* warps per tile */
|
||||
wpt_ = {1, 1, 1};
|
||||
// try to make warp-level tiles as square as possible to maximize data re-use
|
||||
if (tgt->as_nvidia()->sm() < 80) {
|
||||
std::vector<int> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt_;
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
||||
}while(wpt_nm1 != wpt_);
|
||||
} else {
|
||||
bool changed = false;
|
||||
do {
|
||||
changed = false;
|
||||
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
|
||||
break;
|
||||
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
|
||||
if (wpt_[0] < shape_[0] / spw_[0]) {
|
||||
wpt_[0] *= 2;
|
||||
changed = true;
|
||||
}
|
||||
} else {
|
||||
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
|
||||
wpt_[1] *= 2;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
} while (changed);
|
||||
}
|
||||
|
||||
/* shape per block */
|
||||
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* Scanline Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
scanline_layout::scanline_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target *tgt): distributed_layout(SCANLINE, axes, shape, values, align){
|
||||
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
|
||||
unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1;
|
||||
nts_.resize(shape_.size());
|
||||
mts_.resize(shape_.size());
|
||||
bool is_dot = std::any_of(values.begin(), values.end(),
|
||||
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
|
||||
|
||||
std::vector<ir::value*> ptrs;
|
||||
for(ir::value *v: values)
|
||||
for(ir::user *usr: v->get_users())
|
||||
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
|
||||
if(ptrs.empty() || ptrs[0]->get_type()->get_tile_rank() <= io->get_pointer_operand()->get_type()->get_tile_rank())
|
||||
ptrs.push_back(io->get_pointer_operand());
|
||||
}
|
||||
|
||||
unsigned i = order_[0];
|
||||
int contiguous = 1;
|
||||
for(ir::value* ptr: ptrs){
|
||||
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
|
||||
contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
|
||||
}
|
||||
|
||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
size /= shape_[i];
|
||||
num_threads /= mts_[i];
|
||||
if(is_dot)
|
||||
nts_[order_[1]] = clamp(size / num_threads, 1, std::min<int>(4, shape_[order_[1]]));
|
||||
for(size_t d = 1; d < shape_.size(); d++){
|
||||
i = order_[d];
|
||||
if(d > 1 || !is_dot)
|
||||
nts_[i] = 1;
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
num_threads = num_threads / mts_[i];
|
||||
}
|
||||
|
||||
shape_per_cta_.resize(shape_.size());
|
||||
for(size_t d = 0; d < shape_.size(); d++)
|
||||
shape_per_cta_[d] = mts_[d]*nts_[d];
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* Shared Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||
if(phi->get_parent() != terminator->get_parent())
|
||||
return false;
|
||||
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
|
||||
return br->get_true_dest() == phi->get_parent()
|
||||
|| br->get_false_dest() == phi->get_parent();
|
||||
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
|
||||
return false;
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
|
||||
void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
|
||||
auto* phi = dynamic_cast<ir::phi_node*>(v);
|
||||
if(!phi || phi->get_num_incoming() != 2)
|
||||
return;
|
||||
ir::basic_block *block_0 = phi->get_incoming_block(0);
|
||||
ir::basic_block *block_1 = phi->get_incoming_block(1);
|
||||
ir::instruction *terminator_0 = block_0->get_inst_list().back();
|
||||
ir::instruction *terminator_1 = block_1->get_inst_list().back();
|
||||
bool is_latch_0 = is_loop_latch(phi, terminator_0);
|
||||
bool is_latch_1 = is_loop_latch(phi, terminator_1);
|
||||
ir::value *value_0 = phi->get_incoming_value(0);
|
||||
ir::value *value_1 = phi->get_incoming_value(1);
|
||||
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
|
||||
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
|
||||
if(!(i_0 && !i_1) &&
|
||||
!(dynamic_cast<ir::copy_to_shared_inst*>(i_0) && dynamic_cast<ir::copy_to_shared_inst*>(i_1)) &&
|
||||
!(dynamic_cast<ir::masked_load_async_inst*>(i_0) && dynamic_cast<ir::masked_load_async_inst*>(i_1)))
|
||||
return;
|
||||
if(is_latch_1)
|
||||
res.reset(new double_buffer_info_t{value_0, value_1, phi});
|
||||
if(is_latch_0)
|
||||
res.reset(new double_buffer_info_t{value_1, value_0, phi});
|
||||
}
|
||||
|
||||
static bool is_smem(ir::value* v) {
|
||||
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||
dynamic_cast<ir::masked_load_async_inst*>(v))
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
/// param:
|
||||
/// value_1: next_value
|
||||
static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::basic_block* bb1,
|
||||
std::vector<ir::value*>& values_0, ir::value*& value_1) {
|
||||
ir::value* next = phi;
|
||||
while (auto cphi = dynamic_cast<ir::phi_node*>(next)) {
|
||||
// smem from previous bb & phi/smem from current bb
|
||||
ir::value* c0 = cphi->get_incoming_value(0);
|
||||
ir::value* c1 = cphi->get_incoming_value(1);
|
||||
ir::basic_block *cbb0 = cphi->get_incoming_block(0);
|
||||
ir::basic_block *cbb1 = cphi->get_incoming_block(1);
|
||||
|
||||
if (is_smem(c0)) {
|
||||
assert(cbb0 == bb0);
|
||||
values_0.push_back(c0);
|
||||
if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) {
|
||||
next = phi1;
|
||||
continue;
|
||||
} else {
|
||||
if (is_smem(c1)) {
|
||||
value_1 = c1;
|
||||
assert(cbb1 == bb1);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void shared_layout::extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t> &res, int &prev_stages) {
|
||||
auto* phi = dynamic_cast<ir::phi_node*>(v);
|
||||
// if the phi node is nested
|
||||
if (!phi)
|
||||
return;
|
||||
|
||||
ir::basic_block *bb0 = phi->get_incoming_block(0);
|
||||
ir::basic_block *bb1 = phi->get_incoming_block(1);
|
||||
|
||||
std::vector<ir::value*> values_0;
|
||||
ir::value* value_1;
|
||||
|
||||
if (!is_multistage_pipe_phi(phi, bb0, bb1, values_0, value_1))
|
||||
return;
|
||||
|
||||
// double-buffer is a special case
|
||||
if (values_0.size() == 1)
|
||||
return;
|
||||
|
||||
// compute original values_0 input order
|
||||
std::map<ir::value*, int> order;
|
||||
int idx = 0;
|
||||
for (ir::instruction* instr : *bb0) {
|
||||
if (std::find(values_0.begin(), values_0.end(), instr) != values_0.end())
|
||||
order[static_cast<ir::value*>(instr)] = idx++;
|
||||
}
|
||||
assert(order.size() == values_0.size() && "order size incorrect");
|
||||
|
||||
int curr_stages = values_0.size() + 1;
|
||||
if (curr_stages > prev_stages) {
|
||||
res.reset(new N_buffer_info_t{values_0, value_1, phi, order});
|
||||
prev_stages = curr_stages;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
shared_layout::shared_layout(data_layout *arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
ir::type *ty,
|
||||
analysis::align* align, target *tgt)
|
||||
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) {
|
||||
|
||||
size_ = 0;
|
||||
arg_layout_ = arg;
|
||||
|
||||
// N-stage buffering
|
||||
int prev_stages = 0;
|
||||
for (ir::value *v : values)
|
||||
extract_N_bufferable(v, N_buffer_, prev_stages);
|
||||
|
||||
// double-buffering
|
||||
if (!N_buffer_)
|
||||
for(ir::value *v: values)
|
||||
extract_double_bufferable(v, double_buffer_);
|
||||
|
||||
// order
|
||||
std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
|
||||
order_ = arg_order;
|
||||
|
||||
ir::value* dot_a = nullptr;
|
||||
ir::value* dot_b = nullptr;
|
||||
ir::value* hmma_dot_a = nullptr;
|
||||
ir::value* hmma_dot_b = nullptr;
|
||||
for(ir::value* v: values){
|
||||
extract_dot_use(v, dot_a, 0);
|
||||
extract_dot_use(v, dot_b, 1);
|
||||
extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm());
|
||||
extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm());
|
||||
}
|
||||
hmma_dot_a_ = hmma_dot_a;
|
||||
hmma_dot_b_ = hmma_dot_b;
|
||||
|
||||
// Update mma_vec
|
||||
if (hmma_dot_a_) {
|
||||
assert(order_.size() == 2);
|
||||
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
|
||||
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
|
||||
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
|
||||
|
||||
// for now, disable swizzle when using lds.8
|
||||
if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32)
|
||||
if (order_[0] == 0) // need transpose
|
||||
allow_swizzle_ = false;
|
||||
} else if (hmma_dot_b_) {
|
||||
assert(order_.size() == 2);
|
||||
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
|
||||
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
|
||||
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
|
||||
|
||||
// for now, disable swizzle when using lds.8
|
||||
if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32)
|
||||
if (order_[0] == 1) // need transpose
|
||||
allow_swizzle_ = false;
|
||||
}
|
||||
|
||||
// size
|
||||
size_ = ty_->get_primitive_size_in_bits() / 8;
|
||||
for(auto s: shape_)
|
||||
size_ *= s;
|
||||
if(double_buffer_)
|
||||
size_ *= 2;
|
||||
if (N_buffer_) {
|
||||
size_ *= (N_buffer_->firsts.size() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
int shared_layout::get_num_stages() const {
|
||||
if (double_buffer_)
|
||||
return 2;
|
||||
if (N_buffer_)
|
||||
return N_buffer_->firsts.size() + 1;
|
||||
return 1;
|
||||
}
|
||||
|
||||
size_t shared_layout::get_per_stage_elements() const {
|
||||
return get_per_stage_size()/(ty_->get_primitive_size_in_bits()/8);
|
||||
}
|
||||
|
||||
/* -------------------------------- *
|
||||
* ---- Layouts Inference Pass ---- *
|
||||
* -------------------------------- */
|
||||
|
||||
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt)
|
||||
: axes_(axes), align_(align), num_warps_(num_warps), tgt_(tgt){ }
|
||||
|
||||
|
||||
void layouts::connect(ir::value *x, ir::value *y) {
|
||||
if(x == y)
|
||||
return;
|
||||
if(!x->get_type()->is_block_ty())
|
||||
return;
|
||||
if(!y->get_type()->is_block_ty())
|
||||
return;
|
||||
std::vector<int> x_axes = axes_->get(x);
|
||||
std::vector<int> y_axes = axes_->get(y);
|
||||
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
|
||||
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
|
||||
std::set<int> common;
|
||||
std::set_intersection(sx_axes.begin(), sx_axes.end(),
|
||||
sy_axes.begin(), sy_axes.end(),
|
||||
std::inserter(common, common.begin()));
|
||||
graph_.add_edge(x, x);
|
||||
graph_.add_edge(y, y);
|
||||
if(!common.empty())
|
||||
graph_.add_edge(x, y);
|
||||
}
|
||||
|
||||
void layouts::make_graph(ir::instruction *i) {
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()){
|
||||
connect(i, opx);
|
||||
connect(opx, opy);
|
||||
}
|
||||
}
|
||||
|
||||
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
// if(layouts_.find(id) != layouts_.end())
|
||||
// return;
|
||||
auto it_hmma_c = std::find_if(values.begin(), values.end(),
|
||||
[&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); });
|
||||
auto cmp = [](ir::value* x, ir::value *y) {
|
||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
|
||||
return xx < yy;
|
||||
};
|
||||
std::vector<ir::value*> lvalue = values;
|
||||
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
|
||||
ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
|
||||
const auto& axes = axes_->get(largest);
|
||||
const auto& shapes = largest->get_type()->get_block_shapes();
|
||||
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
|
||||
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||
dynamic_cast<ir::masked_load_async_inst*>(v);
|
||||
});
|
||||
// type
|
||||
if(it_hmma_c != values.end()){
|
||||
ir::instruction *dot = (ir::instruction*)*it_hmma_c;
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
create(groups_.at(a), values_.at(groups_.at(a)));
|
||||
create(groups_.at(b), values_.at(groups_.at(b)));
|
||||
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_,
|
||||
(shared_layout*)layouts_.at(groups_.at(a)),
|
||||
(shared_layout*)layouts_.at(groups_.at(b)),
|
||||
dot);
|
||||
}
|
||||
else if(it_cts != values.end()){
|
||||
ir::instruction *cts = (ir::instruction*)*it_cts;
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
}
|
||||
else{
|
||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
|
||||
}
|
||||
}
|
||||
|
||||
void layouts::run(ir::module &mod) {
|
||||
// make graph
|
||||
graph_.clear();
|
||||
layouts_.clear();
|
||||
groups_.clear();
|
||||
|
||||
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
||||
make_graph(i);
|
||||
});
|
||||
|
||||
|
||||
// connected components
|
||||
graph_.connected_components(&values_, &groups_);
|
||||
|
||||
// create layouts
|
||||
for(const auto& x: values_)
|
||||
create(x.first, x.second);
|
||||
|
||||
// create temporaries
|
||||
size_t id = values_.size();
|
||||
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
|
||||
id++;
|
||||
ir::value *arg = red->get_operand(0);
|
||||
unsigned axis = red->get_axis();
|
||||
// shape
|
||||
auto shapes = arg->get_type()->get_block_shapes();
|
||||
scanline_layout *layout = get(arg)->to_scanline();
|
||||
shapes[axis] = layout->mts(axis);
|
||||
// create layout
|
||||
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
tmp_[red] = id;
|
||||
}
|
||||
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
|
||||
distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
|
||||
distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
|
||||
id++;
|
||||
size_t dim = val->get_type()->get_tile_rank();
|
||||
ir::type::block_shapes_t shape(dim);
|
||||
for(size_t k = 0; k < dim; k++){
|
||||
shape[k] = std::max(in_layout->shape_per_cta(k),
|
||||
out_layout->shape_per_cta(k));
|
||||
}
|
||||
auto in_ord = in_layout->get_order();
|
||||
auto out_ord = out_layout->get_order();
|
||||
int in_vec = in_layout->contig_per_thread(in_ord[0]);
|
||||
int out_vec = out_layout->contig_per_thread(out_ord[0]);
|
||||
int pad = std::max(in_vec, out_vec);
|
||||
shape[out_ord[0]] += pad;
|
||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
tmp_[val] = id;
|
||||
}
|
||||
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
|
||||
id++;
|
||||
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
tmp_[atom] = id;
|
||||
}
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,59 +0,0 @@
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
void liveness::run(ir::module &mod) {
|
||||
intervals_.clear();
|
||||
|
||||
// Assigns index to each instruction
|
||||
std::map<ir::value*, slot_index> indices;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
slot_index index = 0;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()){
|
||||
index += 1;
|
||||
indices.insert({instr, index});
|
||||
}
|
||||
}
|
||||
|
||||
// create live intervals
|
||||
for(auto &x: layouts_->get_all()) {
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(!layout)
|
||||
continue;
|
||||
// users
|
||||
std::set<ir::user*> users;
|
||||
for(ir::value *v: layout->get_values()){
|
||||
for(ir::user *u: v->get_users())
|
||||
users.insert(u);
|
||||
}
|
||||
// compute intervals
|
||||
unsigned start = INT32_MAX;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(indices.find(v) != indices.end())
|
||||
start = std::min(start, indices.at(v));
|
||||
unsigned end = 0;
|
||||
for(ir::user *u: users)
|
||||
if(indices.find(u) != indices.end())
|
||||
end = std::max(end, indices.at(u));
|
||||
if(end == 0)
|
||||
end = start + 1;
|
||||
intervals_[layout] = segment{start, end};
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,61 +0,0 @@
|
||||
#include "triton/codegen/analysis/swizzle.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
void swizzle::run(ir::module &) {
|
||||
per_phase_.clear();
|
||||
max_phase_.clear();
|
||||
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
|
||||
if(!layout)
|
||||
continue;
|
||||
ir::value* mma_dot_a = layout->hmma_dot_a();
|
||||
ir::value* mma_dot_b = layout->hmma_dot_b();
|
||||
|
||||
if(!mma_dot_a && !mma_dot_b){
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
continue;
|
||||
}
|
||||
auto ord = layout->get_order();
|
||||
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
|
||||
if(!in_layout)
|
||||
continue;
|
||||
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
|
||||
int inner = mma_dot_a ? 0 : 1;
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
|
||||
if(mma_dot_a)
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||
else
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||
}
|
||||
else {
|
||||
if (!layout->allow_swizzle()) {
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
} else {
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
|
||||
vec_[layout] = layout->get_mma_vec();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,86 +0,0 @@
|
||||
#include "triton/codegen/pass.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/swizzle.h"
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/codegen/transform/disassociate.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
#include "triton/codegen/transform/prefetch.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
namespace triton {
|
||||
namespace codegen {
|
||||
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
|
||||
int cc, int num_warps, int num_stages, int& shared_static) {
|
||||
// generate llvm code
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||
// optimizations
|
||||
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
|
||||
// create passes
|
||||
codegen::analysis::align align;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::transform::cts cts(cts_use_async);
|
||||
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
|
||||
codegen::transform::disassociate disassociate;
|
||||
codegen::analysis::layouts layouts(&axes, &align, num_warps, target);
|
||||
codegen::analysis::liveness liveness(&layouts);
|
||||
codegen::analysis::swizzle swizzle(&layouts, target);
|
||||
codegen::analysis::allocation allocation(&liveness);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole(target, &layouts);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::prefetch prefetch_s(target);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
|
||||
// run passes
|
||||
dce.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
pipeline.run(ir);
|
||||
dce.run(ir);
|
||||
disassociate.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
if (target->is_gpu())
|
||||
cts.run(ir);
|
||||
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||
coalesce.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
dce.run(ir);
|
||||
if (target->is_gpu())
|
||||
cts.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||
swizzle.run(ir);
|
||||
liveness.run(ir);
|
||||
allocation.run(ir);
|
||||
prefetch_s.run(ir);
|
||||
barriers.run(ir);
|
||||
isel.visit(ir, *llvm);
|
||||
shared_static = allocation.allocated_size();
|
||||
return llvm;
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
File diff suppressed because it is too large
Load Diff
@@ -1,173 +0,0 @@
|
||||
#include "triton/codegen/target.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/Intrinsics.h"
|
||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||
#include "llvm/IR/IntrinsicsAMDGPU.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include <iostream>
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
// base
|
||||
|
||||
|
||||
nvidia_cu_target* target::as_nvidia() {
|
||||
return dynamic_cast<nvidia_cu_target*>(this);
|
||||
}
|
||||
|
||||
bool target::is_gpu() const {
|
||||
return is_gpu_;
|
||||
}
|
||||
|
||||
// AMD
|
||||
void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
|
||||
fn->setCallingConv(CallingConv::AMDGPU_KERNEL);
|
||||
}
|
||||
|
||||
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
|
||||
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_barrier, {}, {});
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* group_id = get_block_id(module, builder, ax);
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
|
||||
return result;
|
||||
}
|
||||
|
||||
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
|
||||
Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::amdgcn_workgroup_id_x,
|
||||
Intrinsic::amdgcn_workgroup_id_y,
|
||||
Intrinsic::amdgcn_workgroup_id_z
|
||||
};
|
||||
Value* group_id = builder.CreateIntrinsic(ids[ax], {}, {});
|
||||
return group_id;
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
throw std::runtime_error("not implemented on AMD");
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::amdgcn_workitem_id_x,
|
||||
Intrinsic::amdgcn_workitem_id_y,
|
||||
Intrinsic::amdgcn_workitem_id_z
|
||||
};
|
||||
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_local_id, {});
|
||||
}
|
||||
|
||||
// NVIDIA
|
||||
|
||||
void nvidia_cu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn){
|
||||
// set metadata
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(fn),
|
||||
MDString::get(ctx, "kernel"),
|
||||
ValueAsMetadata::get(builder.getInt32(1))
|
||||
};
|
||||
module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||
}
|
||||
|
||||
Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_barrier0);
|
||||
return builder.CreateCall(barrier, {});
|
||||
}
|
||||
|
||||
Instruction* nvidia_cu_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_membar_gl);
|
||||
return builder.CreateCall(barrier, {});
|
||||
}
|
||||
|
||||
|
||||
Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* group_id = get_block_id(module, builder, ax);
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
|
||||
return result;
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> cta_ids = {
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
|
||||
};
|
||||
Value* cta_id = builder.CreateIntrinsic(cta_ids[ax], {}, {});
|
||||
return cta_id;
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::nvvm_read_ptx_sreg_tid_x,
|
||||
Intrinsic::nvvm_read_ptx_sreg_tid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_tid_z
|
||||
};
|
||||
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_local_id, {});
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
|
||||
};
|
||||
return builder.CreateIntrinsic(ids[ax], {}, {});
|
||||
}
|
||||
|
||||
// CPU
|
||||
|
||||
void cpu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
|
||||
// normal cpu functions can be kernels
|
||||
}
|
||||
|
||||
Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
// no barrier on CPU
|
||||
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
|
||||
}
|
||||
|
||||
Instruction* cpu_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
// no barrier on CPU
|
||||
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
|
||||
}
|
||||
|
||||
|
||||
Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) {
|
||||
const Function *fn = builder.GetInsertBlock()->getParent();
|
||||
size_t num_params = fn->getFunctionType()->getNumParams();
|
||||
static std::array<const Argument*, 3> ids = {
|
||||
fn->arg_begin() + num_params - 3,
|
||||
fn->arg_begin() + num_params - 2,
|
||||
fn->arg_begin() + num_params - 1
|
||||
};
|
||||
return (Argument*)ids[ax];
|
||||
}
|
||||
|
||||
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
|
||||
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax));
|
||||
return result;
|
||||
}
|
||||
|
||||
Value* cpu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
return builder.getInt32(0);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,133 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
|
||||
: align_(align), layout_(layouts) { }
|
||||
|
||||
|
||||
// simplify layout conversions using the following simple rules:
|
||||
// - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
|
||||
// - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y))
|
||||
//ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){
|
||||
// ir::value* _op = inst->get_operand(0);
|
||||
// ir::instruction* op = dynamic_cast<ir::instruction*>(_op);
|
||||
// analysis::mma_layout* mma_in = layout_->get(op) ->to_mma();
|
||||
// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma();
|
||||
// std::cout << 1 << std::endl;
|
||||
// // i must be layout conversion instruction
|
||||
// if(!mma_in && !mma_out)
|
||||
// return inst;
|
||||
// // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
|
||||
// bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT;
|
||||
// if((mma_in || mma_out) && is_op_cvt &&
|
||||
// (layout_->get(inst) == layout_->get(op->get_operand(0))))
|
||||
// return op->get_operand(0);
|
||||
// // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y))
|
||||
// if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR)
|
||||
// return inst;
|
||||
// std::cout << 1 << std::endl;
|
||||
// for(size_t i = 0; i < op->get_num_operands(); i++){
|
||||
// ir::value* arg_i = op->get_operand(i);
|
||||
// builder.set_insert_point(op);
|
||||
// // create new layout transform
|
||||
// ir::instruction* new_arg_i = inst->clone();
|
||||
// builder.insert(new_arg_i);
|
||||
// // set the right args
|
||||
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
|
||||
// op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder));
|
||||
// }
|
||||
// std::cout << 2 << std::endl;
|
||||
// return op;
|
||||
//}
|
||||
|
||||
void coalesce::run(ir::module &mod) {
|
||||
ir::builder& builder = mod.get_builder();
|
||||
// add layout conversion instructions
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
// coalesce before store
|
||||
if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
|
||||
if(ir::value* op = i->get_operand(1))
|
||||
if(op->get_type()->is_block_ty())
|
||||
if(layout_->get(op)->to_mma()){
|
||||
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
|
||||
builder.set_insert_point(i);
|
||||
builder.insert(new_op);
|
||||
i->replace_uses_of_with(op, new_op);
|
||||
}
|
||||
// uncoalesce after load
|
||||
if(auto x = dynamic_cast<ir::load_inst*>(i))
|
||||
if(x->get_type()->is_block_ty())
|
||||
if(x->get_type()->get_tile_rank()==2)
|
||||
if(layout_->get(x)->to_mma()){
|
||||
builder.set_insert_point_after(x);
|
||||
ir::instruction* new_x = ir::cvt_layout_inst::create(x);
|
||||
builder.insert(new_x);
|
||||
x->replace_all_uses_with(new_x);
|
||||
new_x->replace_uses_of_with(new_x, x);
|
||||
}
|
||||
}
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
// re-arrange scanline to promote memory coalescing
|
||||
if(auto x = dynamic_cast<ir::store_inst*>(i)){
|
||||
ir::value* ptr = x->get_pointer_operand();
|
||||
ir::value* val = x->get_value_operand();
|
||||
auto out_contig = align_->contiguous(ptr);
|
||||
auto val_inst = dynamic_cast<ir::instruction*>(val);
|
||||
if(!val_inst)
|
||||
break;
|
||||
if(dynamic_cast<ir::cvt_layout_inst*>(val))
|
||||
break;
|
||||
std::vector<unsigned> in_contig;
|
||||
std::vector<ir::instruction*> queue = {val_inst};
|
||||
std::set<ir::instruction*> seen;
|
||||
std::vector<ir::io_inst*> ios;
|
||||
while(!queue.empty()){
|
||||
ir::instruction* curr = queue.back();
|
||||
seen.insert(curr);
|
||||
queue.pop_back();
|
||||
if(auto dot_inst = dynamic_cast<ir::dot_inst*>(curr))
|
||||
break;
|
||||
if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){
|
||||
in_contig = align_->contiguous(io_inst->get_pointer_operand());
|
||||
break;
|
||||
}
|
||||
for(ir::value* op: curr->ops()){
|
||||
auto inst_op = dynamic_cast<ir::instruction*>(op);
|
||||
if(!inst_op || seen.find(inst_op) != seen.end())
|
||||
continue;
|
||||
if(!op->get_type()->is_block_ty() ||
|
||||
!val->get_type()->is_block_ty())
|
||||
continue;
|
||||
if(op->get_type()->get_tile_num_elements() ==
|
||||
val->get_type()->get_tile_num_elements())
|
||||
queue.push_back(inst_op);
|
||||
}
|
||||
}
|
||||
if(in_contig.size() <= 1 || out_contig==in_contig)
|
||||
continue;
|
||||
builder.set_insert_point_after(val_inst);
|
||||
auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst));
|
||||
x->replace_uses_of_with(val_inst, new_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,97 +0,0 @@
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
inline bool is_shmem_op(ir::instruction* i, int op) {
|
||||
if(i->get_id() == ir::INST_DOT)
|
||||
return op==0 || op==1;
|
||||
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
|
||||
return op==0;
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
return op==0;
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool is_shmem_res(ir::value* v){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return false;
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// run pass on module
|
||||
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||
// not an instruction
|
||||
if(!i) {
|
||||
builder.set_insert_point(parent);
|
||||
ir::value *copy;
|
||||
if(to_shared)
|
||||
copy = builder.create_copy_to_shared(x);
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
return;
|
||||
}
|
||||
// phi node
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
|
||||
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
|
||||
return;
|
||||
}
|
||||
// already in shared memory
|
||||
if(to_shared && is_shmem_res(i))
|
||||
return;
|
||||
// copy
|
||||
builder.set_insert_point_after(i);
|
||||
ir::value *copy;
|
||||
if(to_shared){
|
||||
copy = builder.create_copy_to_shared(x);
|
||||
}
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
}
|
||||
|
||||
void cts::run(ir::module &mod) {
|
||||
// Add shared copies
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function* fn: mod.get_function_list()){
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
size_t num_op = i->get_num_operands();
|
||||
// copy to shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(is_shmem_op(i, k)){
|
||||
add_copy(i, i->get_operand(k), builder, true);
|
||||
}
|
||||
// copy from shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(!dynamic_cast<ir::phi_node*>(i) &&
|
||||
!is_shmem_op(i,k) &&
|
||||
is_shmem_res(i->get_operand(k))){
|
||||
add_copy(i, i->get_operand(k), builder, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,75 +0,0 @@
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
void dce::run(ir::module &mod) {
|
||||
std::list<ir::instruction*> work_list;
|
||||
std::set<ir::instruction*> marked;
|
||||
|
||||
// initialize work-list
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo)
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
switch(i->get_id()){
|
||||
case ir::INST_RETURN:
|
||||
case ir::INST_UNCOND_BRANCH:
|
||||
case ir::INST_COND_BRANCH:
|
||||
case ir::INST_UNMASKED_STORE:
|
||||
case ir::INST_MASKED_STORE:
|
||||
case ir::INST_ATOMIC_CAS:
|
||||
case ir::INST_ATOMIC_RMW:
|
||||
case ir::INST_ATOMIC_EXCH:
|
||||
case ir::INST_BARRIER: {
|
||||
work_list.push_back(i);
|
||||
marked.insert(i);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mark -- ignore branches
|
||||
while(!work_list.empty()){
|
||||
ir::instruction* current = work_list.back();
|
||||
work_list.pop_back();
|
||||
// mark instruction operands
|
||||
for(ir::value* op: current->ops()) {
|
||||
if(auto *i = dynamic_cast<ir::instruction*>(op)){
|
||||
if(marked.insert(i).second)
|
||||
work_list.push_back(i);
|
||||
}
|
||||
}
|
||||
// TODO: mark last intstruction of current's reverse-dominance frontier
|
||||
}
|
||||
|
||||
// sweep -- delete non-branch unmarked instructions
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo)
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(marked.find(i) == marked.end())
|
||||
to_delete.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// delete
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,62 +0,0 @@
|
||||
#include "triton/codegen/transform/disassociate.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
|
||||
std::set<ir::value*>& seen) {
|
||||
if (dynamic_cast<ir::phi_node*>(root))
|
||||
return root;
|
||||
if(!seen.insert(root).second)
|
||||
return root;
|
||||
if(!root->get_type()->is_block_ty())
|
||||
return root;
|
||||
|
||||
bld.set_insert_point(root);
|
||||
ir::instruction *new_root = bld.insert(root->clone());
|
||||
for(ir::value *op: root->ops()){
|
||||
ir::instruction *i = dynamic_cast<ir::instruction*>(op);
|
||||
if(!i || i->get_id() == ir::INST_REDUCE)
|
||||
continue;
|
||||
ir::instruction* new_op = rematerialize(bld, i, seen);
|
||||
new_root->replace_uses_of_with(op, new_op);
|
||||
}
|
||||
return new_root;
|
||||
}
|
||||
|
||||
void disassociate::run(ir::module &mod) {
|
||||
ir::builder &bld = mod.get_builder();
|
||||
|
||||
// ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
// bld.set_insert_point(i);
|
||||
// for(ir::value* op: i->ops()){
|
||||
// auto reshape = dynamic_cast<ir::make_range*>(op);
|
||||
// if(!reshape)
|
||||
// continue;
|
||||
// ir::instruction* new_op = bld.insert(reshape->clone());
|
||||
// i->replace_uses_of_with(op, new_op);
|
||||
// }
|
||||
// });
|
||||
|
||||
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
if(dynamic_cast<ir::reshape_inst*>(i) || dynamic_cast<ir::splat_inst*>(i)){
|
||||
std::set<ir::value*> seen;
|
||||
ir::instruction* new_i = rematerialize(bld, i, seen);
|
||||
i->replace_all_uses_with(new_i);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,244 +0,0 @@
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/prefetch.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
|
||||
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
|
||||
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
||||
if (analysis::double_buffer_info_t* info = layout->get_double_buffer())
|
||||
return group_of(info->first, async_write);
|
||||
else if (analysis::N_buffer_info_t* info = layout->get_N_buffer()) {
|
||||
if (v == info->phi)
|
||||
return group_of(info->firsts[0], async_write);
|
||||
else // prefetched value
|
||||
return group_of(info->firsts[1], async_write);
|
||||
}
|
||||
std::vector<int> groups(phi->get_num_operands());
|
||||
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
|
||||
return *std::max_element(groups.begin(), groups.end());
|
||||
}
|
||||
else{
|
||||
if(layouts_->has_tmp(v))
|
||||
return async_write.size() - 1;
|
||||
auto it = std::find(async_write.begin(), async_write.end(), v);
|
||||
return std::distance(async_write.begin(), it);
|
||||
}
|
||||
}
|
||||
|
||||
inline bool membar::intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout) {
|
||||
if(!a_layout || !b_layout)
|
||||
return false;
|
||||
int a_start = alloc_->offset(a_layout);
|
||||
int a_end = a_start + a_layout->get_size();
|
||||
int b_start = alloc_->offset(b_layout);
|
||||
int b_end = b_start + b_layout->get_size();
|
||||
if(a_start < b_end || b_start < a_end)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
|
||||
val_set_t ret;
|
||||
for(ir::value* a: as){
|
||||
if(!a->get_type()->is_block_ty())
|
||||
continue;
|
||||
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
|
||||
analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr;
|
||||
for(ir::value* b: bs){
|
||||
if(!b->get_type()->is_block_ty())
|
||||
continue;
|
||||
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
|
||||
analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr;
|
||||
if(intersect_with(a_layout, b_layout) ||
|
||||
intersect_with(a_layout, b_tmp) ||
|
||||
intersect_with(a_tmp, b_layout) ||
|
||||
intersect_with(a_tmp, b_tmp))
|
||||
ret.insert(b);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool membar::check_safe_war(ir::instruction* i) {
|
||||
bool is_i_shared_block = i->get_type()->is_block_ty() &&
|
||||
layouts_->get(i)->to_shared();
|
||||
bool is_i_double_buffered = is_i_shared_block &&
|
||||
layouts_->get(i)->to_shared()->get_double_buffer();
|
||||
bool is_i_n_buffered = is_i_shared_block &&
|
||||
layouts_->get(i)->to_shared()->get_N_buffer();
|
||||
|
||||
if (is_i_double_buffered || is_i_n_buffered) {
|
||||
// with async copy & prefetch_s disabled, WARs are not safe
|
||||
if (dynamic_cast<ir::masked_load_async_inst*>(i) && !prefetch_->is_prefetched(i))
|
||||
return false;
|
||||
else
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void membar::transfer(ir::basic_block *block,
|
||||
val_vec_t& async_write,
|
||||
val_set_t& sync_write,
|
||||
val_set_t& sync_read,
|
||||
std::set<ir::value*>& safe_war,
|
||||
bool& inserted, ir::builder& builder) {
|
||||
std::vector<ir::async_wait_inst*> async_waits;
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
for(ir::instruction *i: instructions){
|
||||
if(dynamic_cast<ir::phi_node*>(i))
|
||||
continue;
|
||||
if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() &&
|
||||
dynamic_cast<ir::masked_load_async_inst*>(i)){
|
||||
async_write.push_back(i);
|
||||
}
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
|
||||
sync_write.insert(i);
|
||||
ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i);
|
||||
ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i);
|
||||
// Get shared memory reads
|
||||
std::set<ir::value*> read;
|
||||
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
|
||||
[&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();});
|
||||
if(layouts_->has_tmp(i))
|
||||
read.insert(i);
|
||||
// RAW (async)
|
||||
val_set_t tmp;
|
||||
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
|
||||
if(intersect_with(read, tmp).size()){
|
||||
std::vector<int> groups(read.size());
|
||||
std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
|
||||
int N = *std::max_element(groups.begin(), groups.end());
|
||||
if(N < async_write.size()){
|
||||
builder.set_insert_point(i);
|
||||
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
|
||||
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||
inserted = true;
|
||||
async_waits.push_back(async_wait);
|
||||
}
|
||||
}
|
||||
// RAW, WAR
|
||||
bool is_safe_war = check_safe_war(i);
|
||||
// WAR barrier is not required when data is double-buffered
|
||||
if(!intersect_with(read, sync_write).empty() ||
|
||||
(!intersect_with({i}, sync_read).empty() && !is_safe_war)) {
|
||||
builder.set_insert_point(i);
|
||||
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||
inserted = true;
|
||||
}
|
||||
// update state of asynchronous copies
|
||||
if(async_wait){
|
||||
int N = async_write.size() - async_wait->get_N();
|
||||
async_write.erase(async_write.begin(), async_write.begin() + N);
|
||||
}
|
||||
// all the copy_to_shared and read from shared are synchronized after barrier
|
||||
if(barrier){
|
||||
sync_write.clear();
|
||||
sync_read.clear();
|
||||
}
|
||||
sync_read.insert(read.begin(), read.end());
|
||||
}
|
||||
|
||||
// coalesce barriers
|
||||
// fixme: to support more general cases
|
||||
if (async_waits.size() == 2) {
|
||||
// (aw N; bar; prefetch; aw N-1; bar; prefetch; => aw N-1; bar; 2*prefetch;)
|
||||
for (int idx=0; idx<async_waits.size()-1; ++idx) {
|
||||
ir::async_wait_inst *first_async_wait = async_waits[idx];
|
||||
std::vector<ir::instruction*> to_erase;
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
for(auto iter = instructions.begin(); iter != instructions.end(); ++iter){
|
||||
ir::instruction *i = *iter;
|
||||
if (static_cast<ir::instruction*>(first_async_wait) == i) {
|
||||
// peak next 5 instructions
|
||||
auto peak_iter = std::next(iter);
|
||||
if (std::distance(peak_iter, instructions.end()) >= 5) {
|
||||
auto first_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
|
||||
auto first_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter++);
|
||||
auto second_async_wait = dynamic_cast<ir::async_wait_inst*>(*peak_iter++);
|
||||
auto second_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
|
||||
auto second_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter);
|
||||
if (first_bar && first_pf && second_async_wait && second_bar && second_pf) {
|
||||
int first_n = first_async_wait->get_N();
|
||||
int second_n = second_async_wait->get_N();
|
||||
to_erase.push_back(second_async_wait);
|
||||
to_erase.push_back(second_bar);
|
||||
first_async_wait->set_N(second_n);
|
||||
}
|
||||
} else
|
||||
break;
|
||||
for (ir::instruction *i : to_erase)
|
||||
block->erase(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void membar::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// extract phi-node associates with double-buffered
|
||||
// shared-memory copies. These can be read from and written to
|
||||
// without needing synchronization
|
||||
std::set<ir::value*> safe_war;
|
||||
for(const auto& x: layouts_->get_all()){
|
||||
analysis::shared_layout* layout = x.second->to_shared();
|
||||
if(!layout || !layout->get_double_buffer() || !layout->get_N_buffer())
|
||||
continue;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(v != layout->get_double_buffer()->phi){
|
||||
safe_war.insert(v);
|
||||
}
|
||||
}
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
std::map<ir::basic_block*, val_vec_t> async_writes;
|
||||
std::map<ir::basic_block*, val_set_t> sync_writes;
|
||||
std::map<ir::basic_block*, val_set_t> sync_reads;
|
||||
std::list<ir::value *> pipelined;
|
||||
bool inserted;
|
||||
do{
|
||||
inserted = false;
|
||||
// find barrier location
|
||||
for(ir::basic_block *block: rpo){
|
||||
// join inputs
|
||||
val_vec_t async_write;
|
||||
val_set_t sync_write;
|
||||
val_set_t sync_read;
|
||||
val_set_t tmp;
|
||||
for(ir::basic_block* pred: block->get_predecessors()){
|
||||
for(ir::value* v: async_writes[pred])
|
||||
if(tmp.insert(v).second)
|
||||
async_write.push_back(v);
|
||||
sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end());
|
||||
sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end());
|
||||
}
|
||||
transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder);
|
||||
async_writes[block] = async_write;
|
||||
sync_writes[block] = sync_write;
|
||||
sync_reads[block] = sync_read;
|
||||
}
|
||||
}while(inserted);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,309 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
|
||||
const std::vector<int>& perm) {
|
||||
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
|
||||
// transpose operands
|
||||
std::vector<ir::value*> incs;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
incs.push_back(rewrite_trans_phi_impl(phi->get_incoming_value(n), builder, perm));
|
||||
// create phi for transposed values
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size());
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
result->add_incoming(incs[n], phi->get_incoming_block(n));
|
||||
return result;
|
||||
}
|
||||
else if(auto i = dynamic_cast<ir::instruction*>(value)){
|
||||
ir::basic_block* block = i->get_parent();
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
it++;
|
||||
builder.set_insert_point(it);
|
||||
ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm);
|
||||
trans->set_operand(0, i);
|
||||
return trans;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
|
||||
auto trans = dynamic_cast<ir::trans_inst*>(value);
|
||||
if(!trans)
|
||||
return false;
|
||||
auto users = trans->get_users();
|
||||
auto ops = trans->ops();
|
||||
if(users.size() > 1 || ops.size() > 1)
|
||||
return false;
|
||||
ir::value* op = *ops.begin();
|
||||
// trans(phi) -> phi(trans(), trans()...)
|
||||
auto* phi = dynamic_cast<ir::phi_node*>(op);
|
||||
if(!phi)
|
||||
return false;
|
||||
ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm());
|
||||
if(!new_phi)
|
||||
return false;
|
||||
trans->replace_all_uses_with(new_phi);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
// dot(a, b, c) + d -> dot(a, b, c + d)
|
||||
// d + dot(a, b, c) -> dot(a, b, c + d)
|
||||
auto add = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) {
|
||||
bool is_int_dot = add->get_op() == ir::binary_op_t::Add;
|
||||
ir::value *lhs = add->get_operand(0);
|
||||
ir::value *rhs = add->get_operand(1);
|
||||
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
|
||||
ir::dot_inst *rhs_dot = dynamic_cast<ir::dot_inst*>(rhs);
|
||||
if(!lhs_dot && !rhs_dot)
|
||||
return false;
|
||||
ir::dot_inst *dot = lhs_dot ? lhs_dot : rhs_dot;
|
||||
ir::value *other = (dot == lhs) ? rhs : lhs;
|
||||
ir::value *acc = dot->get_operand(2);
|
||||
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
|
||||
ir::constant *_0 = nullptr;
|
||||
if(splat)
|
||||
_0 = dynamic_cast<ir::constant*>(splat->get_operand(0));
|
||||
if(!_0)
|
||||
return false;
|
||||
if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0))
|
||||
if (fp_0->get_value() != 0.0)
|
||||
return false;
|
||||
if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0))
|
||||
if (int_0->get_value() != 0)
|
||||
return false;
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
builder.set_insert_point(add);
|
||||
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name()));
|
||||
add->replace_all_uses_with(new_dot);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
//bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
|
||||
// auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
|
||||
// if(cfs) {
|
||||
// ir::value *arg = cfs->get_operand(0);
|
||||
// ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
|
||||
// if(!cts)
|
||||
// return false;
|
||||
// cfs->replace_all_uses_with(cts->get_operand(0));
|
||||
// return true;
|
||||
// }
|
||||
|
||||
//}
|
||||
|
||||
bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){
|
||||
auto copy_to_shared = dynamic_cast<ir::copy_to_shared_inst*>(value);
|
||||
if(!copy_to_shared)
|
||||
return false;
|
||||
ir::value *arg = copy_to_shared->get_operand(0);
|
||||
ir::masked_load_inst* ld = dynamic_cast<ir::masked_load_inst*>(arg);
|
||||
if(!ld)
|
||||
return false;
|
||||
builder.set_insert_point(copy_to_shared);
|
||||
ir::value *ptr = ld->get_pointer_operand();
|
||||
ir::value *msk = ld->get_mask_operand();
|
||||
ir::value *val = ld->get_false_value_operand();
|
||||
analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
|
||||
int nts = layout->nts(layout->get_order()[0]);
|
||||
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(nts*dtsize >= 4){
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy());
|
||||
copy_to_shared->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
|
||||
// std::cout << layout->nts(layout->get_order(0)) << std::endl;
|
||||
// return true;
|
||||
|
||||
}
|
||||
|
||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||
if(!x)
|
||||
return false;
|
||||
ir::value *arg = x->get_operand(0);
|
||||
auto shapes = arg->get_type()->get_block_shapes();
|
||||
if(shapes[x->get_axis()] == 1){
|
||||
builder.set_insert_point(x);
|
||||
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_block_shapes());
|
||||
x->replace_all_uses_with(new_red);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
|
||||
auto binop = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
|
||||
ir::value *lhs = binop->get_operand(0);
|
||||
ir::value *rhs = binop->get_operand(1);
|
||||
ir::constant_int *_1_lhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_lhs = cst;
|
||||
}
|
||||
ir::constant_int *_1_rhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_rhs = cst;
|
||||
}
|
||||
if(_1_lhs){
|
||||
binop->replace_all_uses_with(rhs);
|
||||
return true;
|
||||
}
|
||||
else if(_1_rhs){
|
||||
binop->replace_all_uses_with(lhs);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
|
||||
auto x = dynamic_cast<ir::getelementptr_inst*>(value);
|
||||
if(!x)
|
||||
return false;
|
||||
auto y = dynamic_cast<ir::getelementptr_inst*>(x->get_pointer_operand());
|
||||
if(!y)
|
||||
return false;
|
||||
auto idx = *y->idx_begin();
|
||||
auto z = dynamic_cast<ir::binary_operator*>(idx);
|
||||
if(!z)
|
||||
return false;
|
||||
bool is_sub = z->get_op() == ir::binary_op_t::Sub;
|
||||
auto *lhs = dynamic_cast<ir::constant_int*>(z->get_operand(0));
|
||||
bool is_lhs_0 = lhs && (lhs->get_value()==0);
|
||||
bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin();
|
||||
if(is_sub && is_lhs_0 && is_rhs_eq_x_rhs){
|
||||
x->replace_all_uses_with(y->get_pointer_operand());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& builder){
|
||||
auto select = dynamic_cast<ir::select_inst*>(value);
|
||||
if(!select)
|
||||
return false;
|
||||
auto if_value = dynamic_cast<ir::masked_load_inst*>(select->get_if_value_op());
|
||||
if(!if_value)
|
||||
return false;
|
||||
if(select->get_pred_op() != if_value->get_mask_operand())
|
||||
return false;
|
||||
builder.set_insert_point(select);
|
||||
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
|
||||
if_value->get_mask_operand(),
|
||||
select->get_else_value_op(),
|
||||
if_value->get_cache_modifier(),
|
||||
if_value->get_eviction_policy(),
|
||||
if_value->get_is_volatile());
|
||||
select->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){
|
||||
auto cvt = dynamic_cast<ir::cvt_layout_inst*>(value);
|
||||
if(!cvt)
|
||||
return false;
|
||||
ir::instruction* op = dynamic_cast<ir::instruction*>(cvt->get_operand(0));
|
||||
if(!op)
|
||||
return false;
|
||||
// // convert(elementwise(x, y)) = elementwise(convert(x), convert(y))
|
||||
// if(op->get_id() == ir::INST_BINOP){
|
||||
// for(size_t i = 0; i < op->get_num_operands(); i++){
|
||||
// ir::value* arg_i = op->get_operand(i);
|
||||
// builder.set_insert_point(op);
|
||||
// // create new layout transform
|
||||
// ir::instruction* new_arg_i = cvt->clone();
|
||||
// layouts_->copy(new_arg_i, op);
|
||||
// builder.insert(new_arg_i);
|
||||
// // set the right args
|
||||
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
|
||||
// op->replace_uses_of_with(arg_i, new_arg_i);
|
||||
// }
|
||||
// cvt->replace_all_uses_with(op);
|
||||
// return true;
|
||||
// }
|
||||
auto cvt_op = dynamic_cast<ir::cvt_layout_inst*>(op);
|
||||
if(!cvt_op)
|
||||
return false;
|
||||
// convert1(convert2(x)) if convert1 is the inverse of convert2
|
||||
ir::value* op_op = cvt_op->get_operand(0);
|
||||
if(layouts_->has(cvt) && layouts_->has(op_op) &&
|
||||
layouts_->get(cvt) && layouts_->get(op_op)){
|
||||
cvt->replace_all_uses_with(op_op);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void peephole::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// keep track of whether any modification was made
|
||||
std::set<ir::value*> seen;
|
||||
size_t n_seen;
|
||||
|
||||
// rewrite dots first
|
||||
do{
|
||||
n_seen = seen.size();
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
if(seen.find(i) != seen.end())
|
||||
continue;
|
||||
bool was_modified = rewrite_dot(i, builder);
|
||||
if(was_modified){
|
||||
seen.insert(i);
|
||||
}
|
||||
}
|
||||
}while(seen.size() != n_seen);
|
||||
|
||||
// rewrite other ops
|
||||
seen.clear();
|
||||
do{
|
||||
n_seen = seen.size();
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
if(seen.find(i) != seen.end())
|
||||
continue;
|
||||
bool was_modified = false;
|
||||
was_modified = was_modified || rewrite_mult(i, builder);
|
||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
// was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
// TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD
|
||||
// was_modified = was_modified || rewrite_select_masked_load(i, builder);
|
||||
was_modified = was_modified || rewrite_cvt_layout(i, builder);
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||
was_modified = was_modified || rewrite_load_to_shared(i, builder);
|
||||
if(was_modified)
|
||||
seen.insert(i);
|
||||
}
|
||||
}while(seen.size() != n_seen);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,330 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i || i->get_parent() != block)
|
||||
return;
|
||||
if(i->get_id()==ir::INST_PHI)
|
||||
return;
|
||||
ret.push_back(i);
|
||||
for(ir::user* u: i->get_users())
|
||||
recursive_deps(u, block, ret);
|
||||
}
|
||||
|
||||
void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
|
||||
auto instr = dynamic_cast<ir::instruction*>(cond);
|
||||
for (auto op : instr->ops()) {
|
||||
if (auto phi_op = dynamic_cast<ir::phi_node*>(op)) {
|
||||
phis.insert(phi_op);
|
||||
return;
|
||||
}
|
||||
if (dynamic_cast<ir::instruction*>(op))
|
||||
get_induction_vars(op, phis);
|
||||
}
|
||||
}
|
||||
|
||||
/// assume incoming block is 1
|
||||
ir::value* rematerialize_vals(ir::builder& builder, ir::basic_block* block, ir::value* v,
|
||||
std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i || i->get_parent() != block)
|
||||
return v;
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) {
|
||||
if (prev_phi_vals.find(phi) == prev_phi_vals.end())
|
||||
throw std::runtime_error("Don't have that phi node\n");
|
||||
return prev_phi_vals.at(phi);
|
||||
}
|
||||
|
||||
std::vector<ir::value*> new_ops;
|
||||
for(ir::value* op: i->ops()){
|
||||
new_ops.push_back(rematerialize_vals(builder, block, op, prev_phi_vals));
|
||||
}
|
||||
ir::instruction* ret = i->clone();
|
||||
for(size_t k = 0; k < new_ops.size(); k++)
|
||||
ret->set_operand(k, new_ops[k]);
|
||||
builder.insert(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
ir::value* rematerialize(ir::builder& builder, ir::basic_block* block,
|
||||
ir::value* v, size_t phi_idx){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i || i->get_parent() != block)
|
||||
return v;
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
|
||||
return phi->get_incoming_value(phi_idx);
|
||||
|
||||
std::vector<ir::value*> new_ops;
|
||||
for(ir::value* op: i->ops()){
|
||||
new_ops.push_back(rematerialize(builder, block, op, phi_idx));
|
||||
}
|
||||
ir::instruction* ret = i->clone();
|
||||
for(size_t k = 0; k < new_ops.size(); k++)
|
||||
ret->set_operand(k, new_ops[k]);
|
||||
builder.insert(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// moving the prev phi vals to the next iteration
|
||||
std::map<ir::phi_node*, ir::value*> update_prev_phi_vals(
|
||||
ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
|
||||
std::map<ir::phi_node*, ir::value*> next_phi_vals;
|
||||
for (auto &[phi, val] : prev_phi_vals) {
|
||||
next_phi_vals[phi] = rematerialize_vals(builder, block, phi->get_incoming_value(1), prev_phi_vals);
|
||||
}
|
||||
return next_phi_vals;
|
||||
}
|
||||
|
||||
void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& load_ivs,
|
||||
std::map<ir::phi_node*, ir::value*>& next_load_ivs) {
|
||||
for (auto& [phi, val] : load_ivs) {
|
||||
if (auto new_phi = dynamic_cast<ir::phi_node*>(val)) {
|
||||
ir::value* next_k = rematerialize_vals(builder, block, phi->get_incoming_value(1), load_ivs);
|
||||
assert(new_phi->get_num_operands() == 1 && "should be incomplete phi");
|
||||
new_phi->add_incoming(next_k, phi->get_incoming_block(1));
|
||||
// cache next_k (to be used by next_mask)
|
||||
next_load_ivs[phi] = next_k;
|
||||
} else
|
||||
throw std::runtime_error("must be phi");
|
||||
}
|
||||
}
|
||||
|
||||
struct pipeline_info_t {
|
||||
ir::load_inst* load;
|
||||
ir::phi_node* ptr;
|
||||
ir::dot_inst* dot;
|
||||
|
||||
pipeline_info_t(ir::load_inst* load, ir::phi_node* ptr, ir::dot_inst* dot)
|
||||
: load(load), ptr(ptr), dot(dot) {}
|
||||
};
|
||||
|
||||
void pipeline::run(ir::module &mod) {
|
||||
if (num_stages_ <= 1)
|
||||
return;
|
||||
// *Very* conservative heuristics for pre-fetching.
|
||||
// A load instruction can be pipelined if:
|
||||
// - the pointer is a phi node that references a value
|
||||
// in its basic block (i.e., pointer induction variable)
|
||||
// - the load has only a single use in a dot instruction
|
||||
// As more use cases become apparent, this pass will be improved
|
||||
std::vector<pipeline_info_t> to_pipeline;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
|
||||
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
|
||||
auto users = load->get_users();
|
||||
auto dot = dynamic_cast<ir::dot_inst*>(*users.begin());
|
||||
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
|
||||
&& users.size() == 1 && dot)
|
||||
to_pipeline.push_back({load, ptr, dot});
|
||||
}});
|
||||
// do the pipelining
|
||||
std::vector<ir::phi_node*> new_loads;
|
||||
ir::builder &builder = mod.get_builder();
|
||||
const int num_stages = num_stages_;
|
||||
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
|
||||
for(auto info: to_pipeline){
|
||||
ir::load_inst* load = info.load;
|
||||
ir::phi_node* ptr = info.ptr;
|
||||
ir::basic_block* block = load->get_parent();
|
||||
ir::basic_block* header = block->get_predecessors()[0];
|
||||
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
|
||||
auto* header_br = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back());
|
||||
assert(block_br);
|
||||
assert(header_br);
|
||||
ir::type* ty = load->get_type();
|
||||
// multi-stage pipe
|
||||
if (has_copy_async_ && num_stages > 2) {
|
||||
ir::value* header_cond = header_br->get_cond();
|
||||
ir::value* block_cond = block_br->get_cond();
|
||||
// 1. collect induction variables
|
||||
std::set<ir::phi_node*> induction_vars;
|
||||
get_induction_vars(block_cond, induction_vars);
|
||||
|
||||
std::vector<ir::value*> first_ptrs(num_stages-1);
|
||||
std::vector<ir::value*> first_loads(num_stages-1);
|
||||
std::vector<ir::value*> first_masks(num_stages-1);
|
||||
std::vector<ir::value*> loop_conds(num_stages-1);
|
||||
|
||||
std::map<ir::phi_node*, ir::value*> prev_phi_vals;
|
||||
// initialize prev_phi_vals
|
||||
// Add all phi nodes. The following DCE pass will delete dead ones.
|
||||
for (ir::instruction *instr : block->get_inst_list())
|
||||
if (auto *phi = dynamic_cast<ir::phi_node*>(instr))
|
||||
if (phi->get_incoming_block(1) == block)
|
||||
prev_phi_vals[phi] = phi->get_value_for_block(header);
|
||||
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
first_ptrs[0] = ptr->get_value_for_block(header);
|
||||
loop_conds[0] = header_cond;
|
||||
first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes());
|
||||
ir::value* false_value = nullptr;
|
||||
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||
ir::value* remat_mask =rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals) ;
|
||||
ir::value* remat_false_value =
|
||||
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
|
||||
first_masks[0] = builder.create_and(first_masks[0], remat_mask);
|
||||
false_value = remat_false_value;
|
||||
} else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
|
||||
for (int stage = 1; stage < num_stages-1; ++stage) {
|
||||
// mask is the loop condition of the previous iteration
|
||||
loop_conds[stage] = rematerialize_vals(builder, block, block_cond, prev_phi_vals);
|
||||
prev_phi_vals = update_prev_phi_vals(builder, block, prev_phi_vals);
|
||||
first_ptrs[stage] = rematerialize_vals(builder, block, ptr, prev_phi_vals);
|
||||
first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes());
|
||||
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals);
|
||||
ir::value* remat_false_value =
|
||||
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
|
||||
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
}
|
||||
|
||||
// create new phis for induction variables
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
std::map<ir::phi_node*, ir::value*> load_ivs;
|
||||
std::map<ir::phi_node*, ir::value*> next_load_ivs;
|
||||
for (auto& [iv, val] : prev_phi_vals) {
|
||||
ir::phi_node* pn = builder.create_phi(iv->get_type(), 2);
|
||||
pn->add_incoming(prev_phi_vals[iv], header);
|
||||
load_ivs[iv] = pn;
|
||||
}
|
||||
// add incoming for phis & update next_load_ivs
|
||||
finalize_iv_vals(builder, block, load_ivs, next_load_ivs);
|
||||
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
// ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_ptr = rematerialize_vals(builder, block, ptr->get_value_for_block(block), load_ivs);
|
||||
ir::value* next_mask = builder.create_splat(
|
||||
rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes());
|
||||
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), next_load_ivs);
|
||||
// TODO: false may depends on some other phi nodes
|
||||
ir::value* remat_false_value =
|
||||
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), next_load_ivs);
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
|
||||
|
||||
// phi node
|
||||
ptr->set_incoming_value(0, first_ptrs.back());
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
// nested phis for load
|
||||
std::vector<ir::phi_node*> new_load_phis(num_stages-1);
|
||||
for (auto& pn : new_load_phis)
|
||||
pn = builder.create_phi(ty, 2);
|
||||
for (int i=0; i<num_stages-2; ++i) {
|
||||
new_load_phis[i]->add_incoming(first_loads[i], header);
|
||||
new_load_phis[i]->add_incoming(new_load_phis[i+1], block);
|
||||
}
|
||||
new_load_phis.back()->add_incoming(first_loads.back(), header);
|
||||
new_load_phis.back()->add_incoming(next_load, block);
|
||||
load->replace_all_uses_with(new_load_phis.front());
|
||||
new_loads.push_back(new_load_phis.back());
|
||||
|
||||
// record first_loads to reorder them
|
||||
preheader_loads.push_back({new_load_phis.front(), first_loads});
|
||||
} else {
|
||||
// pre-fetch first iteration
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
ir::value* first_ptr = ptr->get_value_for_block(header);
|
||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
|
||||
ir::value* false_value;
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 0);
|
||||
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 0);
|
||||
first_mask = builder.create_and(first_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 1);
|
||||
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 1);
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
// phi node
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
||||
new_load->add_incoming(first_load, header);
|
||||
new_load->add_incoming(next_load, block);
|
||||
load->replace_all_uses_with(new_load);
|
||||
new_loads.push_back(new_load);
|
||||
}
|
||||
}
|
||||
|
||||
// try to reorder prefetched value from a0, a1, a2, ..., b0, b1, b2, ... to
|
||||
// a0, b0, a1, b1, ...
|
||||
if (!preheader_loads.empty()) {
|
||||
ir::basic_block* header = preheader_loads.begin()->first->get_incoming_block(0);
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
for (int i=1; i<num_stages-1; ++i) {
|
||||
for (auto iter = preheader_loads.begin(); iter != preheader_loads.end(); ++iter) {
|
||||
ir::instruction* original_load = static_cast<ir::instruction*>(iter->second.at(i));
|
||||
ir::instruction* moved_load = original_load->clone();
|
||||
builder.insert(moved_load);
|
||||
original_load->replace_all_uses_with(moved_load);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// try to move dot_inst after loads
|
||||
// for better overlap of io and compute
|
||||
struct move_config_t{
|
||||
std::vector<ir::instruction*> insts;
|
||||
ir::load_inst* dst;
|
||||
};
|
||||
std::vector<move_config_t> to_move(to_pipeline.size());
|
||||
|
||||
if(has_copy_async_){
|
||||
for (size_t idx = 0; idx < to_pipeline.size(); ++idx) {
|
||||
auto info = to_pipeline[idx];
|
||||
ir::load_inst* load = info.load;
|
||||
ir::phi_node* ptr = info.ptr;
|
||||
ir::dot_inst* dot = info.dot;
|
||||
ir::basic_block* bb = dot->get_parent();
|
||||
recursive_deps(dot, bb, to_move[idx].insts);
|
||||
to_move[idx].dst = load;
|
||||
}
|
||||
|
||||
for(auto& move_config: to_move){
|
||||
builder.set_insert_point_after(move_config.dst);
|
||||
for(ir::instruction* i: move_config.insts){
|
||||
i->get_parent()->erase(i);
|
||||
builder.insert(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,133 +0,0 @@
|
||||
#include "triton/codegen/transform/prefetch.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
namespace triton::codegen::transform {
|
||||
|
||||
/// find defs till phis
|
||||
static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector<ir::instruction*> &ret) {
|
||||
ir::instruction *i = dynamic_cast<ir::instruction*>(v);
|
||||
if (!i || i->get_parent() != bb)
|
||||
return;
|
||||
if (i->get_id() == ir::INST_PHI)
|
||||
return;
|
||||
ret.push_back(i);
|
||||
for (ir::value *op : i->ops())
|
||||
recursive_defs(op, bb, ret);
|
||||
}
|
||||
|
||||
void prefetch::run(ir::module &mod) {
|
||||
// 1. collect dots that can be prefethced
|
||||
std::vector<ir::dot_inst*> to_prefetch;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i) {
|
||||
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
|
||||
// Now only do prefetching when dot is using tensor cores
|
||||
if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
|
||||
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
|
||||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
|
||||
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ||
|
||||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8)
|
||||
&& dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8)
|
||||
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||
)
|
||||
)
|
||||
return;
|
||||
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
||||
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
||||
if (a && a->get_incoming_block(1) == a->get_parent() &&
|
||||
b && b->get_incoming_block(1) == b->get_parent())
|
||||
to_prefetch.push_back(dot);
|
||||
}
|
||||
});
|
||||
|
||||
assert(to_prefetch.size() <=1 && "Don't know what to do with multiple dots");
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// 2. do the prefetching
|
||||
for (ir::dot_inst* dot : to_prefetch) {
|
||||
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
||||
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
||||
assert(a->get_incoming_block(0) == b->get_incoming_block(0));
|
||||
ir::basic_block *loop_header = a->get_incoming_block(0);
|
||||
ir::basic_block *loop_body = a->get_parent();
|
||||
|
||||
// mark as prefetched
|
||||
dot->set_prefetched(true);
|
||||
|
||||
// 1. in the loop header (first iteration)
|
||||
builder.set_insert_point(loop_header->get_inst_list().back());
|
||||
assert(a && b);
|
||||
builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0);
|
||||
builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0);
|
||||
|
||||
// 2. at the end of the loop body (next iteration)
|
||||
builder.set_insert_point(loop_body->get_inst_list().back());
|
||||
builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1);
|
||||
builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1);
|
||||
|
||||
prefetched_vals_.insert(a->get_incoming_value(0));
|
||||
prefetched_vals_.insert(b->get_incoming_value(0));
|
||||
// nested phis
|
||||
ir::value* next_a = a->get_incoming_value(1);
|
||||
while (auto* next_a_phi = dynamic_cast<ir::phi_node*>(next_a)) {
|
||||
prefetched_vals_.insert(next_a_phi->get_incoming_value(0));
|
||||
next_a = next_a_phi->get_incoming_value(1);
|
||||
}
|
||||
prefetched_vals_.insert(next_a);
|
||||
|
||||
ir::value* next_b = b->get_incoming_value(1);
|
||||
while (auto* next_b_phi = dynamic_cast<ir::phi_node*>(next_b)) {
|
||||
prefetched_vals_.insert(next_b_phi->get_incoming_value(0));
|
||||
next_b = next_b_phi->get_incoming_value(1);
|
||||
}
|
||||
prefetched_vals_.insert(next_b);
|
||||
}
|
||||
|
||||
// move loads to the beginning of the loop
|
||||
if (tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80) {
|
||||
for (ir::function *fn : mod.get_function_list())
|
||||
for (ir::basic_block *bb : fn->blocks()) {
|
||||
// only apply to loop body
|
||||
if (bb->get_predecessors().size() != 2 || bb->get_predecessors()[1] != bb)
|
||||
continue;
|
||||
// record loads (& dependency) to move
|
||||
std::vector<ir::instruction*> loads;
|
||||
// record original inst order
|
||||
std::map<ir::instruction*, size_t> idx_map;
|
||||
size_t idx = 0;
|
||||
for (ir::instruction *inst : bb->get_inst_list()) {
|
||||
if (auto *i = dynamic_cast<ir::masked_load_inst*>(inst))
|
||||
recursive_defs(i, bb, loads);
|
||||
idx_map[inst] = idx;
|
||||
idx++;
|
||||
}
|
||||
|
||||
// remove duplicates & keep the original input order
|
||||
std::sort(loads.begin(), loads.end());
|
||||
loads.erase(std::unique(loads.begin(), loads.end()), loads.end());
|
||||
std::sort(loads.begin(), loads.end(), [&idx_map](ir::instruction *a, ir::instruction *b) {
|
||||
return idx_map[a] < idx_map[b];
|
||||
});
|
||||
|
||||
builder.set_insert_point(bb->get_first_non_phi());
|
||||
auto& inst_list = bb->get_inst_list();
|
||||
for (ir::instruction *i : loads){
|
||||
auto it = std::find(inst_list.begin(), inst_list.end(), i);
|
||||
// make sure we don't invalidate insert point
|
||||
// in case instruction already at the top
|
||||
if(it == builder.get_insert_point())
|
||||
continue;
|
||||
bb->erase(i);
|
||||
builder.insert(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace triton::codegen::transform
|
@@ -1,51 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
void reorder::run(ir::module& mod){
|
||||
// ir::builder &builder = mod.get_builder();
|
||||
// std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
|
||||
|
||||
// for(ir::function *fn: mod.get_function_list())
|
||||
// for(ir::basic_block *block: fn->blocks())
|
||||
// for(ir::instruction* i: block->get_inst_list()){
|
||||
// if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
|
||||
// ir::value* _ptr = ld->get_pointer_operand();
|
||||
// ir::value* _msk = ld->get_mask_operand();
|
||||
// ir::value* _val = ld->get_false_value_operand();
|
||||
// auto ptr = std::find(block->begin(), block->end(), _ptr);
|
||||
// auto msk = std::find(block->begin(), block->end(), _msk);
|
||||
// auto val = std::find(block->begin(), block->end(), _val);
|
||||
// if(ptr == block->end() || msk == block->end() || val == block->end())
|
||||
// continue;
|
||||
// auto it = std::find(block->begin(), block->end(), i);
|
||||
// int dist_ptr = std::distance(ptr, it);
|
||||
// int dist_msk = std::distance(msk, it);
|
||||
// int dist_val = std::distance(val, it);
|
||||
// if(dist_ptr < dist_msk && dist_ptr < dist_val)
|
||||
// builder.set_insert_point(++ptr);
|
||||
// if(dist_msk < dist_ptr && dist_msk < dist_val)
|
||||
// builder.set_insert_point(++msk);
|
||||
// if(dist_val < dist_ptr && dist_val < dist_msk)
|
||||
// builder.set_insert_point(++val);
|
||||
// ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
|
||||
// to_replace.push_back(std::make_pair(ld, new_ld));
|
||||
// }
|
||||
// }
|
||||
|
||||
// for(auto& x: to_replace)
|
||||
// x.first->replace_all_uses_with(x.second);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,5 +0,0 @@
|
||||
# Run the benchmarks
|
||||
|
||||
Install the required dependencies via `pip install -r requirements-bench.txt` from the triton/python/bench folder.
|
||||
|
||||
Run the benchmarks through `python3 bench/run.py`, this will produce an HTML report in a results folder.
|
@@ -1,92 +0,0 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
# -------------------------------
|
||||
# Matrix Multiplication
|
||||
# -------------------------------
|
||||
|
||||
nt = {False: 'n', True: 't'}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'],
|
||||
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
line_arg='block',
|
||||
line_vals=[16, 32, 64, 128],
|
||||
line_names=['Block16', 'Block32', 'Block64', 'Block128'],
|
||||
ylabel='TFLOPS',
|
||||
plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
|
||||
args={'layout_mode': layout_mode, 'op_mode': op_mode,
|
||||
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)
|
||||
for AT in [False] for BT in [False]
|
||||
for op_mode in ['dsd'] for layout_mode in ['dense']
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
|
||||
Z, H = 1, 1
|
||||
make_layout = {
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[layout_mode]
|
||||
# create layout
|
||||
shape = {'sdd': (M, N), 'dsd': (K, M) if AT else (M, K), 'dds': (N, K) if BT else (K, N)}[op_mode]
|
||||
layout = make_layout(H, shape[0] // block, shape[1] // block)
|
||||
# creat inputs
|
||||
a = torch.randn((Z, H, K, M) if AT else (Z, H, M, K), dtype=dtype, device='cuda')
|
||||
b = torch.randn((Z, H, N, K) if BT else (Z, H, K, N), dtype=dtype, device='cuda')
|
||||
# create op
|
||||
tflops = lambda ms: num_flops / ms * 1e3
|
||||
if provider == 'triton':
|
||||
op = triton.ops.blocksparse.matmul(layout, block, op_mode, trans_a=AT, trans_b=BT)
|
||||
# inputs
|
||||
a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a
|
||||
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
|
||||
num_flops = {
|
||||
'sdd': 2 * Z * K * float(layout.sum()) * block * block,
|
||||
'dsd': 2 * Z * N * float(layout.sum()) * block * block,
|
||||
'dds': 2 * Z * M * float(layout.sum()) * block * block
|
||||
}[op_mode] * 1e-12
|
||||
return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Softmax
|
||||
# -------------------------------
|
||||
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N'],
|
||||
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
||||
line_arg='block',
|
||||
line_vals=[16, 32, 64],
|
||||
line_names=['Block16', 'Block32', 'Block64'],
|
||||
ylabel='GBPS',
|
||||
plot_name=f'{layout_mode}-square',
|
||||
args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
|
||||
)
|
||||
for layout_mode in ['dense', 'tril']
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
|
||||
Z, H = 1, 1
|
||||
make_layout = {
|
||||
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
|
||||
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
|
||||
}[layout_mode]
|
||||
layout = make_layout(H, M // block, N // block)
|
||||
a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda')
|
||||
if provider == 'triton':
|
||||
a = triton.testing.sparsify_tensor(a, layout, block)
|
||||
op = triton.ops.blocksparse.softmax(layout, block)
|
||||
gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
|
||||
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
|
||||
|
||||
|
||||
bench_matmul.run(print_data=True, show_plots=True)
|
@@ -1,41 +0,0 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'],
|
||||
x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
|
||||
line_arg='provider',
|
||||
line_vals=['triton', 'torch'],
|
||||
line_names=['Triton', 'Torch'],
|
||||
ylabel='GBPS',
|
||||
plot_name=f'{mode}-2048',
|
||||
args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
|
||||
)
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(confs)
|
||||
def bench_op(M, N, dtype, mode, provider):
|
||||
# create inputs
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
||||
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
||||
num_gb = (2 * x.numel() * x.element_size() * 1e-9)
|
||||
gbps = lambda ms: num_gb / ms * 1e3
|
||||
# forward pass
|
||||
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
|
||||
'triton': triton.ops.cross_entropy}[provider]
|
||||
if mode == 'forward':
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
|
||||
if mode == 'backward':
|
||||
y = op(x, idx)
|
||||
dy = torch.randn_like(y)
|
||||
fn = lambda: y.backward(dy, retain_graph=True)
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=[x])
|
||||
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
bench_op.run(print_data=True)
|
@@ -1,67 +0,0 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
def rounded_linspace(low, high, steps, div):
|
||||
ret = torch.linspace(low, high, steps)
|
||||
ret = (ret.int() + div - 1) // div * div
|
||||
ret = torch.unique(ret)
|
||||
return list(map(int, ret))
|
||||
|
||||
|
||||
# Square benchmarks
|
||||
nt = {False: "n", True: "t"}
|
||||
square_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["M", "N", "K"],
|
||||
x_vals=rounded_linspace(512, 8192, 32, 128),
|
||||
line_arg="provider",
|
||||
line_vals=["cublas", "triton", "cutlass"],
|
||||
line_names=["cuBLAS", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
|
||||
args={"AT": AT, "BT": BT, "dtype": torch.float16},
|
||||
) for AT in [False] for BT in [False]
|
||||
]
|
||||
|
||||
# Transformer training benchmarks
|
||||
transformer_confs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=[x],
|
||||
x_vals=rounded_linspace(NK // 16, NK, 32, 128),
|
||||
line_arg="provider",
|
||||
line_vals=["cublas", "triton", "cutlass"],
|
||||
line_names=["cuBLAS", "Triton", "CUTLASS"],
|
||||
ylabel="TFLOPS",
|
||||
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
|
||||
args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
|
||||
) for NK in [12288]
|
||||
for i, x in enumerate(["N", "K"])
|
||||
for M in [2048]
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(square_confs)
|
||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
|
||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||
if AT:
|
||||
a = a.t()
|
||||
if BT:
|
||||
b = b.t()
|
||||
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
|
||||
if provider == "cublas":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
if provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
if provider == "cutlass":
|
||||
cutlass_matmul = triton.testing.cutlass_matmul
|
||||
try:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
@@ -1,2 +0,0 @@
|
||||
pandas >= 1.3.3
|
||||
matplotlib >= 3.4.3
|
@@ -1,44 +0,0 @@
|
||||
import argparse
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
def run_all(result_dir, names):
|
||||
if not os.path.exists(result_dir):
|
||||
os.makedirs(result_dir)
|
||||
for mod in os.listdir(os.path.dirname(os.path.realpath(__file__))):
|
||||
# skip non python files
|
||||
if not mod.endswith('.py'):
|
||||
continue
|
||||
# skip file not in provided names
|
||||
if names and names not in mod:
|
||||
continue
|
||||
# skip files that don't start with 'bench_'
|
||||
if not mod.startswith('bench_'):
|
||||
continue
|
||||
print(f'running {mod}...')
|
||||
mod = __import__(os.path.splitext(mod)[0])
|
||||
benchmarks = inspect.getmembers(mod, lambda x: isinstance(x, triton.testing.Mark))
|
||||
for name, bench in benchmarks:
|
||||
curr_dir = os.path.join(result_dir, mod.__name__.replace('bench_', ''))
|
||||
if len(benchmarks) > 1:
|
||||
curr_dir = os.path.join(curr_dir, name.replace('bench_', ''))
|
||||
if not os.path.exists(curr_dir):
|
||||
os.makedirs(curr_dir)
|
||||
bench.run(save_path=curr_dir)
|
||||
|
||||
|
||||
def main(args):
|
||||
parser = argparse.ArgumentParser(description="Run the benchmark suite.")
|
||||
parser.add_argument("-r", "--result-dir", type=str, default='results', required=False)
|
||||
parser.add_argument("-n", "--names", type=str, default='', required=False)
|
||||
parser.set_defaults(feature=False)
|
||||
args = parser.parse_args(args)
|
||||
run_all(args.result_dir, args.names)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
18
python/examples/copy_strided.py
Normal file
18
python/examples/copy_strided.py
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn,
|
||||
Z, stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
tl.store(Zs, tl.load(Xs))
|
||||
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 128, "BLOCK_N": 128}, output="ttgir")
|
||||
print(ret)
|
8
python/examples/empty.py
Normal file
8
python/examples/empty.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
|
||||
pass
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir")
|
@@ -1,202 +0,0 @@
|
||||
#include "cutlass/library/handle.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/operation_table.h"
|
||||
#include "cutlass/library/singleton.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
using namespace cutlass;
|
||||
using namespace cutlass::library;
|
||||
|
||||
std::map<std::vector<size_t>, const Operation *> op_cache_;
|
||||
|
||||
static int const kHostWorkspaceSize = (4 << 10);
|
||||
static int const kDeviceWorkspaceSize = (4 << 20);
|
||||
|
||||
void run(int M, int N, int K,
|
||||
int lda, int ldb, int ldc, int ldd,
|
||||
void const *ptr_A, void const *ptr_B, void const *ptr_C, void *ptr_D,
|
||||
void const *alpha, void const *beta,
|
||||
ScalarPointerMode scalar_mode,
|
||||
const Operation *operation,
|
||||
cudaStream_t stream) {
|
||||
|
||||
GemmUniversalConfiguration configuration{
|
||||
GemmUniversalMode::kGemm,
|
||||
{M, N, K},
|
||||
1,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldd};
|
||||
|
||||
// host workspace size
|
||||
uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration);
|
||||
if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed)
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
char host_workspace[kHostWorkspaceSize];
|
||||
|
||||
// device workspace size
|
||||
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
|
||||
if (uint64_t(kDeviceWorkspaceSize) < device_workspace_size_needed)
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
static void *device_workspace;
|
||||
|
||||
// Initialize host and device workspaces
|
||||
Status status = operation->initialize(&configuration, host_workspace, device_workspace, stream);
|
||||
if (status != cutlass::Status::kSuccess)
|
||||
throw std::runtime_error("Unable to initialize workspace");
|
||||
|
||||
// Run the operator
|
||||
GemmArguments arguments{ptr_A, ptr_B, ptr_C, ptr_D, alpha, beta, scalar_mode};
|
||||
operation->run(&arguments, host_workspace, device_workspace, stream);
|
||||
}
|
||||
|
||||
const Operation *autotune(int M, int N, int K,
|
||||
NumericTypeID element_compute,
|
||||
NumericTypeID element_scalar,
|
||||
void const *alpha,
|
||||
NumericTypeID element_A,
|
||||
LayoutTypeID layout_A,
|
||||
ComplexTransform transform_A,
|
||||
void const *ptr_A,
|
||||
int lda,
|
||||
NumericTypeID element_B,
|
||||
LayoutTypeID layout_B,
|
||||
ComplexTransform transform_B,
|
||||
void const *ptr_B,
|
||||
int ldb,
|
||||
void const *beta,
|
||||
NumericTypeID element_C,
|
||||
void const *ptr_C,
|
||||
int ldc,
|
||||
void *ptr_D,
|
||||
int ldd,
|
||||
ScalarPointerMode scalar_mode,
|
||||
int device_id,
|
||||
cudaStream_t stream) {
|
||||
|
||||
// index operation table with functional key
|
||||
GemmFunctionalKey key(
|
||||
Provider::kCUTLASS,
|
||||
GemmKind::kUniversal,
|
||||
element_compute,
|
||||
element_scalar,
|
||||
element_A,
|
||||
layout_A,
|
||||
transform_A,
|
||||
element_B,
|
||||
layout_B,
|
||||
transform_B,
|
||||
element_C);
|
||||
|
||||
auto operators_it = Singleton::get().operation_table.gemm_operations.find(key);
|
||||
if (operators_it == Singleton::get().operation_table.gemm_operations.end())
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
if (operators_it->second.empty())
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
|
||||
cudaDeviceProp device_prop;
|
||||
cudaError_t error = cudaGetDeviceProperties(&device_prop, device_id);
|
||||
if (error != cudaSuccess)
|
||||
throw std::runtime_error("Unable to get device properties");
|
||||
int cc = device_prop.major * 10 + device_prop.minor;
|
||||
|
||||
// index operation table with preference key
|
||||
// assume 8-bytes aligned memory pointers
|
||||
int alignment = 8;
|
||||
GemmPreferenceKey preference_key(cc, alignment);
|
||||
auto autotune_it = operators_it->second.find(preference_key);
|
||||
if (autotune_it == operators_it->second.end())
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
const std::vector<const Operation *> &operations = autotune_it->second;
|
||||
if (operations.empty())
|
||||
throw std::runtime_error("Unable to find gemm operation");
|
||||
|
||||
// auto-tune
|
||||
const Operation *best = nullptr;
|
||||
double best_ms = std::numeric_limits<double>::max();
|
||||
for (const Operation *op : operations) {
|
||||
auto fn = [&]() { run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D,
|
||||
alpha, beta, scalar_mode, op, stream); };
|
||||
triton::driver::cu_stream tt_stream((CUstream)stream, false);
|
||||
double ms = triton::tools::bench(fn, &tt_stream, 10, 25);
|
||||
if (ms < best_ms) {
|
||||
best_ms = ms;
|
||||
best = op;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
|
||||
// map of torch datatypes to cutlass datatypes
|
||||
std::map<std::string, NumericTypeID> type_map = {
|
||||
{"float16", NumericTypeID::kF16},
|
||||
{"float32", NumericTypeID::kF32},
|
||||
{"float64", NumericTypeID::kF64}};
|
||||
|
||||
void cutlass_matmul(uintptr_t A, uintptr_t B, uintptr_t C,
|
||||
size_t M, size_t N, size_t K,
|
||||
size_t stride_a_0, size_t stride_a_1,
|
||||
size_t stride_b_0, size_t stride_b_1,
|
||||
size_t stride_c_0, size_t stride_c_1,
|
||||
std::string type_a, std::string type_b, std::string type_c,
|
||||
size_t dev_id, uint64_t stream_handle) {
|
||||
void *ptr_A = (void *)A;
|
||||
void *ptr_B = (void *)B;
|
||||
void *ptr_C = (void *)C;
|
||||
void *ptr_D = ptr_C;
|
||||
size_t lda = stride_a_0;
|
||||
size_t ldb = stride_b_0;
|
||||
size_t ldc = stride_c_1;
|
||||
size_t ldd = ldc;
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
// layout for A
|
||||
LayoutTypeID layout_A;
|
||||
if (stride_a_0 == 1)
|
||||
layout_A = LayoutTypeID::kColumnMajor;
|
||||
else if (stride_a_1 == 1)
|
||||
layout_A = LayoutTypeID::kRowMajor;
|
||||
else
|
||||
throw std::runtime_error("A layout is not supported");
|
||||
// layout for B
|
||||
LayoutTypeID layout_B;
|
||||
if (stride_b_0 == 1)
|
||||
layout_B = LayoutTypeID::kColumnMajor;
|
||||
else if (stride_b_1 == 1)
|
||||
layout_B = LayoutTypeID::kRowMajor;
|
||||
else
|
||||
throw std::runtime_error("B layout is not supported");
|
||||
// data types
|
||||
NumericTypeID element_compute = NumericTypeID::kF32;
|
||||
NumericTypeID element_A = type_map[type_a];
|
||||
NumericTypeID element_B = type_map[type_b];
|
||||
NumericTypeID element_C = type_map[type_c];
|
||||
// misc. flags
|
||||
ScalarPointerMode scalar_mode = ScalarPointerMode::kHost;
|
||||
NumericTypeID element_scalar = NumericTypeID::kF32;
|
||||
ComplexTransform transform_A = ComplexTransform::kNone;
|
||||
ComplexTransform transform_B = ComplexTransform::kNone;
|
||||
// runtime flags
|
||||
cudaStream_t stream = (cudaStream_t)stream_handle;
|
||||
// auto-tune
|
||||
std::vector<size_t> tune_key = {M, N, K, (size_t)element_A, (size_t)element_B, (size_t)element_C,
|
||||
dev_id, (size_t)element_compute, (size_t)scalar_mode};
|
||||
auto it = op_cache_.find(tune_key);
|
||||
if (it == op_cache_.end()) {
|
||||
const Operation *op = autotune(M, N, K, element_compute, element_scalar, &alpha,
|
||||
element_A, layout_A, transform_A, ptr_A, lda,
|
||||
element_B, layout_B, transform_B, ptr_B, ldb,
|
||||
&beta, element_C, ptr_C, ldc, ptr_D, ldd, scalar_mode,
|
||||
dev_id, stream);
|
||||
it = op_cache_.insert({tune_key, op}).first;
|
||||
}
|
||||
run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D, &alpha, &beta,
|
||||
scalar_mode, it->second, stream);
|
||||
}
|
||||
|
||||
void init_cutlass(pybind11::module &m) {
|
||||
pybind11::module subm = m.def_submodule("cutlass");
|
||||
subm.def("matmul", &cutlass_matmul, "matrix multiplication");
|
||||
}
|
@@ -1,676 +0,0 @@
|
||||
#include "triton/ir/builder.h"
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace ir = triton::ir;
|
||||
namespace py = pybind11;
|
||||
|
||||
static const std::string _builder_doc = R"pbdoc(
|
||||
:param builder: IR builder to generate code into, optional, set automatically when called inside a @triton.jit function
|
||||
:type builder: triton.ir.builder
|
||||
)pbdoc";
|
||||
|
||||
#define VA_ARGS(...) , ##__VA_ARGS__
|
||||
#define DEF_FUNC(MOD, PY_NAME, C_FUNC, ...) \
|
||||
MOD.def(PY_NAME, C_FUNC, (C_FUNC##_docstr + _builder_doc).c_str(), \
|
||||
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
|
||||
|
||||
void throw_not_implemented(std::string key) {
|
||||
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. This is likely a bug on our side.");
|
||||
}
|
||||
|
||||
void throw_not_int_or_float(std::string key) {
|
||||
throw std::runtime_error("`" + key + "` only supported for integer and floating point types.");
|
||||
}
|
||||
|
||||
enum type_code {
|
||||
_bool,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
float16,
|
||||
float32,
|
||||
float64
|
||||
};
|
||||
|
||||
ir::type *make_ir(type_code ty, ir::builder *builder) {
|
||||
switch (ty) {
|
||||
case float16:
|
||||
return builder->get_half_ty();
|
||||
case float32:
|
||||
return builder->get_float_ty();
|
||||
default:
|
||||
throw_not_implemented("make_ir");
|
||||
}
|
||||
}
|
||||
|
||||
type_code from_ir(ir::type *ty) {
|
||||
if (ty->is_half_ty())
|
||||
return float16;
|
||||
if (ty->is_float_ty())
|
||||
return float32;
|
||||
throw_not_implemented("from_ir");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.cast / triton.ir.value.to
|
||||
----------------------------------------------*/
|
||||
std::string cast_docstr = R"pbdoc(
|
||||
Tries to cast a block to a new data type.
|
||||
|
||||
:param input: The input block.
|
||||
:type input: triton.ir.value
|
||||
)pbdoc";
|
||||
|
||||
ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) {
|
||||
ir::type *src_ty = input->get_type();
|
||||
ir::type *dst_ty = make_ir(_dtype, builder);
|
||||
if (src_ty->is_block_ty())
|
||||
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
|
||||
ir::type *src_sca_ty = src_ty->get_scalar_ty();
|
||||
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
|
||||
// FP Truncation
|
||||
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
|
||||
dst_sca_ty->is_floating_point_ty() &&
|
||||
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
|
||||
if (truncate_fp)
|
||||
return builder->create_fp_trunc(input, dst_ty);
|
||||
// FP Extension
|
||||
bool ext_fp = src_sca_ty->is_floating_point_ty() &&
|
||||
dst_sca_ty->is_floating_point_ty() &&
|
||||
src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width();
|
||||
if (ext_fp)
|
||||
return builder->create_fp_ext(input, dst_ty);
|
||||
// Int cast
|
||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
|
||||
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
|
||||
return builder->create_int_cast(input, dst_ty, true);
|
||||
// Float -> Int
|
||||
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty())
|
||||
return builder->create_fp_to_si(input, dst_ty);
|
||||
// int -> Float
|
||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty())
|
||||
return builder->create_si_to_fp(input, dst_ty);
|
||||
// Ptr -> Ptr
|
||||
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
|
||||
return builder->create_cast(ir::BitCast, input, dst_ty);
|
||||
// * -> Bool
|
||||
if (dst_sca_ty->is_bool_ty()) {
|
||||
if (src_sca_ty->is_pointer_ty())
|
||||
input = cast(input, int64, builder);
|
||||
ir::value *other = builder->get_int64(0);
|
||||
if (src_ty->is_bool_ty())
|
||||
other = builder->create_splat(other, src_ty->get_block_shapes());
|
||||
return builder->create_icmpNE(input, other);
|
||||
}
|
||||
throw_not_implemented("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.broadcast_check
|
||||
----------------------------------------------*/
|
||||
std::string try_broadcast_docstr = R"pbdoc(
|
||||
Tries to broadcast two blocks to a common compatible shape.
|
||||
|
||||
:param input: The first input block.
|
||||
:type input: triton.ir.value
|
||||
:param other: The second input block.
|
||||
:type other: triton.ir.value
|
||||
)pbdoc";
|
||||
|
||||
std::tuple<ir::value *, ir::value *> try_broadcast(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
|
||||
ir::type *lhs_ty = lhs->get_type();
|
||||
ir::type *rhs_ty = rhs->get_type();
|
||||
// make_shape_compatible(block, scalar)
|
||||
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
|
||||
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());
|
||||
// make_shape_compatible(scalar, block)
|
||||
else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty())
|
||||
lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes());
|
||||
// make_shape_compatible(block, block)
|
||||
else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) {
|
||||
auto lhs_shape = lhs_ty->get_block_shapes();
|
||||
auto rhs_shape = rhs_ty->get_block_shapes();
|
||||
if (lhs_shape.size() != rhs_shape.size())
|
||||
throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank");
|
||||
ir::type::block_shapes_t ret_shape;
|
||||
for (size_t i = 0; i < lhs_shape.size(); ++i) {
|
||||
unsigned left = lhs_shape[i];
|
||||
unsigned right = rhs_shape[i];
|
||||
if (left == 1)
|
||||
ret_shape.push_back(right);
|
||||
else if (right == 1)
|
||||
ret_shape.push_back(left);
|
||||
else if (left == right)
|
||||
ret_shape.push_back(left);
|
||||
else
|
||||
throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) +
|
||||
": " + std::to_string(left) + " and " + std::to_string(right));
|
||||
}
|
||||
if (lhs_shape != ret_shape)
|
||||
lhs = builder->create_broadcast(lhs, ret_shape);
|
||||
if (rhs_shape != ret_shape)
|
||||
rhs = builder->create_broadcast(rhs, ret_shape);
|
||||
}
|
||||
return std::make_tuple(lhs, rhs);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.broadcast_to
|
||||
----------------------------------------------*/
|
||||
std::string broadcast_to_docstr = R"pbdoc(
|
||||
Tries to broadcast a block to a new shape.
|
||||
|
||||
:param input: The input block.
|
||||
:type input: triton.value
|
||||
:param shape: The new shape.
|
||||
:type shape: tuple of int
|
||||
)pbdoc";
|
||||
|
||||
ir::value *broadcast_to(ir::value *input, const ir::type::block_shapes_t &shape, ir::builder *builder) {
|
||||
if (!input->get_type()->is_block_ty())
|
||||
return builder->create_splat(input, shape);
|
||||
auto src_shape = input->get_type()->get_block_shapes();
|
||||
if (src_shape.size() != shape.size())
|
||||
throw std::runtime_error("Cannot broadcast");
|
||||
return builder->create_broadcast(input, shape);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.load
|
||||
----------------------------------------------*/
|
||||
std::string load_docstr = R"pbdoc(
|
||||
Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`.
|
||||
|
||||
:param pointer: Pointer to the data to be loaded.
|
||||
:type pointer: Block of triton.pointer
|
||||
:param mask: if mask[idx] is false, do not load the data at `pointer[idx]`.
|
||||
:type mask: Block of triton.bool, optional
|
||||
:param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]`
|
||||
:type other: Block of triton.value, optional
|
||||
)pbdoc";
|
||||
|
||||
ir::value *load(ir::value *pointer, std::optional<ir::value *> _mask, std::optional<ir::value *> _other, ir::builder *builder) {
|
||||
if (!_mask.has_value() && !_other.has_value())
|
||||
return builder->create_load(pointer);
|
||||
if (!_mask.has_value())
|
||||
throw std::runtime_error("`other` cannot be provided without `mask`");
|
||||
ir::value *mask = _mask.value();
|
||||
ir::type *elt_ty = pointer->get_type()->get_scalar_ty()->get_pointer_element_ty();
|
||||
auto shape = pointer->get_type()->get_block_shapes();
|
||||
ir::value *other = _other.has_value() ? _other.value() : ir::undef_value::get(elt_ty);
|
||||
other = cast(other, from_ir(elt_ty), builder);
|
||||
other = broadcast_to(other, shape, builder);
|
||||
mask = broadcast_to(mask, shape, builder);
|
||||
return builder->create_masked_load(pointer, mask, other);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.store
|
||||
----------------------------------------------*/
|
||||
std::string store_docstr = R"pbdoc(
|
||||
Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`.
|
||||
|
||||
:param pointer: The memory locations where the elements of `value` are stored.
|
||||
:type pointer: Block of triton.pointer
|
||||
:param value: The block of elements to be stored.
|
||||
:type value: Block of triton.value
|
||||
:param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`.
|
||||
:type mask: Block of triton.bool, optional
|
||||
)pbdoc";
|
||||
ir::value *store(ir::value *ptr, ir::value *val, std::optional<ir::value *> _mask, ir::builder *builder) {
|
||||
if (!_mask.has_value())
|
||||
return builder->create_store(ptr, val);
|
||||
ir::value *mask = _mask.value();
|
||||
return builder->create_masked_store(ptr, val, mask);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.dot
|
||||
----------------------------------------------*/
|
||||
std::string dot_docstr = R"pbdoc(
|
||||
Returns the matrix product of two blocks.
|
||||
The two blocks must be two dimensionals and have compatible inner dimensions.
|
||||
|
||||
:param input: The first block to be multiplied.
|
||||
:type input: 2D block of scalar-type in {`float16`, `float32`}
|
||||
:param other: The second block to be multiplied.
|
||||
:type other: 2D block of scalar-type in {`float16`, `float32`}
|
||||
)pbdoc";
|
||||
ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
|
||||
ir::value *_0 = builder->get_float32(0);
|
||||
unsigned M = lhs->get_type()->get_block_shapes()[0];
|
||||
unsigned N = rhs->get_type()->get_block_shapes()[1];
|
||||
_0 = builder->create_splat(_0, {M, N});
|
||||
return builder->create_dot(lhs, rhs, _0);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.where
|
||||
----------------------------------------------*/
|
||||
std::string where_docstr = R"pbdoc(
|
||||
Returns a block of elements from either `x` or `y`, depending on `condition`.
|
||||
Note that `x` and `y` are always evaluated regardless of the value of `condition`.
|
||||
If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
|
||||
|
||||
:param condition: When True (nonzero), yield x, otherwise yield y.
|
||||
:type condition: Block of triton.bool
|
||||
:param x: values selected at indices where condition is True.
|
||||
:param y: values selected at indices where condition is False.
|
||||
)pbdoc";
|
||||
ir::value *where(ir::value *condition, ir::value *x, ir::value *y, ir::builder *builder) {
|
||||
return builder->create_select(condition, x, y);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.arange
|
||||
----------------------------------------------*/
|
||||
std::string arange_docstr = R"pbdoc(
|
||||
Returns contiguous values within the open interval [start, end).
|
||||
|
||||
:param start: Start of the interval.
|
||||
:type start: int
|
||||
:param stop: End of the interval.
|
||||
:type stop: int
|
||||
)pbdoc";
|
||||
ir::value *arange(int start, int end, ir::builder *builder) {
|
||||
return builder->get_range(start, end);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.program_id
|
||||
----------------------------------------------*/
|
||||
std::string program_id_docstr = R"pbdoc(
|
||||
Returns the id of the current program instance along the given `axis`.
|
||||
Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s.
|
||||
|
||||
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
||||
:type axis: int
|
||||
)pbdoc";
|
||||
ir::value *program_id(int axis, ir::builder *builder) {
|
||||
return builder->create_get_program_id(axis);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.num_programs
|
||||
----------------------------------------------*/
|
||||
std::string num_programs_docstr = R"pbdoc(
|
||||
Returns the number of program instances launched along the given `axis`.
|
||||
|
||||
:param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2.
|
||||
:type axis: int
|
||||
)pbdoc";
|
||||
ir::value *num_programs(int axis, ir::builder *builder) {
|
||||
return builder->create_get_num_programs(axis);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.zeros
|
||||
----------------------------------------------*/
|
||||
std::string zeros_docstr = R"pbdoc(
|
||||
Returns a block filled with the scalar value 0 and the given shape.
|
||||
|
||||
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
||||
:type shape: tuple of ints
|
||||
:param dtype: Data-type of the new array, e.g., tl.float16
|
||||
:type dtype: triton.ir.dtype
|
||||
)pbdoc";
|
||||
ir::value *zeros(ir::type::block_shapes_t shape, type_code _dtype, ir::builder *builder) {
|
||||
ir::type *dtype = make_ir(_dtype, builder);
|
||||
ir::value *_0 = ir::constant::get_null_value(dtype);
|
||||
return builder->create_splat(_0, shape);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.exp
|
||||
----------------------------------------------*/
|
||||
std::string _exp_docstr = R"pbdoc(
|
||||
Returns the element-wise exponential of `input`.
|
||||
)pbdoc";
|
||||
ir::value *_exp(ir::value *input, ir::builder *builder) {
|
||||
return builder->create_exp(input);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.log
|
||||
----------------------------------------------*/
|
||||
std::string _log_docstr = R"pbdoc(
|
||||
Returns the element-wise natural logarithm of `input`.
|
||||
)pbdoc";
|
||||
ir::value *_log(ir::value *input, ir::builder *builder) {
|
||||
return builder->create_log(input);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.sqrt
|
||||
----------------------------------------------*/
|
||||
std::string sqrt_docstr = R"pbdoc(
|
||||
Returns the element-wise square root of `input`.
|
||||
)pbdoc";
|
||||
ir::value *sqrt(ir::value *input, ir::builder *builder) {
|
||||
return builder->create_sqrt(input);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.min
|
||||
----------------------------------------------*/
|
||||
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
|
||||
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_reduce(input, FLOAT_OP, axis);
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_reduce(input, INT_OP, axis);
|
||||
else
|
||||
throw_not_int_or_float(name);
|
||||
}
|
||||
|
||||
std::string min_docstr = R"pbdoc(
|
||||
Returns the minimum value of `input`.
|
||||
)pbdoc";
|
||||
ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.max
|
||||
----------------------------------------------*/
|
||||
std::string max_docstr = R"pbdoc(
|
||||
Returns the maximum value of `input`.
|
||||
)pbdoc";
|
||||
ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.sum
|
||||
----------------------------------------------*/
|
||||
std::string sum_docstr = R"pbdoc(
|
||||
Returns the sum of `input`.
|
||||
)pbdoc";
|
||||
ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.atomic_cas
|
||||
----------------------------------------------*/
|
||||
std::string atomic_cas_docstr = R"pbdoc(
|
||||
Atomic compare-and-swap.
|
||||
)pbdoc";
|
||||
ir::value *atomic_cas(ir::value *ptr, ir::value *cmp, ir::value *val, ir::builder *builder) {
|
||||
return builder->create_atomic_cas(ptr, cmp, val);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.atomic_xchg
|
||||
----------------------------------------------*/
|
||||
std::string atomic_xchg_docstr = R"pbdoc(
|
||||
Atomic exchange.
|
||||
)pbdoc";
|
||||
ir::value *atomic_xchg(ir::value *ptr, ir::value *val, ir::builder *builder) {
|
||||
return builder->create_atomic_exch(ptr, val);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
debug barrier
|
||||
----------------------------------------------*/
|
||||
std::string debug_barrier_docstr = R"pbdoc(
|
||||
Temporary hacky fixup for when the compiler forgets to insert sync barriers
|
||||
)pbdoc";
|
||||
ir::value *debug_barrier(ir::builder *builder) {
|
||||
return builder->create_barrier();
|
||||
}
|
||||
|
||||
#define DEF_BINARY_OP(MOD, PY_NAME, C_FUNC, ...) \
|
||||
MOD.def(PY_NAME, binary_op(C_FUNC), (C_FUNC##_docstr + _builder_doc).c_str(), \
|
||||
ret::reference VA_ARGS(__VA_ARGS__), "builder"_a)
|
||||
|
||||
template <class FN>
|
||||
std::function<ir::value *(ir::value *, ir::value *, ir::builder *builder)>
|
||||
binary_op(const FN &fn) {
|
||||
auto ret = [&fn](ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
//std::tie(self, other) = try_broadcast(self, other, builder);
|
||||
return fn(self, other, builder);
|
||||
};
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self + other
|
||||
----------------------------------------------*/
|
||||
std::string add_docstr = R"pbdoc(
|
||||
Returns self + other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *add(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// ptr + offset
|
||||
if (scalar_ty->is_pointer_ty())
|
||||
return builder->create_gep(self, {other});
|
||||
// float + float
|
||||
else if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fadd(self, other);
|
||||
// int + int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_add(self, other);
|
||||
throw_not_implemented("add");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self - other
|
||||
----------------------------------------------*/
|
||||
std::string sub_docstr = R"pbdoc(
|
||||
Returns self - other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *sub(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// ptr + offset
|
||||
if (scalar_ty->is_pointer_ty())
|
||||
return builder->create_gep(self, {other});
|
||||
// float + float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fsub(self, other);
|
||||
// int + int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_sub(self, other);
|
||||
throw_not_implemented("sub");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self * other
|
||||
----------------------------------------------*/
|
||||
std::string mul_docstr = R"pbdoc(
|
||||
Returns self * other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *mul(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float * float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fmul(self, other);
|
||||
// int * int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_mul(self, other);
|
||||
throw_not_implemented("mul");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self > other
|
||||
----------------------------------------------*/
|
||||
std::string greater_than_docstr = R"pbdoc(
|
||||
Returns self > other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *greater_than(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float > float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGT(self, other);
|
||||
// int > int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSGT(self, other);
|
||||
throw_not_implemented("greater_than");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self >= other
|
||||
----------------------------------------------*/
|
||||
std::string greater_equal_docstr = R"pbdoc(
|
||||
Returns self >= other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *greater_equal(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float >= float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGE(self, other);
|
||||
// int >= int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSGE(self, other);
|
||||
throw_not_implemented("greater_equal");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self < other
|
||||
----------------------------------------------*/
|
||||
std::string less_than_docstr = R"pbdoc(
|
||||
Returns self < other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *less_than(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float < float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLT(self, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSLT(self, other);
|
||||
throw_not_implemented("less_than");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self <= other
|
||||
----------------------------------------------*/
|
||||
std::string less_equal_docstr = R"pbdoc(
|
||||
Returns self <= other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *less_equal(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float < float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLE(self, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSLE(self, other);
|
||||
throw_not_implemented("less_equal");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self == other
|
||||
----------------------------------------------*/
|
||||
std::string equal_docstr = R"pbdoc(
|
||||
Returns self == other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *equal(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float == float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOEQ(self, other);
|
||||
// int == int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpEQ(self, other);
|
||||
throw_not_implemented("equal");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self / other
|
||||
----------------------------------------------*/
|
||||
std::string _div_docstr = R"pbdoc(
|
||||
Returns self / other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *_div(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float / float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fdiv(self, other);
|
||||
// int / int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_sdiv(self, other);
|
||||
throw_not_implemented("div");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self % other
|
||||
----------------------------------------------*/
|
||||
std::string mod_docstr = R"pbdoc(
|
||||
Returns self % other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *mod(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
ir::type *scalar_ty = self->get_type()->get_scalar_ty();
|
||||
// float % int
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_frem(self, other);
|
||||
// int % int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_srem(self, other);
|
||||
throw_not_implemented("mod");
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self & other
|
||||
----------------------------------------------*/
|
||||
std::string _and_docstr = R"pbdoc(
|
||||
Returns self & other, element-wise.
|
||||
)pbdoc";
|
||||
ir::value *_and(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
return builder->create_and(self, other);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of minimum(self, other)
|
||||
----------------------------------------------*/
|
||||
std::string minimum_docstr = R"pbdoc(
|
||||
Returns element-wise minimum of self and other
|
||||
)pbdoc";
|
||||
ir::value *minimum(ir::value *self, ir::value *other, ir::builder *builder) {
|
||||
return where(less_than(self, other, builder), self, other, builder);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of self[slices]
|
||||
----------------------------------------------*/
|
||||
|
||||
enum slice_mode_t {
|
||||
NEWAXIS,
|
||||
ALL
|
||||
};
|
||||
|
||||
std::string subscript_docstr = R"pbdoc(
|
||||
returns self[slices].
|
||||
|
||||
:param slices: The slices to subscript with.
|
||||
:type slices: List of `None` or `:` slices.
|
||||
)pbdoc";
|
||||
ir::value *subscript(ir::value *self, std::vector<py::object> slices, ir::builder *builder) {
|
||||
std::vector<slice_mode_t> modes;
|
||||
for (py::object slice : slices) {
|
||||
py::object none = py::none();
|
||||
py::object all = py::make_tuple(none, none, none);
|
||||
if (slice.is(none))
|
||||
modes.push_back(NEWAXIS);
|
||||
else if (all.attr("__eq__")(slice))
|
||||
modes.push_back(ALL);
|
||||
else
|
||||
throw std::runtime_error("slice must be None or (None, None, None)");
|
||||
}
|
||||
|
||||
ir::type::block_shapes_t shape;
|
||||
size_t curr = 0;
|
||||
for (slice_mode_t mode : modes) {
|
||||
if (mode == NEWAXIS)
|
||||
shape.push_back(1);
|
||||
else {
|
||||
assert(mode == ALL);
|
||||
shape.push_back(self->get_type()->get_block_shapes()[curr++]);
|
||||
}
|
||||
}
|
||||
return builder->create_reshape(self, shape);
|
||||
}
|
@@ -8,8 +8,4 @@ void init_cutlass(pybind11::module &m);
|
||||
PYBIND11_MODULE(libtriton, m) {
|
||||
m.doc() = "Python bindings to the C++ Triton API";
|
||||
init_triton(m);
|
||||
init_superblocking(m);
|
||||
#ifdef WITH_CUTLASS_BINDINGS
|
||||
init_cutlass(m);
|
||||
#endif
|
||||
}
|
||||
|
@@ -1,119 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
// row-major 3d tensor
|
||||
class tensor_3d {
|
||||
public:
|
||||
tensor_3d(int size_0, int size_1, int size_2, int *data = nullptr) : data_(size_0 * size_1 * size_2, 0) {
|
||||
if (data)
|
||||
std::copy(data, data + data_.size(), data_.begin());
|
||||
stride_0_ = size_1 * size_2;
|
||||
stride_1_ = size_2;
|
||||
stride_2_ = 1;
|
||||
}
|
||||
|
||||
int &operator()(int i, int j, int k) {
|
||||
return data_[i * stride_0_ + j * stride_1_ + k];
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> data_;
|
||||
int stride_0_;
|
||||
int stride_1_;
|
||||
int stride_2_;
|
||||
};
|
||||
|
||||
std::vector<int> segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width, int H, int M, int N) {
|
||||
tensor_3d tmp(H, M, N);
|
||||
std::vector<int> current(H, 0);
|
||||
int num = 0;
|
||||
std::vector<int> lut(H * M * N * 4);
|
||||
for (ssize_t h = 0; h < H; h++) {
|
||||
// surrounding indices
|
||||
std::vector<int> ii_left(max_width, -1);
|
||||
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
|
||||
// start the dynamic programming algorithm
|
||||
for (ssize_t m = 0; m < M; m++) {
|
||||
for (ssize_t n = 0; n < N; n++) {
|
||||
int v = layout(h, m, n);
|
||||
if (v == 0)
|
||||
continue;
|
||||
int n_left = ii_left[max_width - 1];
|
||||
int m_top = ii_top[max_width - 1][n];
|
||||
int top = (m_top >= 0) ? tmp(h, m_top, n) : 0;
|
||||
int left = (n_left >= 0) ? tmp(h, m, n_left) : 0;
|
||||
int topleft = (m_top >= 0 && n_left >= 0) ? tmp(h, m_top, n_left) : 0;
|
||||
int width = std::min(left, std::min(top, topleft)) + 1;
|
||||
// reset width if blocks cannot be
|
||||
// packed together (i.e., there's a 1 "in the middle")
|
||||
for (int nn = n_left + 1; nn < n; nn++)
|
||||
if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n])
|
||||
width = 1;
|
||||
tmp(h, m, n) = width;
|
||||
// update n_left ring buffer
|
||||
for (int k = 0; k < max_width - 1; k++)
|
||||
ii_left[k] = ii_left[k + 1];
|
||||
ii_left[max_width - 1] = n;
|
||||
// update ii_top ring buffer
|
||||
for (int k = 0; k < max_width - 1; k++)
|
||||
ii_top[k][n] = ii_top[k + 1][n];
|
||||
ii_top[max_width - 1][n] = m;
|
||||
// block is too small -- skip
|
||||
if (width != max_width)
|
||||
continue;
|
||||
// retained blocks are set to zeros
|
||||
for (ssize_t km = 0; km < max_width; km++)
|
||||
for (ssize_t kn = 0; kn < max_width; kn++) {
|
||||
int mm = ii_top[km][n];
|
||||
int nn = ii_left[kn];
|
||||
if (mm < 0 || nn < 0)
|
||||
continue;
|
||||
layout(h, mm, nn) = 0;
|
||||
tmp(h, mm, nn) = 0;
|
||||
lut[num++] = (int)h;
|
||||
lut[num++] = (int)mm;
|
||||
lut[num++] = (int)nn;
|
||||
lut[num++] = idx(h, mm, nn);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
lut.resize(num);
|
||||
return lut;
|
||||
}
|
||||
|
||||
typedef std::pair<int, pybind11::array_t<int>> lut_t;
|
||||
|
||||
std::vector<lut_t> superblock(uintptr_t LAYOUT, int H, int M, int N, int start_width) {
|
||||
std::vector<lut_t> ret;
|
||||
int current = 0;
|
||||
tensor_3d layout(H, M, N, (int *)LAYOUT);
|
||||
tensor_3d idx(H, M, N);
|
||||
for (int64_t h = 0; h < H; h++)
|
||||
for (int64_t m = 0; m < M; m++)
|
||||
for (int64_t n = 0; n < N; n++) {
|
||||
if (layout(h, m, n) == 0)
|
||||
continue;
|
||||
idx(h, m, n) = current++;
|
||||
}
|
||||
// create lut
|
||||
for (int max_width = start_width; max_width > 0; max_width /= 2) {
|
||||
auto lut = segment_blocks(layout, idx, max_width, H, M, N);
|
||||
if (lut.size() == 0)
|
||||
continue;
|
||||
ret.push_back(std::make_pair(max_width, pybind11::array_t<int>(lut.size(), lut.data())));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void init_superblocking(pybind11::module &m) {
|
||||
m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication");
|
||||
}
|
@@ -764,19 +764,26 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference);
|
||||
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
|
||||
|
||||
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
|
||||
.def("dump", &mlir::ModuleOp::dump)
|
||||
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
|
||||
self.push_back(funcOp);
|
||||
})
|
||||
.def("has_function", [](mlir::ModuleOp &self, std::string &funcName) -> bool {
|
||||
if (self.lookupSymbol(funcName))
|
||||
return true;
|
||||
return false;
|
||||
})
|
||||
.def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
|
||||
return self.lookupSymbol<mlir::FuncOp>(funcName);
|
||||
})
|
||||
// dynamic_attr is used to transfer ownership of the MLIR context to the module
|
||||
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module", py::dynamic_attr())
|
||||
.def("dump", &mlir::ModuleOp::dump)
|
||||
.def("str", [](mlir::ModuleOp &self) -> std::string {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
self.print(os);
|
||||
return str;
|
||||
})
|
||||
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
|
||||
self.push_back(funcOp);
|
||||
})
|
||||
.def("has_function", [](mlir::ModuleOp &self, std::string &funcName) -> bool {
|
||||
if (self.lookupSymbol(funcName))
|
||||
return true;
|
||||
return false;
|
||||
})
|
||||
.def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
|
||||
return self.lookupSymbol<mlir::FuncOp>(funcName);
|
||||
})
|
||||
;
|
||||
|
||||
py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
|
||||
|
@@ -1,164 +0,0 @@
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
||||
|
||||
DEVICE_NAME = 'v100'
|
||||
|
||||
#######################
|
||||
# Utilities
|
||||
#######################
|
||||
|
||||
|
||||
def nvsmi(attrs):
|
||||
attrs = ','.join(attrs)
|
||||
cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
|
||||
out = subprocess.check_output(cmd)
|
||||
ret = out.decode(sys.stdout.encoding).split(',')
|
||||
ret = [int(x) for x in ret]
|
||||
return ret
|
||||
|
||||
|
||||
#######################
|
||||
# Matrix Multiplication
|
||||
#######################
|
||||
|
||||
sm_clocks = {'v100': 1350, 'a100': 1350}
|
||||
mem_clocks = {'v100': 877, 'a100': 1215}
|
||||
|
||||
matmul_data = {
|
||||
'v100': {
|
||||
# square
|
||||
(256, 256, 256): {'float16': 0.027},
|
||||
(512, 512, 512): {'float16': 0.158},
|
||||
(1024, 1024, 1024): {'float16': 0.466},
|
||||
(2048, 2048, 2048): {'float16': 0.695},
|
||||
(4096, 4096, 4096): {'float16': 0.831},
|
||||
(8192, 8192, 8192): {'float16': 0.849},
|
||||
# tall-skinny
|
||||
(16, 1024, 1024): {'float16': 0.0128},
|
||||
(16, 4096, 4096): {'float16': 0.0883},
|
||||
(16, 8192, 8192): {'float16': 0.101},
|
||||
(64, 1024, 1024): {'float16': 0.073},
|
||||
(64, 4096, 4096): {'float16': 0.270},
|
||||
(64, 8192, 8192): {'float16': 0.459},
|
||||
(1024, 64, 1024): {'float16': 0.0692},
|
||||
(4096, 64, 4096): {'float16': 0.264},
|
||||
(8192, 64, 8192): {'float16': 0.452},
|
||||
},
|
||||
'a100': {
|
||||
(256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006},
|
||||
(512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030},
|
||||
(1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169},
|
||||
(2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385},
|
||||
(4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711},
|
||||
(8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860},
|
||||
# tall-skinny
|
||||
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
|
||||
(16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259},
|
||||
(16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431},
|
||||
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
|
||||
(64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097},
|
||||
(64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174},
|
||||
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
|
||||
(4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102},
|
||||
(8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177},
|
||||
}
|
||||
# # deep reductions
|
||||
# (64 , 64 , 16384) : {'a100': 0.},
|
||||
# (64 , 64 , 65536) : {'a100': 0.},
|
||||
# (256 , 256 , 8192 ) : {'a100': 0.},
|
||||
# (256 , 256 , 32768) : {'a100': 0.},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M, N, K, dtype_str',
|
||||
[(M, N, K, dtype_str)
|
||||
for M, N, K in matmul_data[DEVICE_NAME].keys()
|
||||
for dtype_str in ['float16']])
|
||||
def test_matmul(M, N, K, dtype_str):
|
||||
if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100':
|
||||
pytest.skip('Only test float32 & int8 on a100')
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
|
||||
torch.manual_seed(0)
|
||||
ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
ref_sm_clock = sm_clocks[DEVICE_NAME]
|
||||
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
|
||||
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
|
||||
if dtype == torch.int8:
|
||||
a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda')
|
||||
b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda')
|
||||
b = b.t() # only test row-col layout
|
||||
else:
|
||||
a = torch.randn((M, K), dtype=dtype, device='cuda')
|
||||
b = torch.randn((K, N), dtype=dtype, device='cuda')
|
||||
fn = lambda: triton.ops.matmul(a, b)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
|
||||
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
|
||||
|
||||
#######################
|
||||
# Element-Wise
|
||||
#######################
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
elementwise_data = {
|
||||
'v100': {
|
||||
1024 * 16: 0.0219,
|
||||
1024 * 64: 0.0791,
|
||||
1024 * 256: 0.243,
|
||||
1024 * 1024: 0.534,
|
||||
1024 * 4096: 0.796,
|
||||
1024 * 16384: 0.905,
|
||||
1024 * 65536: 0.939,
|
||||
},
|
||||
'a100': {
|
||||
1024 * 16: 0.008,
|
||||
1024 * 64: 0.034,
|
||||
1024 * 256: 0.114,
|
||||
1024 * 1024: 0.315,
|
||||
1024 * 4096: 0.580,
|
||||
1024 * 16384: 0.782,
|
||||
1024 * 65536: 0.850,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys())
|
||||
def test_elementwise(N):
|
||||
torch.manual_seed(0)
|
||||
ref_gpu_util = elementwise_data[DEVICE_NAME][N]
|
||||
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
|
||||
ref_mem_clock = mem_clocks[DEVICE_NAME]
|
||||
max_gpu_perf = get_dram_gbps()
|
||||
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz'
|
||||
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
|
||||
x = torch.randn_like(z)
|
||||
y = torch.randn_like(z)
|
||||
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
|
||||
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
File diff suppressed because it is too large
Load Diff
@@ -1,177 +0,0 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.stats
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
#####################################
|
||||
# Reference Philox Implementation
|
||||
#####################################
|
||||
|
||||
|
||||
class PhiloxConfig:
|
||||
def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE):
|
||||
self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE)
|
||||
self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE)
|
||||
self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE)
|
||||
self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE)
|
||||
self.DTYPE = DTYPE
|
||||
|
||||
|
||||
# This is better for GPU
|
||||
PHILOX_32 = PhiloxConfig(
|
||||
PHILOX_KEY_A=0x9E3779B9,
|
||||
PHILOX_KEY_B=0xBB67AE85,
|
||||
PHILOX_ROUND_A=0xD2511F53,
|
||||
PHILOX_ROUND_B=0xCD9E8D57,
|
||||
DTYPE=np.uint32,
|
||||
)
|
||||
|
||||
# This is what numpy implements
|
||||
PHILOX_64 = PhiloxConfig(
|
||||
PHILOX_KEY_A=0x9E3779B97F4A7C15,
|
||||
PHILOX_KEY_B=0xBB67AE8584CAA73B,
|
||||
PHILOX_ROUND_A=0xD2E7470EE14C6C93,
|
||||
PHILOX_ROUND_B=0xCA5A826395121157,
|
||||
DTYPE=np.uint64,
|
||||
)
|
||||
|
||||
|
||||
class CustomPhilox4x:
|
||||
def __init__(self, seed, config):
|
||||
self._config = config
|
||||
seed = self._into_pieces(seed)
|
||||
self._key = np.array(seed[:2], dtype=self._dtype)
|
||||
self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype)
|
||||
|
||||
@property
|
||||
def _dtype(self):
|
||||
return self._config.DTYPE
|
||||
|
||||
def _into_pieces(self, n, pad=4):
|
||||
res = []
|
||||
while len(res) < pad:
|
||||
res.append(np.array(n, dtype=self._dtype))
|
||||
n >>= (np.dtype(self._dtype).itemsize * 8)
|
||||
assert n == 0
|
||||
return tuple(res)
|
||||
|
||||
def _multiply_low_high(self, a, b):
|
||||
low = a * b
|
||||
high = int(a) * int(b)
|
||||
high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype)
|
||||
return low, high
|
||||
|
||||
def _single_round(self, counter, key):
|
||||
lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0])
|
||||
lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2])
|
||||
ret0 = hi1 ^ counter[1] ^ key[0]
|
||||
ret1 = lo1
|
||||
ret2 = hi0 ^ counter[3] ^ key[1]
|
||||
ret3 = lo0
|
||||
return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype)
|
||||
|
||||
def _raise_key(self, key):
|
||||
pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B]
|
||||
return key + np.array(pk, dtype=self._dtype)
|
||||
|
||||
def random_raw(self):
|
||||
counter = self._counter
|
||||
key = self._key
|
||||
for _ in range(10):
|
||||
counter = self._single_round(counter, key)
|
||||
key = self._raise_key(key)
|
||||
self.advance(1)
|
||||
return counter
|
||||
|
||||
def advance(self, n_steps):
|
||||
self._counter[0] += n_steps
|
||||
assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets"
|
||||
|
||||
|
||||
class CustomPhilox(CustomPhilox4x):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.buffer = []
|
||||
|
||||
def random_raw(self):
|
||||
if len(self.buffer) == 0:
|
||||
self.buffer = list(super().random_raw())[::-1]
|
||||
return int(self.buffer.pop())
|
||||
|
||||
|
||||
#####################################
|
||||
# Unit Tests
|
||||
#####################################
|
||||
|
||||
BLOCK = 1024
|
||||
|
||||
# test generation of random uint32
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in ['10', '4,53', '10000']
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
def test_randint(size, seed, device='cuda'):
|
||||
size = list(map(int, size.split(',')))
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.randint(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.int32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist()
|
||||
# reference result
|
||||
gen = CustomPhilox4x(seed, config=PHILOX_32)
|
||||
out_ref = [gen.random_raw()[0] for _ in out_tri]
|
||||
assert out_tri == out_ref
|
||||
|
||||
# test uniform PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_rand(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.rand(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.float32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
assert all((x >= 0) & (x <= 1))
|
||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test normal PRNG
|
||||
|
||||
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]
|
||||
for seed in [0, 42, 124, 54]]
|
||||
)
|
||||
def test_randn(size, seed, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, N, seed):
|
||||
offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
|
||||
rand = tl.randn(seed, offset)
|
||||
tl.store(X + offset, rand, mask=offset < N)
|
||||
# triton result
|
||||
x = torch.empty(size, dtype=torch.float32, device=device)
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
assert abs(x.mean()) < 1e-2
|
||||
assert abs(x.std() - 1) < 1e-2
|
@@ -1,187 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
||||
@pytest.mark.parametrize("TRANS_A", [False, True])
|
||||
@pytest.mark.parametrize("TRANS_B", [False, True])
|
||||
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
|
||||
@pytest.mark.parametrize("DTYPE", [torch.float16])
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
|
||||
seed = 0
|
||||
torch.manual_seed(seed)
|
||||
is_sdd = MODE == "sdd"
|
||||
is_dsd = MODE == "dsd"
|
||||
is_dds = MODE == "dds"
|
||||
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
|
||||
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
|
||||
# create inputs
|
||||
# create op
|
||||
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
|
||||
b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
|
||||
c_shape = (Z, H, M, N)
|
||||
shape = {
|
||||
"sdd": (M, N),
|
||||
"dsd": (a_shape[2], a_shape[3]),
|
||||
"dds": (b_shape[2], b_shape[3]),
|
||||
}[MODE]
|
||||
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
||||
layout[1, 2, :] = 0
|
||||
layout[1, :, 1] = 0
|
||||
# create data
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1)
|
||||
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1)
|
||||
dc_ref, dc_tri = triton.testing.make_pair(c_shape)
|
||||
# compute [torch]
|
||||
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
|
||||
a_ref = do_mask(a_ref) if is_dsd else a_ref
|
||||
b_ref = do_mask(b_ref) if is_dds else b_ref
|
||||
a_ref.retain_grad()
|
||||
b_ref.retain_grad()
|
||||
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
|
||||
b_ref.transpose(2, 3) if TRANS_B else b_ref)
|
||||
c_ref.backward(dc_ref)
|
||||
c_ref = do_sparsify(c_ref) if is_sdd else c_ref
|
||||
da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
|
||||
db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad
|
||||
# triton result
|
||||
dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
|
||||
a_tri = do_sparsify(a_tri) if is_dsd else a_tri
|
||||
b_tri = do_sparsify(b_tri) if is_dds else b_tri
|
||||
a_tri.retain_grad()
|
||||
b_tri.retain_grad()
|
||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
|
||||
c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest)
|
||||
triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest)
|
||||
da_tri = a_tri.grad
|
||||
db_tri = b_tri.grad
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(c_ref, c_tri)
|
||||
triton.testing.assert_almost_equal(da_ref, da_tri)
|
||||
triton.testing.assert_almost_equal(db_ref, db_tri)
|
||||
|
||||
|
||||
configs = [
|
||||
(16, 256),
|
||||
(32, 576),
|
||||
(64, 1871),
|
||||
(128, 2511),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_dense", [False, True])
|
||||
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
|
||||
def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
Z, H, M, N = 2, 3, WIDTH, WIDTH
|
||||
# initialize layout
|
||||
# make sure each row has at least one non-zero element
|
||||
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
||||
if is_dense:
|
||||
layout[:] = 1
|
||||
else:
|
||||
layout[1, 2, :] = 0
|
||||
layout[1, :, 1] = 0
|
||||
# initialize data
|
||||
a_shape = (Z, H, M, N)
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape)
|
||||
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
|
||||
# compute [torch]
|
||||
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
|
||||
a_ref.retain_grad()
|
||||
at_mask = torch.ones((M, N), device="cuda")
|
||||
if is_causal:
|
||||
at_mask = torch.tril(at_mask)
|
||||
M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
|
||||
a_ref[M == 0] = float("-inf")
|
||||
out_ref = torch.softmax(a_ref * scale, -1)
|
||||
out_ref.backward(dout_ref)
|
||||
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
|
||||
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
|
||||
# compute [triton]
|
||||
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
|
||||
a_tri.retain_grad()
|
||||
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
|
||||
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
|
||||
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
|
||||
out_tri.backward(dout_tri)
|
||||
da_tri = a_tri.grad
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(out_tri, out_ref)
|
||||
triton.testing.assert_almost_equal(da_tri, da_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_attention_fwd_bwd(
|
||||
block,
|
||||
dtype,
|
||||
input_scale=1.0,
|
||||
scale=1 / 8.0,
|
||||
n_ctx=256,
|
||||
batch_size=2,
|
||||
n_heads=2,
|
||||
):
|
||||
# inputs
|
||||
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
||||
qkvs = [
|
||||
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
|
||||
]
|
||||
|
||||
# Triton:
|
||||
n_blocks = n_ctx // block
|
||||
layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long))
|
||||
query, key, value = [x.clone() for x in qkvs]
|
||||
query.retain_grad()
|
||||
key.retain_grad()
|
||||
value.retain_grad()
|
||||
attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale)
|
||||
# ad hoc loss
|
||||
loss = (attn_out ** 2).mean()
|
||||
loss.backward()
|
||||
grads = [query.grad, key.grad, value.grad]
|
||||
|
||||
# Torch version:
|
||||
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
|
||||
attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype)
|
||||
attn_mask = torch.tril(attn_mask, diagonal=0)
|
||||
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
|
||||
torch_q.retain_grad()
|
||||
torch_k.retain_grad()
|
||||
torch_v.retain_grad()
|
||||
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
|
||||
scores = scores + attn_mask
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
|
||||
# ad hoc loss
|
||||
torch_loss = (torch_attn_out ** 2).mean()
|
||||
torch_loss.backward()
|
||||
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
|
||||
|
||||
# comparison
|
||||
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
|
||||
triton.testing.assert_almost_equal(loss, torch_loss)
|
||||
for g1, g2 in zip(grads, torch_grads):
|
||||
triton.testing.assert_almost_equal(g1, g2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
def triton_attention(
|
||||
layout,
|
||||
block: int,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device)
|
||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device)
|
||||
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
|
||||
|
||||
w = sparse_dot_sdd_nt(query, key)
|
||||
w = sparse_softmax(w, scale=scale, is_causal=True)
|
||||
a = sparse_dot_dsd_nn(w, value)
|
||||
return a
|
@@ -1,35 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
[
|
||||
(M, N, dtype, mode) for M in [1024, 821]
|
||||
for N in [512, 857, 1871, 2089, 8573, 31000]
|
||||
for dtype in ['float16', 'float32']
|
||||
for mode in ['forward', 'backward']
|
||||
]
|
||||
)
|
||||
def test_op(M, N, dtype, mode):
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype]
|
||||
# create inputs
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
|
||||
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
|
||||
# forward pass
|
||||
tt_y = triton.ops.cross_entropy(x, idx)
|
||||
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
|
||||
if mode == 'forward':
|
||||
triton.testing.assert_almost_equal(th_y, tt_y)
|
||||
# backward pass
|
||||
elif mode == 'backward':
|
||||
dy = torch.randn_like(tt_y)
|
||||
# triton backward
|
||||
tt_y.backward(dy)
|
||||
tt_dx = x.grad.clone()
|
||||
# torch backward
|
||||
x.grad.zero_()
|
||||
th_y.backward(dy)
|
||||
th_dx = x.grad.clone()
|
||||
triton.testing.assert_almost_equal(th_dx, tt_dx)
|
@@ -1,98 +0,0 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
# n-stage
|
||||
*[
|
||||
[
|
||||
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
|
||||
]
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80 and DTYPE == "bfloat16":
|
||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||
if DTYPE == "bfloat16" and SPLIT_K != 1:
|
||||
pytest.skip("bfloat16 matmuls don't allow split_k for now")
|
||||
torch.manual_seed(0)
|
||||
# nuke kernel decorators -- will set meta-parameters manually
|
||||
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
||||
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
|
||||
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
|
||||
kernel = triton.ops._matmul.kernel
|
||||
decorators = kernel.kernel_decorators
|
||||
kernel.kernel_decorators = []
|
||||
triton.autotune(configs, [])(kernel)
|
||||
kernel.kernel_decorators += decorators[1:]
|
||||
# get matrix shape
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
K = BLOCK_K * SPLIT_K if K is None else K
|
||||
# allocate/transpose inputs
|
||||
DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE]
|
||||
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
# run test
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
|
||||
triton.testing.assert_almost_equal(th_c, tt_c)
|
@@ -1,132 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.code_gen import JITFunction
|
||||
|
||||
tmpdir = ".tmp"
|
||||
|
||||
|
||||
@triton.jit
|
||||
def function_1(i):
|
||||
i = i + 1
|
||||
i = function_2(i)
|
||||
return i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def function_2(i):
|
||||
i = i + 1
|
||||
return i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, i, BLOCK: tl.constexpr):
|
||||
i = i + 1
|
||||
i = function_1(i)
|
||||
tl.store(X, i)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["i"])
|
||||
def kernel_nospec(X, i, BLOCK: tl.constexpr):
|
||||
i = i + 1
|
||||
i = function_1(i)
|
||||
tl.store(X, i)
|
||||
|
||||
|
||||
def apply_src_change(target, old, new):
|
||||
kernel.hash = None
|
||||
function_1.hash = None
|
||||
function_2.hash = None
|
||||
function_1.src = function_1.src.replace(old, new)
|
||||
target.src = target.src.replace(old, new)
|
||||
ret = target.cache_key
|
||||
target.src = target.src.replace(new, old)
|
||||
return ret
|
||||
|
||||
|
||||
def test_nochange():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
|
||||
assert baseline == updated
|
||||
|
||||
|
||||
def test_toplevel_change():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
|
||||
assert baseline != updated
|
||||
|
||||
|
||||
def test_nested1_change():
|
||||
baseline = kernel.cache_key
|
||||
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
|
||||
assert baseline != updated
|
||||
|
||||
|
||||
def reset_tmp_dir():
|
||||
os.environ["TRITON_CACHE_DIR"] = tmpdir
|
||||
if os.path.exists(tmpdir):
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
def test_reuse():
|
||||
counter = 0
|
||||
|
||||
def inc_counter(*args, **kwargs):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
JITFunction.cache_hook = inc_counter
|
||||
reset_tmp_dir()
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
for i in range(10):
|
||||
kernel[(1,)](x, 1, BLOCK=1024)
|
||||
assert counter == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', ['enable', 'disable'])
|
||||
def test_specialize(mode):
|
||||
counter = 0
|
||||
|
||||
def inc_counter(*args, **kwargs):
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
JITFunction.cache_hook = inc_counter
|
||||
reset_tmp_dir()
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
|
||||
target = {'enable': 5, 'disable': 1}[mode]
|
||||
for i in [1, 2, 4, 8, 16, 32]:
|
||||
function[(1,)](x, i, BLOCK=512)
|
||||
assert counter == target
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
|
||||
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
|
||||
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
cache_str = None
|
||||
|
||||
def get_cache_str(*args, **kwargs):
|
||||
nonlocal cache_str
|
||||
cache_str = kwargs['key'].split('-')
|
||||
triton.code_gen.JITFunction.cache_hook = get_cache_str
|
||||
reset_tmp_dir()
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
kernel[(1, )](value, x)
|
||||
triton.code_gen.JITFunction.cache_hook = None
|
||||
|
||||
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
|
||||
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
||||
assert spec_type == value_type
|
@@ -1,98 +0,0 @@
|
||||
import subprocess
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def get_p2p_matrix():
|
||||
try:
|
||||
stdout = subprocess.check_output(["nvidia-smi", "topo", "-p2p", "n"]).decode("ascii")
|
||||
except subprocess.CalledProcessError:
|
||||
return pytest.skip("No multi-GPU topology", allow_module_level=True)
|
||||
|
||||
lines = stdout.split("Legend")[0].split('\n')[1:]
|
||||
matrix = np.array([line.split('\t')[1:-1] for line in lines][:-2])
|
||||
if matrix.size <= 1:
|
||||
return pytest.skip("No multi-GPU topology", allow_module_level=True)
|
||||
else:
|
||||
return matrix
|
||||
|
||||
|
||||
def get_p2p_devices():
|
||||
matrix = get_p2p_matrix()
|
||||
idx = np.where(matrix == "OK")
|
||||
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
|
||||
|
||||
|
||||
def get_non_p2p_devices():
|
||||
matrix = get_p2p_matrix()
|
||||
idx = np.where(matrix == "NS")
|
||||
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
|
||||
|
||||
|
||||
p2p_devices = get_p2p_devices()
|
||||
non_p2p_devices = get_non_p2p_devices()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _copy(from_ptr, to_ptr, N, **meta):
|
||||
pid = tl.program_id(0)
|
||||
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
|
||||
values = tl.load(from_ptr + offsets, mask=offsets < N)
|
||||
tl.store(to_ptr + offsets, values, mask=offsets < N)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not p2p_devices, reason="No pair of device with P2P support")
|
||||
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
|
||||
[(device_kernel, device_from, device_to, stream_from, stream_to)
|
||||
for device_kernel in p2p_devices
|
||||
for device_from in p2p_devices
|
||||
for device_to in p2p_devices
|
||||
for stream_from in ['default', 'custom']
|
||||
for stream_to in ['default', 'custom']
|
||||
])
|
||||
def test_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
|
||||
if device_to == device_from:
|
||||
return pytest.skip()
|
||||
|
||||
torch.cuda.set_device(device_kernel)
|
||||
N = 512
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
|
||||
|
||||
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
|
||||
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
|
||||
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
|
||||
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
|
||||
|
||||
_copy[grid](x_from, x_to, N, BLOCK=1024)
|
||||
assert torch.allclose(x_from, x_to.to(device_from))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not non_p2p_devices, reason="No pair of device with no P2P support")
|
||||
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
|
||||
[(device_kernel, device_from, device_to, stream_from, stream_to)
|
||||
for device_kernel in non_p2p_devices
|
||||
for device_from in non_p2p_devices
|
||||
for device_to in non_p2p_devices
|
||||
for stream_from in ['default', 'custom']
|
||||
for stream_to in ['default', 'custom']
|
||||
])
|
||||
def test_non_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
|
||||
if device_to == device_from:
|
||||
return pytest.skip()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
torch.cuda.set_device(device_kernel)
|
||||
N = 512
|
||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
|
||||
|
||||
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
|
||||
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
|
||||
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
|
||||
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
|
||||
|
||||
_copy[grid](x_from, x_to, N, BLOCK=1024)
|
@@ -6,9 +6,9 @@ __version__ = '2.0.0'
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch
|
||||
# submodules
|
||||
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \
|
||||
JITFunction, Config, Autotuner, reinterpret
|
||||
from .utils import *
|
||||
from .runtime import jit, Config, autotune, heuristics
|
||||
from .compiler import compile
|
||||
from . import language
|
||||
from . import code_gen
|
||||
from . import testing
|
||||
from . import ops
|
||||
|
File diff suppressed because it is too large
Load Diff
806
python/triton/compiler.py
Normal file
806
python/triton/compiler.py
Normal file
@@ -0,0 +1,806 @@
|
||||
from __future__ import annotations
|
||||
import ast
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Dict, Union
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
def str_to_ty(name):
|
||||
if name[0] == "*":
|
||||
ty = str_to_ty(name[1:])
|
||||
return triton.language.pointer_type(ty)
|
||||
tys = {
|
||||
"fp8": triton.language.float8,
|
||||
"fp16": triton.language.float16,
|
||||
"bf16": triton.language.bfloat16,
|
||||
"fp32": triton.language.float32,
|
||||
"fp64": triton.language.float64,
|
||||
"i8": triton.language.int8,
|
||||
"i16": triton.language.int16,
|
||||
"i32": triton.language.int32,
|
||||
"i64": triton.language.int64,
|
||||
"u8": triton.language.uint8,
|
||||
"u16": triton.language.uint16,
|
||||
"u32": triton.language.uint32,
|
||||
"u64": triton.language.uint64,
|
||||
"B": triton.language.int1,
|
||||
}
|
||||
return tys[name]
|
||||
|
||||
def mangle_ty(ty):
|
||||
if ty.is_ptr():
|
||||
return 'P' + mangle_ty(ty.element_ty)
|
||||
if ty.is_int():
|
||||
return 'i' + str(ty.int_bitwidth)
|
||||
if ty.is_fp8():
|
||||
return 'fp8'
|
||||
if ty.is_fp16():
|
||||
return 'fp16'
|
||||
if ty.is_bf16():
|
||||
return 'bf16'
|
||||
if ty.is_fp32():
|
||||
return 'fp32'
|
||||
if ty.is_fp64():
|
||||
return 'fp64'
|
||||
if ty.is_void():
|
||||
return 'V'
|
||||
if ty.is_block():
|
||||
elt = mangle_ty(ty.scalar)
|
||||
shape = '_'.join(map(str, ty.shape))
|
||||
return f'{elt}S{shape}S'
|
||||
assert False, "Unsupport type"
|
||||
|
||||
|
||||
def mangle_fn(name, arg_tys, constants):
|
||||
# doesn't mangle ret type, which must be a function of arg tys
|
||||
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
|
||||
mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
|
||||
mangled_constants = mangled_constants.replace('.', '_d_')
|
||||
mangled_constants = mangled_constants.replace("'", '_sq_')
|
||||
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
|
||||
return ret
|
||||
|
||||
class enter_sub_region:
|
||||
def __init__(self, generator: CodeGenerator):
|
||||
self.generator = generator
|
||||
|
||||
def __enter__(self):
|
||||
# record lscope & local_defs in the parent scope
|
||||
self.liveins = self.generator.lscope.copy()
|
||||
self.prev_defs = self.generator.local_defs.copy()
|
||||
self.generator.local_defs = {}
|
||||
self.insert_block = self.generator.builder.get_insertion_block()
|
||||
return self.liveins, self.insert_block
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self.generator.builder.set_insertion_point_to_end(self.insert_block)
|
||||
self.generator.lscope = self.liveins
|
||||
self.generator.local_defs = self.prev_defs
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
self.module = self.builder.create_module() if module is None else module
|
||||
self.function_ret_types = function_types
|
||||
self.prototype = prototype
|
||||
self.gscope = gscope
|
||||
self.lscope = dict()
|
||||
self.attributes = attributes
|
||||
self.constants = constants
|
||||
self.is_kernel = is_kernel
|
||||
self.last_node = None
|
||||
self.builtins = {
|
||||
'range': range,
|
||||
'min': triton.language.minimum,
|
||||
'float': float,
|
||||
'int': int,
|
||||
'print': print,
|
||||
'isinstance': isinstance,
|
||||
'getattr': getattr,
|
||||
}
|
||||
# SSA-construction
|
||||
# name => triton.language.tensor
|
||||
self.local_defs: Dict[str, triton.language.tensor] = {}
|
||||
self.global_uses: Dict[str, triton.language.tensor] = {}
|
||||
|
||||
def get_value(self, name):
|
||||
''' This function:
|
||||
1. make sure `name` is defined
|
||||
2. if `name` is triton.language.tensor, get stored tensor by calling
|
||||
`self._get_tensor()`
|
||||
'''
|
||||
# search node.id in local scope
|
||||
ret = None
|
||||
if name in self.lscope:
|
||||
ret = self.lscope[name]
|
||||
if name not in self.local_defs:
|
||||
self.global_uses[name] = ret
|
||||
# search node.id in global scope
|
||||
elif name in self.gscope:
|
||||
ret = self.gscope[name]
|
||||
# search node.id in builtins
|
||||
elif name in self.builtins:
|
||||
ret = self.builtins[name]
|
||||
else:
|
||||
raise ValueError(f'{name} is not defined')
|
||||
return ret
|
||||
|
||||
def set_value(self, name: str,
|
||||
value: Union[triton.language.tensor, triton.language.constexpr]) -> None:
|
||||
''' This function:
|
||||
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
|
||||
1. record local defined name (FIXME: should consider control flow)
|
||||
2. store tensor in self.lvalue
|
||||
'''
|
||||
self.lscope[name] = value
|
||||
self.local_defs[name] = value
|
||||
|
||||
def is_triton_tensor(self, value):
|
||||
return isinstance(value, triton.language.tensor)
|
||||
|
||||
#
|
||||
# AST visitor
|
||||
#
|
||||
def visit_compound_statement(self, stmts):
|
||||
for stmt in stmts:
|
||||
self.last_ret_type = self.visit(stmt)
|
||||
if isinstance(stmt, ast.Return):
|
||||
break
|
||||
return stmts and isinstance(stmt, ast.Return)
|
||||
|
||||
def visit_Module(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_List(self, node):
|
||||
ctx = self.visit(node.ctx)
|
||||
assert ctx is None
|
||||
elts = [self.visit(elt) for elt in node.elts]
|
||||
return elts
|
||||
|
||||
# By design, only non-kernel functions can return
|
||||
def visit_Return(self, node):
|
||||
ret_value = self.visit(node.value)
|
||||
if ret_value is None:
|
||||
self.builder.ret([])
|
||||
return None
|
||||
if isinstance(ret_value, tuple):
|
||||
ret_values = [triton.language.core._to_tensor(v, self.builder) for v in ret_value]
|
||||
ret_types = [v.type for v in ret_values]
|
||||
self.builder.ret([v.handle for v in ret_values])
|
||||
return tuple(ret_types)
|
||||
else:
|
||||
ret = triton.language.core._to_tensor(ret_value, self.builder)
|
||||
self.builder.ret([ret_value.handle])
|
||||
return ret.type
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
# initialize defaults
|
||||
for i, default_value in enumerate(node.args.defaults):
|
||||
arg_node = node.args.args[-i - 1]
|
||||
annotation = arg_node.annotation
|
||||
name = arg_node.arg
|
||||
st_target = ast.Name(id=name, ctx=ast.Store())
|
||||
if annotation is None:
|
||||
init_node = ast.Assign(targets=[st_target], value=default_value)
|
||||
else:
|
||||
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
||||
self.visit(init_node)
|
||||
# initialize function
|
||||
fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants)
|
||||
fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder))
|
||||
self.module.push_back(fn)
|
||||
entry = fn.add_entry_block()
|
||||
arg_values = []
|
||||
idx = 0
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if i in self.constants:
|
||||
cst = self.constants[i]
|
||||
if not isinstance(cst, triton.language.constexpr):
|
||||
cst = triton.language.constexpr(self.constants[i])
|
||||
arg_values.append(cst)
|
||||
else:
|
||||
pass
|
||||
if i in self.attributes:
|
||||
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i])
|
||||
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
|
||||
insert_pt = self.builder.get_insertion_block()
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
self.builder.set_insertion_point_to_start(entry)
|
||||
# visit function body
|
||||
has_ret = self.visit_compound_statement(node.body)
|
||||
# finalize function
|
||||
if not has_ret:
|
||||
self.builder.ret([])
|
||||
else:
|
||||
# update return type
|
||||
if isinstance(self.last_ret_type, tuple):
|
||||
self.prototype.ret_types = list(self.last_ret_type)
|
||||
fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
else:
|
||||
self.prototype.ret_types = [self.last_ret_type]
|
||||
fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
if insert_pt:
|
||||
self.builder.set_insertion_point_to_end(insert_pt)
|
||||
|
||||
def visit_arguments(self, node):
|
||||
arg_names = []
|
||||
for arg in node.args:
|
||||
arg_names += [self.visit(arg)]
|
||||
kwarg_names = self.visit(node.kwarg)
|
||||
return arg_names, kwarg_names
|
||||
|
||||
def visit_arg(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
return node.arg
|
||||
|
||||
def visit_AnnAssign(self, node):
|
||||
# extract attributes
|
||||
annotation = self.visit(node.annotation)
|
||||
target = self.visit(node.target)
|
||||
value = self.visit(node.value)
|
||||
# constexpr
|
||||
if annotation == triton.language.constexpr:
|
||||
if target in self.lscope:
|
||||
raise ValueError(f'{target} is already defined.'
|
||||
f' constexpr cannot be reassigned.')
|
||||
if not isinstance(value, triton.language.constexpr):
|
||||
value = triton.language.constexpr(value)
|
||||
self.lscope[target] = value
|
||||
return self.lscope[target]
|
||||
# default: call visit_Assign
|
||||
return self.visit_Assign(node)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
_names = []
|
||||
for target in node.targets:
|
||||
_names += [self.visit(target)]
|
||||
assert len(_names) == 1
|
||||
names = _names[0]
|
||||
values = self.visit(node.value)
|
||||
if not isinstance(names, tuple):
|
||||
names = [names]
|
||||
if not isinstance(values, tuple):
|
||||
values = [values]
|
||||
for name, value in zip(names, values):
|
||||
# by default, constexpr are assigned into python variable
|
||||
if isinstance(value, triton.language.constexpr):
|
||||
value = value.value
|
||||
if not isinstance(value, triton.language.tensor):
|
||||
value = triton.language.core._to_tensor(value, self.builder)
|
||||
self.set_value(name, value)
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
name = node.target.id
|
||||
lhs = ast.Name(id=name, ctx=ast.Load())
|
||||
rhs = ast.BinOp(lhs, node.op, node.value)
|
||||
assign = ast.Assign(targets=[node.target], value=rhs)
|
||||
self.visit(assign)
|
||||
return self.get_value(name)
|
||||
|
||||
def visit_Name(self, node):
|
||||
if type(node.ctx) == ast.Store:
|
||||
return node.id
|
||||
return self.get_value(node.id)
|
||||
|
||||
def visit_Store(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_Load(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_Tuple(self, node):
|
||||
args = [self.visit(x) for x in node.elts]
|
||||
return tuple(args)
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
rhs = rhs.value
|
||||
fn = {
|
||||
ast.Add: '__add__',
|
||||
ast.Sub: '__sub__',
|
||||
ast.Mult: '__mul__',
|
||||
ast.Div: '__truediv__',
|
||||
ast.FloorDiv: '__floordiv__',
|
||||
ast.Mod: '__mod__',
|
||||
ast.Pow: '__pow__',
|
||||
ast.LShift: '__lshift__',
|
||||
ast.RShift: '__rshift__',
|
||||
ast.BitAnd: '__and__',
|
||||
ast.BitOr: '__or__',
|
||||
ast.BitXor: '__xor__',
|
||||
}[type(node.op)]
|
||||
if self.is_triton_tensor(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_tensor(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
return getattr(lhs, fn)(rhs)
|
||||
|
||||
def visit_If(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if isinstance(cond, triton.language.tensor):
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, ip_block = sr
|
||||
|
||||
then_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(then_block)
|
||||
self.visit_compound_statement(node.body)
|
||||
then_defs = self.local_defs.copy()
|
||||
|
||||
# when need an else block when:
|
||||
# 1. we have an orelse node
|
||||
# or
|
||||
# 2. the then block defines new variable
|
||||
if then_defs or node.orelse:
|
||||
if node.orelse:
|
||||
self.lscope = liveins
|
||||
self.local_defs = {}
|
||||
else_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_end(else_block)
|
||||
self.visit_compound_statement(node.orelse)
|
||||
else_defs = self.local_defs.copy()
|
||||
else:
|
||||
# collect else_defs
|
||||
else_defs = {}
|
||||
for name in then_defs:
|
||||
if name in liveins:
|
||||
assert self.is_triton_tensor(then_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
else_defs[name] = liveins[name]
|
||||
# collect yields
|
||||
names = []
|
||||
ret_types = []
|
||||
for then_name in then_defs:
|
||||
for else_name in else_defs:
|
||||
if then_name == else_name:
|
||||
if then_defs[then_name].type == else_defs[else_name].type:
|
||||
names.append(then_name)
|
||||
ret_types.append(then_defs[then_name].type)
|
||||
|
||||
self.builder.set_insertion_point_to_end(ip_block)
|
||||
|
||||
if then_defs or node.orelse: # with else block
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
self.builder.create_yield_op([then_defs[n].handle for n in names])
|
||||
if not node.orelse:
|
||||
else_block = if_op.get_else_block()
|
||||
else:
|
||||
else_block.merge_block_before(if_op.get_else_block())
|
||||
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||
self.builder.create_yield_op([else_defs[n].handle for n in names])
|
||||
else: # no else block
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
|
||||
# update values yielded by IfOp
|
||||
for i, name in enumerate(names):
|
||||
new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i])
|
||||
self.lscope[name] = new_tensor
|
||||
self.local_defs[name] = new_tensor
|
||||
|
||||
else:
|
||||
if isinstance(cond, triton.language.constexpr):
|
||||
cond = cond.value
|
||||
if cond:
|
||||
self.visit_compound_statement(node.body)
|
||||
else:
|
||||
self.visit_compound_statement(node.orelse)
|
||||
|
||||
def visit_IfExp(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if cond.value:
|
||||
return self.visit(node.body)
|
||||
else:
|
||||
return self.visit(node.orelse)
|
||||
|
||||
def visit_Pass(self, node):
|
||||
pass
|
||||
|
||||
def visit_Compare(self, node):
|
||||
assert len(node.comparators) == 1
|
||||
assert len(node.ops) == 1
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.comparators[0])
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
rhs = rhs.value
|
||||
if type(node.ops[0]) == ast.Is:
|
||||
return triton.language.constexpr(lhs is rhs)
|
||||
if type(node.ops[0]) == ast.IsNot:
|
||||
return triton.language.constexpr(lhs is not rhs)
|
||||
fn = {
|
||||
ast.Eq: '__eq__',
|
||||
ast.NotEq: '__ne__',
|
||||
ast.Lt: '__lt__',
|
||||
ast.LtE: '__le__',
|
||||
ast.Gt: '__gt__',
|
||||
ast.GtE: '__ge__',
|
||||
}[type(node.ops[0])]
|
||||
if self.is_triton_tensor(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_tensor(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
return getattr(lhs, fn)(rhs)
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
op = self.visit(node.operand)
|
||||
if type(node.op) == ast.Not:
|
||||
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
|
||||
return triton.language.constexpr(not op)
|
||||
if isinstance(op, triton.language.constexpr):
|
||||
op = op.value
|
||||
fn = {
|
||||
ast.USub: '__neg__',
|
||||
ast.UAdd: '__pos__',
|
||||
ast.Invert: '__invert__',
|
||||
}[type(node.op)]
|
||||
if self.is_triton_tensor(op):
|
||||
return getattr(op, fn)(_builder=self.builder)
|
||||
return getattr(op, fn)()
|
||||
|
||||
def visit_While(self, node):
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, insert_block = sr
|
||||
|
||||
# condtion (the before region)
|
||||
cond_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(cond_block)
|
||||
cond = self.visit(node.test)
|
||||
|
||||
# loop body (the after region)
|
||||
loop_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(loop_block)
|
||||
self.visit_compound_statement(node.body)
|
||||
loop_defs = self.local_defs
|
||||
|
||||
# collect loop-carried values
|
||||
names = []
|
||||
ret_types = []
|
||||
init_args = []
|
||||
yields = []
|
||||
for name in loop_defs:
|
||||
if name in liveins:
|
||||
# We should not def new constexpr
|
||||
assert self.is_triton_tensor(loop_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if loop_defs[name].type == liveins[name].type:
|
||||
# these are loop-carried values
|
||||
names.append(name)
|
||||
ret_types.append(loop_defs[name].type)
|
||||
init_args.append(liveins[name])
|
||||
yields.append(loop_defs[name])
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
|
||||
[arg.handle for arg in init_args])
|
||||
# merge the condition region
|
||||
before_block = self.builder.create_block_with_parent(while_op.get_before(),
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
cond_block.merge_block_before(before_block)
|
||||
self.builder.set_insertion_point_to_end(before_block)
|
||||
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
||||
self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
|
||||
# merge the loop body
|
||||
after_block = self.builder.create_block_with_parent(while_op.get_after(),
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
loop_block.merge_block_before(after_block)
|
||||
self.builder.set_insertion_point_to_end(after_block)
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
|
||||
# update global uses in while_op
|
||||
for i, name in enumerate(names):
|
||||
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
||||
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
||||
|
||||
# WhileOp defines new values, update the symbol table (lscope, local_defs)
|
||||
for i, name in enumerate(names):
|
||||
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
|
||||
self.lscope[name] = new_def
|
||||
self.local_defs[name] = new_def
|
||||
|
||||
for stmt in node.orelse:
|
||||
assert False, "Not implemented"
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
|
||||
def visit_Subscript(self, node):
|
||||
assert node.ctx.__class__.__name__ == "Load"
|
||||
lhs = self.visit(node.value)
|
||||
slices = self.visit(node.slice)
|
||||
if self.is_triton_tensor(lhs):
|
||||
return lhs.__getitem__(slices, _builder=self.builder)
|
||||
return lhs[slices]
|
||||
|
||||
def visit_ExtSlice(self, node):
|
||||
return [self.visit(dim) for dim in node.dims]
|
||||
|
||||
def visit_For(self, node):
|
||||
iterator = self.visit(node.iter.func)
|
||||
if iterator != self.builtins['range']:
|
||||
raise RuntimeError('Only `range` iterator currently supported')
|
||||
# static for loops: all iterator arguments are constexpr
|
||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
|
||||
if is_static:
|
||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||
iter_args = [arg.value for arg in iter_args]
|
||||
range = iterator(*iter_args)
|
||||
if len(range) <= 10:
|
||||
for i in iterator(*iter_args):
|
||||
self.lscope[node.target.id] = triton.language.constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
for stmt in node.orelse:
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
return
|
||||
|
||||
# collect lower bound (lb), upper bound (ub), and step
|
||||
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
|
||||
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
|
||||
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
|
||||
# lb/ub/step might be constexpr, we need to cast them to tensor
|
||||
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
||||
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
||||
step = triton.language.core._to_tensor(step, self.builder).handle
|
||||
# ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
|
||||
lb = self.builder.create_to_index(lb)
|
||||
ub = self.builder.create_to_index(ub)
|
||||
step = self.builder.create_to_index(step)
|
||||
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, insert_block = sr
|
||||
|
||||
# create loop body block
|
||||
block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(block)
|
||||
|
||||
# visit loop body
|
||||
self.visit_compound_statement(node.body)
|
||||
|
||||
# If a variable (name) is defined in both its parent & itself, then it's
|
||||
# a loop-carried variable. (They must be of the same type)
|
||||
init_args = []
|
||||
yields = []
|
||||
names = []
|
||||
for name in self.local_defs:
|
||||
if name in liveins:
|
||||
assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor'
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if self.local_defs[name].type == liveins[name].type:
|
||||
names.append(name)
|
||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
||||
# create ForOp
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
||||
block.merge_block_before(for_op.get_body(0))
|
||||
# create YieldOp
|
||||
self.builder.set_insertion_point_to_end(for_op.get_body(0))
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
for_op_region = for_op.get_body(0).get_parent()
|
||||
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
|
||||
# replace global uses with block arguments
|
||||
for i, name in enumerate(names):
|
||||
# arg0 is the induction variable
|
||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
||||
|
||||
# update lscope & local_defs (ForOp defines new values)
|
||||
for i, name in enumerate(names):
|
||||
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
|
||||
for stmt in node.orelse:
|
||||
assert False, "Don't know what to do with else after for"
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
|
||||
def visit_Slice(self, node):
|
||||
lower = self.visit(node.lower)
|
||||
upper = self.visit(node.upper)
|
||||
step = self.visit(node.step)
|
||||
return slice(lower, upper, step)
|
||||
|
||||
def visit_Index(self, node):
|
||||
return self.visit(node.value)
|
||||
|
||||
def visit_keyword(self, node):
|
||||
return {node.arg: self.visit(node.value)}
|
||||
|
||||
def visit_Call(self, node):
|
||||
fn = self.visit(node.func)
|
||||
if isinstance(fn, triton.language.constexpr):
|
||||
fn = fn.value
|
||||
kws = dict()
|
||||
for keyword in node.keywords:
|
||||
kws.update(self.visit(keyword))
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
from inspect import getcallargs
|
||||
args = getcallargs(fn.fn, *args, **kws)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if isinstance(arg, triton.language.tensor)
|
||||
else triton.language.constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
|
||||
constants = {i: args[i] for i in constexprs}
|
||||
# generate call
|
||||
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
|
||||
arg_vals = [arg.handle for arg in args if arg is not None]
|
||||
arg_types = [arg.type for arg in args if arg is not None]
|
||||
fn_name = mangle_fn(fn.__name__, arg_types, constants)
|
||||
# generate function def if necessary
|
||||
if not self.module.has_function(fn_name):
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type([ret_type], arg_types)
|
||||
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
else:
|
||||
callee_ret_type = self.function_ret_types[fn_name]
|
||||
symbol = self.module.get_function(fn_name)
|
||||
call_op = self.builder.call(symbol, arg_vals)
|
||||
if call_op.get_num_results() == 0:
|
||||
return None
|
||||
elif call_op.get_num_results() == 1:
|
||||
return triton.language.tensor(call_op.get_result(0), callee_ret_type)
|
||||
else:
|
||||
# should return a tuple of tl.tensor
|
||||
results = []
|
||||
for i in range(call_op.get_num_results()):
|
||||
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
|
||||
return tuple(results)
|
||||
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language.core:
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
if fn in self.builtins.values():
|
||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
||||
for arg in args]
|
||||
return fn(*args, **kws)
|
||||
|
||||
def visit_Constant(self, node):
|
||||
return triton.language.constexpr(node.value)
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
def visit_NameConstant(self, node):
|
||||
return triton.language.constexpr(node.value)
|
||||
|
||||
def visit_Num(self, node):
|
||||
return triton.language.constexpr(node.n)
|
||||
|
||||
def visit_Str(self, node):
|
||||
return triton.language.constexpr(ast.literal_eval(node))
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
def visit_Expr(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_NoneType(self, node):
|
||||
return None
|
||||
|
||||
def visit(self, node):
|
||||
if node is not None:
|
||||
self.last_node = node
|
||||
with warnings.catch_warnings():
|
||||
# The ast library added visit_Constant and deprecated some other
|
||||
# methods but we can't move to that without breaking Python 3.6 and 3.7.
|
||||
warnings.simplefilter("ignore", DeprecationWarning) # python 3.9
|
||||
warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
|
||||
return super().visit(node)
|
||||
|
||||
def generic_visit(self, node):
|
||||
typename = type(node).__name__
|
||||
raise NotImplementedError("Unsupported node: {}".format(typename))
|
||||
|
||||
|
||||
|
||||
class CompilationError(Exception):
|
||||
def __init__(self, src, node):
|
||||
self.message = f'at {node.lineno}:{node.col_offset}:\n'
|
||||
self.message += '\n'.join(src.split('\n')[:node.lineno])
|
||||
self.message += '\n' + ' ' * node.col_offset + '^'
|
||||
self.src = src
|
||||
self.node = node
|
||||
super().__init__(self.message)
|
||||
|
||||
def __reduce__(self):
|
||||
# this is necessary to make CompilationError picklable
|
||||
return (type(self), (self.src, self.node))
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
super().__init__(self.message)
|
||||
|
||||
def __reduce__(self):
|
||||
# this is necessary to make CompilationError picklable
|
||||
return (type(self), (self.required, self.limit, self.name))
|
||||
|
||||
|
||||
def make_triton_ir(fn, signature, constants = dict(), attributes = dict()):
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# create kernel prototype
|
||||
arg_types = signature.replace(' ','').split(',')
|
||||
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
|
||||
arg_types = [str_to_ty(x) for x in arg_types]
|
||||
prototype = triton.language.function_type([], arg_types)
|
||||
# visit kernel AST
|
||||
gscope = fn.__globals__.copy()
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=constants, attributes=attributes, is_kernel=True)
|
||||
try:
|
||||
generator.visit(fn.parse())
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||
raise e
|
||||
raise CompilationError(fn.src, node) from e
|
||||
ret = generator.module
|
||||
# module takes ownership of the MLIR context
|
||||
ret.context = context
|
||||
return ret
|
||||
|
||||
def make_tritongpu_ir(mod, num_warps):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_inliner_pass()
|
||||
pm.add_triton_combine_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_triton_gpu_verifier_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
def make_ptx(mod):
|
||||
# TODO
|
||||
return mod
|
||||
|
||||
def compile(fn, signature, constants = dict(), attributes = dict(), num_warps=4, num_stages=3, output = "ttgir"):
|
||||
assert output in ["ttir", "ttgir", "ptx"]
|
||||
# triton-ir
|
||||
module = make_triton_ir(fn, signature, constants, attributes)
|
||||
if output == "ttir":
|
||||
return module.str()
|
||||
# tritongpu-ir
|
||||
module = make_tritongpu_ir(module, num_warps)
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
if output == "ttgir":
|
||||
return module.str()
|
||||
# ptx
|
||||
if output == "ptx":
|
||||
return make_ptx(module)
|
||||
assert False
|
2
python/triton/runtime/__init__.py
Normal file
2
python/triton/runtime/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .jit import JITFunction, jit
|
||||
from .autotuner import Config, autotune, heuristics
|
202
python/triton/runtime/autotuner.py
Normal file
202
python/triton/runtime/autotuner.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from __future__ import annotations
|
||||
import builtins
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
|
||||
|
||||
class Autotuner:
|
||||
def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
|
||||
'''
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
'''
|
||||
if not configs:
|
||||
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
|
||||
else:
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
self.cache = dict()
|
||||
self.kernel = kernel
|
||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||
self.hook = lambda args: 0
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
|
||||
def _hook(args):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
# prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||
if 'early_config_prune' in prune_configs_by:
|
||||
early_config_prune = prune_configs_by['early_config_prune']
|
||||
else:
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.early_config_prune = early_config_prune
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
# as kwargs and by the autotuner
|
||||
conflicts = meta.keys() & config.kwargs.keys()
|
||||
if conflicts:
|
||||
raise ValueError(
|
||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols."
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
self.hook(args)
|
||||
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
return triton.testing.do_bench(kernel_call)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
if len(self.configs) > 1:
|
||||
key = tuple([args[i] for i in self.key_idx])
|
||||
if key not in self.cache:
|
||||
# prune configs
|
||||
pruned_configs = self.configs
|
||||
if self.early_config_prune:
|
||||
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||
if self.perf_model:
|
||||
top_k = self.configs_top_k
|
||||
if isinstance(top_k, float) and top_k <= 1.0:
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
self.hook(args)
|
||||
self.configs_timings = timings
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
self.best_config = config
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(self.nargs)
|
||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
An object that represents a possible kernel configuration for the auto-tuner to try.
|
||||
|
||||
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
||||
:type meta: dict[Str, Any]
|
||||
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
|
||||
`num_warps=8`, then each kernel instance will be automatically parallelized to
|
||||
cooperatively execute using `8 * 32 = 256` threads.
|
||||
:type num_warps: int
|
||||
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
||||
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
||||
:type num_stages: int
|
||||
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
||||
function are args.
|
||||
"""
|
||||
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
||||
self.kwargs = kwargs
|
||||
self.num_warps = num_warps
|
||||
self.num_stages = num_stages
|
||||
self.pre_hook = pre_hook
|
||||
|
||||
def __str__(self):
|
||||
res = []
|
||||
for k, v in self.kwargs.items():
|
||||
res.append(f'{k}: {v}')
|
||||
res.append(f'num_warps: {self.num_warps}')
|
||||
res.append(f'num_stages: {self.num_stages}')
|
||||
return ', '.join(res)
|
||||
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
],
|
||||
key=['x_size'] # the two above configs will be evaluated anytime
|
||||
# the value of x_size changes
|
||||
)
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
reset the value of the provided tensor to `zero` before running any configuration.
|
||||
|
||||
:param configs: a list of :code:`triton.Config` objects
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
:type key: list[str]
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
def decorator(fn):
|
||||
def wrapper(kernel):
|
||||
return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
|
||||
|
||||
fn.kernel_decorators.append(wrapper)
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def heuristics(values):
|
||||
"""
|
||||
Decorator for specifying how the values of certain meta-parameters may be computed.
|
||||
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
|
||||
|
||||
|
||||
.param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
||||
each such function takes a list of positional arguments as input.
|
||||
.type values: dict[str, Callable[[list[Any]], Any]]
|
||||
"""
|
||||
def decorator(fn):
|
||||
def wrapper(kernel):
|
||||
def fun(*args, **meta):
|
||||
for v, heur in values.items():
|
||||
assert v not in meta
|
||||
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
|
||||
return kernel(*args, **meta)
|
||||
return fun
|
||||
|
||||
fn.kernel_decorators.append(wrapper)
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
268
python/triton/runtime/jit.py
Normal file
268
python/triton/runtime/jit.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from ..tools.disasm import extract
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Binary
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class Binary:
|
||||
def __init__(self, backend, name, asm, shared_mem, num_warps):
|
||||
self.backend = backend
|
||||
self.name = name
|
||||
self.asm = asm
|
||||
self.shared_mem = shared_mem
|
||||
self.num_warps = num_warps
|
||||
|
||||
|
||||
class LoadedBinary:
|
||||
def __init__(self, device: int, bin: Binary):
|
||||
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
||||
bin.name,
|
||||
bin.asm,
|
||||
bin.shared_mem,
|
||||
device)
|
||||
self.bin = bin
|
||||
self.asm = bin.asm
|
||||
self.sass = ''
|
||||
self.module = module
|
||||
self.kernel = kernel
|
||||
self.device = device
|
||||
self.shared_mem = bin.shared_mem
|
||||
|
||||
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
||||
_triton.runtime.enqueue(self.bin.backend, stream, self.kernel,
|
||||
grid_0, grid_1, grid_2,
|
||||
self.bin.num_warps * 32, 1, 1,
|
||||
args, self.bin.shared_mem)
|
||||
|
||||
def get_sass(self, fun=None):
|
||||
if self.sass:
|
||||
return self.sass
|
||||
fd, path = tempfile.mkstemp()
|
||||
try:
|
||||
with open(fd, 'wb') as cubin:
|
||||
cubin.write(self.asm['cubin'])
|
||||
self.sass = extract(path, fun)
|
||||
finally:
|
||||
os.remove(path)
|
||||
self.asm['sass'] = self.sass
|
||||
return self.sass
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Kernel
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class Kernel:
|
||||
|
||||
def __call__(self, *args, grid, num_warps=4, num_stages=3, **kwargs):
|
||||
raise RuntimeError("Not implemented. Public repo implementation will be rewritten to reduce latency.")
|
||||
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies Finder
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
class DependenciesFinder(ast.NodeVisitor):
|
||||
"""
|
||||
This AST visitor is used to find dependencies of a JITFunction. This can
|
||||
be used to invalidate a JITFunction's hash when its source code -- or
|
||||
that of its dependencies -- changes.
|
||||
"""
|
||||
|
||||
def __init__(self, globals, src) -> None:
|
||||
super().__init__()
|
||||
self.ret = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
self.globals = globals
|
||||
|
||||
def visit_Name(self, node):
|
||||
return self.globals.get(node.id, None)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
lhs = self.visit(lhs.value)
|
||||
if lhs is None or lhs is triton:
|
||||
return None
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
def visit_Call(self, node):
|
||||
func = self.visit(node.func)
|
||||
if func is None:
|
||||
return
|
||||
if inspect.isbuiltin(func):
|
||||
return
|
||||
if func.__module__ and func.__module__.startswith('triton.'):
|
||||
return
|
||||
assert isinstance(func, JITFunction)
|
||||
if func.hash is None:
|
||||
tree = ast.parse(func.src)
|
||||
finder = DependenciesFinder(func.__globals__, func.src)
|
||||
finder.visit(tree)
|
||||
func.hash = finder.ret
|
||||
self.ret = (self.ret + func.hash).encode("utf-8")
|
||||
self.ret = hashlib.md5(self.ret).hexdigest()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# JITFunction
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@functools.lru_cache()
|
||||
def version_key():
|
||||
import pkgutil
|
||||
contents = []
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# backend
|
||||
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# language
|
||||
language_path = os.path.join(*triton.__path__, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# ptxas version
|
||||
try:
|
||||
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
||||
except Exception:
|
||||
ptxas_version = ''
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
class JITFunction:
|
||||
|
||||
cache_hook = None
|
||||
|
||||
def __init__(self, fn, version=None, inline=True, do_not_specialize=None):
|
||||
# information of wrapped function
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
|
||||
self.version = version
|
||||
self.inline = inline
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize]
|
||||
# cache for callable driver objects (e.g. CUkernel)
|
||||
self.bin_cache = dict()
|
||||
self.hash = None
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
# annotations
|
||||
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
||||
self.__annotations__ = fn.__annotations__
|
||||
# constexprs
|
||||
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
|
||||
# forward docs
|
||||
self.__doc__ = fn.__doc__
|
||||
self.__name__ = fn.__name__
|
||||
self.__globals__ = fn.__globals__
|
||||
self.__module__ = fn.__module__
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def cache_key(self):
|
||||
# TODO : hash should be attribute of `self`
|
||||
if self.hash is None:
|
||||
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
|
||||
dependencies_finder.visit(self.parse())
|
||||
self.hash = dependencies_finder.ret + version_key()
|
||||
return self.hash
|
||||
|
||||
# we do not parse `src` in the constructor because
|
||||
# the user might want to monkey-patch self.src dynamically.
|
||||
# Some unit tests do this, for example.
|
||||
def parse(self):
|
||||
tree = ast.parse(self.src)
|
||||
assert isinstance(tree, ast.Module)
|
||||
assert len(tree.body) == 1
|
||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||
return tree
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
||||
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
# - when kernel decorators change, cached kernel
|
||||
# needs to be cleared
|
||||
def __setattr__(self, name, value):
|
||||
if name == 'kernel_decorators':
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
if name == 'src':
|
||||
self.hash = None
|
||||
JITFunction.cache_key.fget.cache_clear()
|
||||
|
||||
def _init_kernel(self):
|
||||
if self.kernel is None:
|
||||
self.kernel = Kernel(self)
|
||||
for decorator in reversed(self.kernel_decorators):
|
||||
self.kernel = decorator(self.kernel)
|
||||
return self.kernel
|
||||
|
||||
def __getitem__(self, grid):
|
||||
"""
|
||||
A JIT function is launched with: fn[grid](*args, **kwargs).
|
||||
Hence JITFunction.__getitem__ returns a callable proxy that
|
||||
memorizes the grid.
|
||||
"""
|
||||
class Launcher:
|
||||
def __init__(self, kernel, grid):
|
||||
self.kernel = kernel
|
||||
self.grid = grid
|
||||
|
||||
def __call__(self, *wargs, **kwargs):
|
||||
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
||||
|
||||
return Launcher(self._init_kernel(), grid)
|
||||
|
||||
def __repr__(self):
|
||||
return f"JITFunction({self.module}:{self.fn.__name__})"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# `jit` decorator
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def jit(*args, **kwargs):
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
|
||||
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||
|
||||
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||
|
||||
* python primitives,
|
||||
* objects within the triton.language package,
|
||||
* arguments to this function,
|
||||
* other jit'd functions
|
||||
|
||||
:param fn: the function to be jit-compiled
|
||||
:type fn: Callable
|
||||
"""
|
||||
if args:
|
||||
assert len(args) == 1
|
||||
assert callable(args[0])
|
||||
return JITFunction(args[0], **kwargs)
|
||||
else:
|
||||
def decorator(fn):
|
||||
return JITFunction(fn, **kwargs)
|
||||
return decorator
|
@@ -5,7 +5,7 @@ import sys
|
||||
import torch
|
||||
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .code_gen import OutOfResources
|
||||
from .compiler import OutOfResources
|
||||
|
||||
try:
|
||||
import triton._C.libtriton.cutlass as _cutlass
|
||||
|
46
python/triton/utils.py
Normal file
46
python/triton/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
|
||||
|
||||
def cdiv(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
"""Return the smallest power of 2 greater than or equal to n"""
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n += 1
|
||||
return n
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
self.base = base
|
||||
self.is_cuda = base.is_cuda
|
||||
self.device = base.device
|
||||
|
||||
def data_ptr(self):
|
||||
return self.base.data_ptr()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'TensorWrapper[{self.dtype}]({self.base})'
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
if isinstance(tensor, TensorWrapper):
|
||||
if dtype == tensor.base.dtype:
|
||||
# Reinterpreting to the original interpretation; return the base.
|
||||
return tensor.base
|
||||
else:
|
||||
# Reinterpreting a wrapped tensor to a different type.
|
||||
return TensorWrapper(tensor.base, dtype)
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
@@ -1,261 +0,0 @@
|
||||
============================= test session starts ==============================
|
||||
platform linux -- Python 3.7.7, pytest-7.1.1, pluggy-1.0.0
|
||||
rootdir: /home/da/codes/triton-mlir-rewrite/triton/rewrite-test
|
||||
collected 6 items
|
||||
|
||||
scf_tests.py .....module {
|
||||
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = call @"cdiv__i32__1cconstexpr[64]"(%arg3) : (i32) -> i32
|
||||
%2 = call @"cdiv__i32__1cconstexpr[64]"(%arg4) : (i32) -> i32
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%3 = arith.muli %2, %c8_i32 : i32
|
||||
%4 = arith.divsi %0, %3 : i32
|
||||
%c8_i32_0 = arith.constant 8 : i32
|
||||
%5 = arith.muli %4, %c8_i32_0 : i32
|
||||
%6 = arith.subi %1, %5 : i32
|
||||
%7 = call @"minimum__i32__1cconstexpr[8]"(%6) : (i32) -> i32
|
||||
%8 = arith.remsi %0, %7 : i32
|
||||
%9 = arith.addi %5, %8 : i32
|
||||
%10 = arith.remsi %0, %3 : i32
|
||||
%11 = arith.divsi %10, %7 : i32
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%12 = arith.muli %9, %c64_i32 : i32
|
||||
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%14 = tt.broadcast %12 : (i32) -> tensor<64xi32>
|
||||
%15 = arith.addi %14, %13 : tensor<64xi32>
|
||||
%c64_i32_1 = arith.constant 64 : i32
|
||||
%16 = arith.muli %11, %c64_i32_1 : i32
|
||||
%17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%18 = tt.broadcast %16 : (i32) -> tensor<64xi32>
|
||||
%19 = arith.addi %18, %17 : tensor<64xi32>
|
||||
%20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
|
||||
%21 = tt.reshape %15 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%22 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
|
||||
%23 = arith.muli %21, %22 : tensor<64x1xi32>
|
||||
%24 = tt.reshape %20 : (tensor<32xi32>) -> tensor<1x32xi32>
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%25 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32>
|
||||
%26 = arith.muli %24, %25 : tensor<1x32xi32>
|
||||
%27 = tt.broadcast %23 : (tensor<64x1xi32>) -> tensor<64x32xi32>
|
||||
%28 = tt.broadcast %26 : (tensor<1x32xi32>) -> tensor<64x32xi32>
|
||||
%29 = arith.addi %27, %28 : tensor<64x32xi32>
|
||||
%30 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
|
||||
%31 = tt.getelementptr %30, %29, : tensor<64x32x!tt.ptr<f16>>
|
||||
%32 = tt.reshape %20 : (tensor<32xi32>) -> tensor<32x1xi32>
|
||||
%33 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
|
||||
%34 = arith.muli %32, %33 : tensor<32x1xi32>
|
||||
%35 = tt.reshape %19 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%c1_i32_2 = arith.constant 1 : i32
|
||||
%36 = tt.broadcast %c1_i32_2 : (i32) -> tensor<1x64xi32>
|
||||
%37 = arith.muli %35, %36 : tensor<1x64xi32>
|
||||
%38 = tt.broadcast %34 : (tensor<32x1xi32>) -> tensor<32x64xi32>
|
||||
%39 = tt.broadcast %37 : (tensor<1x64xi32>) -> tensor<32x64xi32>
|
||||
%40 = arith.addi %38, %39 : tensor<32x64xi32>
|
||||
%41 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
|
||||
%42 = tt.getelementptr %41, %40, : tensor<32x64x!tt.ptr<f16>>
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%43 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c32_i32 = arith.constant 32 : i32
|
||||
%44 = arith.index_cast %c0_i32 : i32 to index
|
||||
%45 = arith.index_cast %arg5 : i32 to index
|
||||
%46 = arith.index_cast %c32_i32 : i32 to index
|
||||
%47:3 = scf.for %arg9 = %44 to %45 step %46 iter_args(%arg10 = %43, %arg11 = %31, %arg12 = %42) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
|
||||
%cst_6 = arith.constant dense<true> : tensor<64x32xi1>
|
||||
%cst_7 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
|
||||
%77 = tt.load %arg11, %cst_6, %cst_7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
|
||||
%cst_8 = arith.constant dense<true> : tensor<32x64xi1>
|
||||
%cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
|
||||
%78 = tt.load %arg12, %cst_8, %cst_9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
|
||||
%cst_10 = arith.constant 0.000000e+00 : f32
|
||||
%79 = tt.broadcast %cst_10 : (f32) -> tensor<64x64xf32>
|
||||
%80 = tt.dot %77, %78, %79 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
|
||||
%81 = arith.addf %arg10, %80 : tensor<64x64xf32>
|
||||
%c32_i32_11 = arith.constant 32 : i32
|
||||
%82 = tt.broadcast %c32_i32_11 : (i32) -> tensor<64x32xi32>
|
||||
%83 = tt.getelementptr %arg11, %82, : tensor<64x32x!tt.ptr<f16>>
|
||||
%c32_i32_12 = arith.constant 32 : i32
|
||||
%84 = arith.muli %arg7, %c32_i32_12 : i32
|
||||
%85 = tt.broadcast %84 : (i32) -> tensor<32x64xi32>
|
||||
%86 = tt.getelementptr %arg12, %85, : tensor<32x64x!tt.ptr<f16>>
|
||||
scf.yield %81, %83, %86 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
|
||||
}
|
||||
%48 = arith.truncf %47#0 : tensor<64x64xf32> to tensor<64x64xf16>
|
||||
%c64_i32_3 = arith.constant 64 : i32
|
||||
%49 = arith.muli %9, %c64_i32_3 : i32
|
||||
%50 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%51 = tt.broadcast %49 : (i32) -> tensor<64xi32>
|
||||
%52 = arith.addi %51, %50 : tensor<64xi32>
|
||||
%c64_i32_4 = arith.constant 64 : i32
|
||||
%53 = arith.muli %11, %c64_i32_4 : i32
|
||||
%54 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%55 = tt.broadcast %53 : (i32) -> tensor<64xi32>
|
||||
%56 = arith.addi %55, %54 : tensor<64xi32>
|
||||
%57 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%58 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
|
||||
%59 = arith.muli %58, %57 : tensor<64x1xi32>
|
||||
%60 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
|
||||
%61 = tt.getelementptr %60, %59, : tensor<64x1x!tt.ptr<f16>>
|
||||
%62 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%c1_i32_5 = arith.constant 1 : i32
|
||||
%63 = tt.broadcast %c1_i32_5 : (i32) -> tensor<1x64xi32>
|
||||
%64 = arith.muli %62, %63 : tensor<1x64xi32>
|
||||
%65 = tt.broadcast %61 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
|
||||
%66 = tt.broadcast %64 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||
%67 = tt.getelementptr %65, %66, : tensor<64x64x!tt.ptr<f16>>
|
||||
%68 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%69 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
|
||||
%70 = arith.cmpi slt, %68, %69 : tensor<64x1xi32>
|
||||
%71 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%72 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
|
||||
%73 = arith.cmpi slt, %71, %72 : tensor<1x64xi32>
|
||||
%74 = tt.broadcast %70 : (tensor<64x1xi1>) -> tensor<64x64xi1>
|
||||
%75 = tt.broadcast %73 : (tensor<1x64xi1>) -> tensor<64x64xi1>
|
||||
%76 = arith.andi %74, %75 : tensor<64x64xi1>
|
||||
tt.store %67, %48, %76, : tensor<64x64xf16>
|
||||
return
|
||||
}
|
||||
func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%0 = arith.addi %arg0, %c64_i32 : i32
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%1 = arith.subi %0, %c1_i32 : i32
|
||||
%c64_i32_0 = arith.constant 64 : i32
|
||||
%2 = arith.divsi %1, %c64_i32_0 : i32
|
||||
return %2 : i32
|
||||
}
|
||||
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
|
||||
%c8_i32_0 = arith.constant 8 : i32
|
||||
%1 = select %0, %arg0, %c8_i32_0 : i32
|
||||
return %1 : i32
|
||||
}
|
||||
}
|
||||
|
||||
module {
|
||||
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%c63_i32 = arith.constant 63 : i32
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%c0 = arith.constant 0 : index
|
||||
%c32 = arith.constant 32 : index
|
||||
%cst_0 = arith.constant dense<true> : tensor<64x32xi1>
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
|
||||
%cst_2 = arith.constant dense<true> : tensor<32x64xi1>
|
||||
%cst_3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
|
||||
%c32_i32 = arith.constant 32 : i32
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.addi %arg3, %c63_i32 : i32
|
||||
%2 = arith.divsi %1, %c64_i32 : i32
|
||||
%3 = arith.addi %arg4, %c63_i32 : i32
|
||||
%4 = arith.divsi %3, %c64_i32 : i32
|
||||
%5 = arith.muli %4, %c8_i32 : i32
|
||||
%6 = arith.divsi %0, %5 : i32
|
||||
%7 = arith.muli %6, %c8_i32 : i32
|
||||
%8 = arith.subi %2, %7 : i32
|
||||
%9 = arith.cmpi slt, %8, %c8_i32 : i32
|
||||
%10 = select %9, %8, %c8_i32 : i32
|
||||
%11 = arith.remsi %0, %10 : i32
|
||||
%12 = arith.addi %7, %11 : i32
|
||||
%13 = arith.remsi %0, %5 : i32
|
||||
%14 = arith.divsi %13, %10 : i32
|
||||
%15 = arith.muli %12, %c64_i32 : i32
|
||||
%16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%17 = tt.broadcast %15 : (i32) -> tensor<64xi32>
|
||||
%18 = arith.addi %17, %16 : tensor<64xi32>
|
||||
%19 = arith.muli %14, %c64_i32 : i32
|
||||
%20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%21 = tt.broadcast %19 : (i32) -> tensor<64xi32>
|
||||
%22 = arith.addi %21, %20 : tensor<64xi32>
|
||||
%23 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
|
||||
%24 = tt.reshape %18 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%25 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
|
||||
%26 = arith.muli %24, %25 : tensor<64x1xi32>
|
||||
%27 = tt.reshape %23 : (tensor<32xi32>) -> tensor<1x32xi32>
|
||||
%28 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32>
|
||||
%29 = arith.muli %27, %28 : tensor<1x32xi32>
|
||||
%30 = tt.broadcast %26 : (tensor<64x1xi32>) -> tensor<64x32xi32>
|
||||
%31 = tt.broadcast %29 : (tensor<1x32xi32>) -> tensor<64x32xi32>
|
||||
%32 = arith.addi %30, %31 : tensor<64x32xi32>
|
||||
%33 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
|
||||
%34 = tt.getelementptr %33, %32, : tensor<64x32x!tt.ptr<f16>>
|
||||
%35 = tt.reshape %23 : (tensor<32xi32>) -> tensor<32x1xi32>
|
||||
%36 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
|
||||
%37 = arith.muli %35, %36 : tensor<32x1xi32>
|
||||
%38 = tt.reshape %22 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%39 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32>
|
||||
%40 = arith.muli %38, %39 : tensor<1x64xi32>
|
||||
%41 = tt.broadcast %37 : (tensor<32x1xi32>) -> tensor<32x64xi32>
|
||||
%42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<32x64xi32>
|
||||
%43 = arith.addi %41, %42 : tensor<32x64xi32>
|
||||
%44 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
|
||||
%45 = tt.getelementptr %44, %43, : tensor<32x64x!tt.ptr<f16>>
|
||||
%46 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
|
||||
%47 = arith.index_cast %arg5 : i32 to index
|
||||
%48:3 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
|
||||
%78 = tt.load %arg11, %cst_0, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
|
||||
%79 = tt.load %arg12, %cst_2, %cst_3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
|
||||
%80 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
|
||||
%81 = tt.dot %78, %79, %80 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
|
||||
%82 = arith.addf %arg10, %81 : tensor<64x64xf32>
|
||||
%83 = tt.broadcast %c32_i32 : (i32) -> tensor<64x32xi32>
|
||||
%84 = tt.getelementptr %arg11, %83, : tensor<64x32x!tt.ptr<f16>>
|
||||
%85 = arith.muli %arg7, %c32_i32 : i32
|
||||
%86 = tt.broadcast %85 : (i32) -> tensor<32x64xi32>
|
||||
%87 = tt.getelementptr %arg12, %86, : tensor<32x64x!tt.ptr<f16>>
|
||||
scf.yield %82, %84, %87 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
|
||||
}
|
||||
%49 = arith.truncf %48#0 : tensor<64x64xf32> to tensor<64x64xf16>
|
||||
%50 = arith.muli %12, %c64_i32 : i32
|
||||
%51 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%52 = tt.broadcast %50 : (i32) -> tensor<64xi32>
|
||||
%53 = arith.addi %52, %51 : tensor<64xi32>
|
||||
%54 = arith.muli %14, %c64_i32 : i32
|
||||
%55 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%56 = tt.broadcast %54 : (i32) -> tensor<64xi32>
|
||||
%57 = arith.addi %56, %55 : tensor<64xi32>
|
||||
%58 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%59 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
|
||||
%60 = arith.muli %59, %58 : tensor<64x1xi32>
|
||||
%61 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
|
||||
%62 = tt.getelementptr %61, %60, : tensor<64x1x!tt.ptr<f16>>
|
||||
%63 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%64 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32>
|
||||
%65 = arith.muli %63, %64 : tensor<1x64xi32>
|
||||
%66 = tt.broadcast %62 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
|
||||
%67 = tt.broadcast %65 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||
%68 = tt.getelementptr %66, %67, : tensor<64x64x!tt.ptr<f16>>
|
||||
%69 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%70 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
|
||||
%71 = arith.cmpi slt, %69, %70 : tensor<64x1xi32>
|
||||
%72 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%73 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
|
||||
%74 = arith.cmpi slt, %72, %73 : tensor<1x64xi32>
|
||||
%75 = tt.broadcast %71 : (tensor<64x1xi1>) -> tensor<64x64xi1>
|
||||
%76 = tt.broadcast %74 : (tensor<1x64xi1>) -> tensor<64x64xi1>
|
||||
%77 = arith.andi %75, %76 : tensor<64x64xi1>
|
||||
tt.store %68, %49, %77, : tensor<64x64xf16>
|
||||
return
|
||||
}
|
||||
func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 {
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%c63_i32 = arith.constant 63 : i32
|
||||
%0 = arith.addi %arg0, %c63_i32 : i32
|
||||
%1 = arith.divsi %0, %c64_i32 : i32
|
||||
return %1 : i32
|
||||
}
|
||||
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
|
||||
%1 = select %0, %arg0, %c8_i32 : i32
|
||||
return %1 : i32
|
||||
}
|
||||
}
|
||||
|
||||
.
|
||||
|
||||
============================== 6 passed in 1.21s ===============================
|
@@ -1,51 +0,0 @@
|
||||
import triton
|
||||
|
||||
@triton.jit
|
||||
def if_else(lb, ub, value):
|
||||
if value > lb:
|
||||
a = 0.0
|
||||
else:
|
||||
a = 1.0
|
||||
c = a + a
|
||||
|
||||
@triton.jit
|
||||
def only_if(lb, ub, value):
|
||||
a = -1.0
|
||||
if value > lb:
|
||||
a = 0.0
|
||||
c = a + a
|
||||
|
||||
@triton.jit
|
||||
def only_if_invalid(lb, ub, value):
|
||||
if value > lb:
|
||||
a = 0.0
|
||||
c = a + a
|
||||
|
||||
@triton.jit
|
||||
def nested_if(lb, ub, value):
|
||||
if value > lb:
|
||||
if value < ub:
|
||||
a = 2.0
|
||||
else:
|
||||
a = 1.0
|
||||
else:
|
||||
a = 0.0
|
||||
c = a + a
|
||||
|
||||
|
||||
mod_if_else, ctx_if_else = if_else.compile_to_ttir(2, 4, 3, grid=(1,))
|
||||
mod_if_else.dump()
|
||||
|
||||
mod_only_if, ctx_only_if = only_if.compile_to_ttir(2, 4, 3, grid=(1,))
|
||||
mod_only_if.dump()
|
||||
|
||||
try:
|
||||
mod_only_if_invalid, ctx_only_if = only_if_invalid.compile_to_ttir(2, 4, 3, grid=(1,))
|
||||
mod_only_if_invalid.dump()
|
||||
except:
|
||||
print('value error')
|
||||
|
||||
mod_nested_if, ctx_nested_if = nested_if.compile_to_ttir(2, 4, 3, grid=(1,))
|
||||
mod_nested_if.dump()
|
||||
|
||||
print(mod_nested_if.str())
|
@@ -1,261 +0,0 @@
|
||||
module {
|
||||
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c128_13c128_14c128_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = call @"cdiv__i32__1cconstexpr[128]"(%arg3) : (i32) -> i32
|
||||
%2 = call @"cdiv__i32__1cconstexpr[128]"(%arg4) : (i32) -> i32
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%3 = arith.muli %2, %c8_i32 : i32
|
||||
%4 = arith.divsi %0, %3 : i32
|
||||
%c8_i32_0 = arith.constant 8 : i32
|
||||
%5 = arith.muli %4, %c8_i32_0 : i32
|
||||
%6 = arith.subi %1, %5 : i32
|
||||
%7 = call @"minimum__i32__1cconstexpr[8]"(%6) : (i32) -> i32
|
||||
%8 = arith.remsi %0, %7 : i32
|
||||
%9 = arith.addi %5, %8 : i32
|
||||
%10 = arith.remsi %0, %3 : i32
|
||||
%11 = arith.divsi %10, %7 : i32
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
%12 = arith.muli %9, %c128_i32 : i32
|
||||
%13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
%14 = tt.broadcast %12 : (i32) -> tensor<128xi32>
|
||||
%15 = arith.addi %14, %13 : tensor<128xi32>
|
||||
%c128_i32_1 = arith.constant 128 : i32
|
||||
%16 = arith.muli %11, %c128_i32_1 : i32
|
||||
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
%18 = tt.broadcast %16 : (i32) -> tensor<128xi32>
|
||||
%19 = arith.addi %18, %17 : tensor<128xi32>
|
||||
%20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
%21 = tt.reshape %15 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
%22 = tt.broadcast %arg6 : (i32) -> tensor<128x1xi32>
|
||||
%23 = arith.muli %21, %22 : tensor<128x1xi32>
|
||||
%24 = tt.reshape %20 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%25 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32>
|
||||
%26 = arith.muli %24, %25 : tensor<1x128xi32>
|
||||
%27 = tt.broadcast %23 : (tensor<128x1xi32>) -> tensor<128x128xi32>
|
||||
%28 = tt.broadcast %26 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
%29 = arith.addi %27, %28 : tensor<128x128xi32>
|
||||
%30 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>>
|
||||
%31 = tt.getelementptr %30, %29, : tensor<128x128x!tt.ptr<f16>>
|
||||
%32 = tt.reshape %20 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
%33 = tt.broadcast %arg7 : (i32) -> tensor<128x1xi32>
|
||||
%34 = arith.muli %32, %33 : tensor<128x1xi32>
|
||||
%35 = tt.reshape %19 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
%c1_i32_2 = arith.constant 1 : i32
|
||||
%36 = tt.broadcast %c1_i32_2 : (i32) -> tensor<1x128xi32>
|
||||
%37 = arith.muli %35, %36 : tensor<1x128xi32>
|
||||
%38 = tt.broadcast %34 : (tensor<128x1xi32>) -> tensor<128x128xi32>
|
||||
%39 = tt.broadcast %37 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
%40 = arith.addi %38, %39 : tensor<128x128xi32>
|
||||
%41 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>>
|
||||
%42 = tt.getelementptr %41, %40, : tensor<128x128x!tt.ptr<f16>>
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%43 = tt.broadcast %cst : (f32) -> tensor<128x128xf32>
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c128_i32_3 = arith.constant 128 : i32
|
||||
%44 = arith.index_cast %c0_i32 : i32 to index
|
||||
%45 = arith.index_cast %arg5 : i32 to index
|
||||
%46 = arith.index_cast %c128_i32_3 : i32 to index
|
||||
%47:3 = scf.for %arg9 = %44 to %45 step %46 iter_args(%arg10 = %43, %arg11 = %31, %arg12 = %42) -> (tensor<128x128xf32>, tensor<128x128x!tt.ptr<f16>>, tensor<128x128x!tt.ptr<f16>>) {
|
||||
%cst_7 = arith.constant dense<true> : tensor<128x128xi1>
|
||||
%cst_8 = arith.constant dense<0.000000e+00> : tensor<128x128xf16>
|
||||
%77 = tt.load %arg11, %cst_7, %cst_8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16>
|
||||
%cst_9 = arith.constant dense<true> : tensor<128x128xi1>
|
||||
%cst_10 = arith.constant dense<0.000000e+00> : tensor<128x128xf16>
|
||||
%78 = tt.load %arg12, %cst_9, %cst_10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16>
|
||||
%cst_11 = arith.constant 0.000000e+00 : f32
|
||||
%79 = tt.broadcast %cst_11 : (f32) -> tensor<128x128xf32>
|
||||
%80 = tt.dot %77, %78, %79 {allowTF32 = true} : tensor<128x128xf16> * tensor<128x128xf16> -> tensor<128x128xf32>
|
||||
%81 = arith.addf %arg10, %80 : tensor<128x128xf32>
|
||||
%c128_i32_12 = arith.constant 128 : i32
|
||||
%82 = tt.broadcast %c128_i32_12 : (i32) -> tensor<128x128xi32>
|
||||
%83 = tt.getelementptr %arg11, %82, : tensor<128x128x!tt.ptr<f16>>
|
||||
%c128_i32_13 = arith.constant 128 : i32
|
||||
%84 = arith.muli %arg7, %c128_i32_13 : i32
|
||||
%85 = tt.broadcast %84 : (i32) -> tensor<128x128xi32>
|
||||
%86 = tt.getelementptr %arg12, %85, : tensor<128x128x!tt.ptr<f16>>
|
||||
scf.yield %81, %83, %86 : tensor<128x128xf32>, tensor<128x128x!tt.ptr<f16>>, tensor<128x128x!tt.ptr<f16>>
|
||||
}
|
||||
%48 = arith.truncf %47#0 : tensor<128x128xf32> to tensor<128x128xf16>
|
||||
%c128_i32_4 = arith.constant 128 : i32
|
||||
%49 = arith.muli %9, %c128_i32_4 : i32
|
||||
%50 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
%51 = tt.broadcast %49 : (i32) -> tensor<128xi32>
|
||||
%52 = arith.addi %51, %50 : tensor<128xi32>
|
||||
%c128_i32_5 = arith.constant 128 : i32
|
||||
%53 = arith.muli %11, %c128_i32_5 : i32
|
||||
%54 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
%55 = tt.broadcast %53 : (i32) -> tensor<128xi32>
|
||||
%56 = arith.addi %55, %54 : tensor<128xi32>
|
||||
%57 = tt.reshape %52 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
%58 = tt.broadcast %arg8 : (i32) -> tensor<128x1xi32>
|
||||
%59 = arith.muli %58, %57 : tensor<128x1xi32>
|
||||
%60 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>>
|
||||
%61 = tt.getelementptr %60, %59, : tensor<128x1x!tt.ptr<f16>>
|
||||
%62 = tt.reshape %56 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
%c1_i32_6 = arith.constant 1 : i32
|
||||
%63 = tt.broadcast %c1_i32_6 : (i32) -> tensor<1x128xi32>
|
||||
%64 = arith.muli %62, %63 : tensor<1x128xi32>
|
||||
%65 = tt.broadcast %61 : (tensor<128x1x!tt.ptr<f16>>) -> tensor<128x128x!tt.ptr<f16>>
|
||||
%66 = tt.broadcast %64 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
%67 = tt.getelementptr %65, %66, : tensor<128x128x!tt.ptr<f16>>
|
||||
%68 = tt.reshape %52 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
%69 = tt.broadcast %arg3 : (i32) -> tensor<128x1xi32>
|
||||
%70 = arith.cmpi slt, %68, %69 : tensor<128x1xi32>
|
||||
%71 = tt.reshape %56 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
%72 = tt.broadcast %arg4 : (i32) -> tensor<1x128xi32>
|
||||
%73 = arith.cmpi slt, %71, %72 : tensor<1x128xi32>
|
||||
%74 = tt.broadcast %70 : (tensor<128x1xi1>) -> tensor<128x128xi1>
|
||||
%75 = tt.broadcast %73 : (tensor<1x128xi1>) -> tensor<128x128xi1>
|
||||
%76 = arith.andi %74, %75 : tensor<128x128xi1>
|
||||
tt.store %67, %48, %76, : tensor<128x128xf16>
|
||||
return
|
||||
}
|
||||
func @"cdiv__i32__1cconstexpr[128]"(%arg0: i32) -> i32 {
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
%0 = arith.addi %arg0, %c128_i32 : i32
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%1 = arith.subi %0, %c1_i32 : i32
|
||||
%c128_i32_0 = arith.constant 128 : i32
|
||||
%2 = arith.divsi %1, %c128_i32_0 : i32
|
||||
return %2 : i32
|
||||
}
|
||||
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
|
||||
%c8_i32_0 = arith.constant 8 : i32
|
||||
%1 = select %0, %arg0, %c8_i32_0 : i32
|
||||
return %1 : i32
|
||||
}
|
||||
}
|
||||
module {
|
||||
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c128_13c128_14c128_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%c127_i32 = arith.constant 127 : i32
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%c0 = arith.constant 0 : index
|
||||
%c128 = arith.constant 128 : index
|
||||
%cst_0 = arith.constant dense<true> : tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.addi %arg3, %c127_i32 : i32
|
||||
%2 = arith.divsi %1, %c128_i32 : i32
|
||||
%3 = arith.addi %arg4, %c127_i32 : i32
|
||||
%4 = arith.divsi %3, %c128_i32 : i32
|
||||
%5 = arith.muli %4, %c8_i32 : i32
|
||||
%6 = arith.divsi %0, %5 : i32
|
||||
%7 = arith.muli %6, %c8_i32 : i32
|
||||
%8 = arith.subi %2, %7 : i32
|
||||
%9 = arith.cmpi slt, %8, %c8_i32 : i32
|
||||
%10 = select %9, %8, %c8_i32 : i32
|
||||
%11 = arith.remsi %0, %10 : i32
|
||||
%12 = arith.addi %7, %11 : i32
|
||||
%13 = arith.remsi %0, %5 : i32
|
||||
%14 = arith.divsi %13, %10 : i32
|
||||
%15 = arith.muli %12, %c128_i32 : i32
|
||||
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
%17 = tt.broadcast %15 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
%18 = arith.addi %17, %16 : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
%19 = arith.muli %14, %c128_i32 : i32
|
||||
%20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
%21 = tt.broadcast %19 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
%22 = arith.addi %21, %20 : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
%23 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
%24 = tt.reshape %18 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%25 = tt.broadcast %arg6 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%26 = arith.muli %24, %25 : tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%27 = tt.reshape %23 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
|
||||
%28 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
|
||||
%29 = arith.muli %27, %28 : tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
|
||||
%30 = tt.broadcast %26 : (tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%31 = tt.broadcast %29 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%32 = arith.addi %30, %31 : tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%33 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%34 = tt.getelementptr %33, %32, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%35 = tt.reshape %23 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%36 = tt.broadcast %arg7 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%37 = arith.muli %35, %36 : tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%38 = tt.reshape %22 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
|
||||
%39 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
|
||||
%40 = arith.muli %38, %39 : tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
|
||||
%41 = tt.broadcast %37 : (tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%42 = tt.broadcast %40 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%43 = arith.addi %41, %42 : tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%44 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%45 = tt.getelementptr %44, %43, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%46 = tt.broadcast %cst : (f32) -> tensor<128x128xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%47 = arith.index_cast %arg5 : i32 to index
|
||||
%48 = "triton_gpu.copy_async"(%34, %cst_0, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>
|
||||
%49 = "triton_gpu.copy_async"(%45, %cst_0, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>
|
||||
%50 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%51 = tt.getelementptr %34, %50, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%52 = arith.muli %arg7, %c128_i32 : i32
|
||||
%53 = tt.broadcast %52 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%54 = tt.getelementptr %45, %53, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%55:8 = scf.for %arg9 = %c0 to %47 step %c128 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45, %arg13 = %48, %arg14 = %49, %arg15 = %51, %arg16 = %54, %arg17 = %c0) -> (tensor<128x128xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, index) {
|
||||
%85 = tt.dot %arg13, %arg14, %arg10 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> * tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> -> tensor<128x128xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%86 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%87 = tt.getelementptr %arg11, %86, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%88 = arith.muli %arg7, %c128_i32 : i32
|
||||
%89 = tt.broadcast %88 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%90 = tt.getelementptr %arg12, %89, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%91 = arith.addi %arg17, %c128 : index
|
||||
%92 = arith.cmpi slt, %91, %47 : index
|
||||
%93 = tt.broadcast %92 : (i1) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%94 = "triton_gpu.copy_async"(%arg15, %93, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>
|
||||
%95 = "triton_gpu.copy_async"(%arg16, %93, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>
|
||||
%96 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%97 = tt.getelementptr %arg15, %96, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%98 = arith.muli %arg7, %c128_i32 : i32
|
||||
%99 = tt.broadcast %98 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%100 = tt.getelementptr %arg16, %99, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
scf.yield %85, %87, %90, %94, %95, %97, %100, %91 : tensor<128x128xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, index
|
||||
}
|
||||
%56 = arith.truncf %55#0 : tensor<128x128xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">> to tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%57 = arith.muli %12, %c128_i32 : i32
|
||||
%58 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
%59 = tt.broadcast %57 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
%60 = arith.addi %59, %58 : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
%61 = arith.muli %14, %c128_i32 : i32
|
||||
%62 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
%63 = tt.broadcast %61 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
%64 = arith.addi %63, %62 : tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
%65 = tt.reshape %60 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%66 = tt.broadcast %arg8 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%67 = arith.muli %66, %65 : tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%68 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%69 = tt.getelementptr %68, %67, : tensor<128x1x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%70 = tt.reshape %64 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
|
||||
%71 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
|
||||
%72 = arith.muli %70, %71 : tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
|
||||
%73 = tt.broadcast %69 : (tensor<128x1x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%74 = tt.broadcast %72 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%75 = tt.getelementptr %73, %74, : tensor<128x128x!tt.ptr<f16>, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%76 = tt.reshape %60 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%77 = tt.broadcast %arg3 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%78 = "triton_gpu.cmpi"(%76, %77) {predicate = 2 : i64} : (tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>, tensor<128x1xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x1xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%79 = tt.reshape %64 : (tensor<128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
|
||||
%80 = tt.broadcast %arg4 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
|
||||
%81 = "triton_gpu.cmpi"(%79, %80) {predicate = 2 : i64} : (tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>, tensor<1x128xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>) -> tensor<1x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>
|
||||
%82 = tt.broadcast %78 : (tensor<128x1xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%83 = tt.broadcast %81 : (tensor<1x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 1, 32, order = 0, 1>">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
%84 = arith.andi %82, %83 : tensor<128x128xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
tt.store %75, %56, %84, : tensor<128x128xf16, #triton_gpu<"coalesced encoding<threadTileSize = 1, 1, blockTileSize = 32, 1, order = 0, 1>">>
|
||||
return
|
||||
}
|
||||
func @"cdiv__i32__1cconstexpr[128]"(%arg0: i32) -> i32 {
|
||||
%c128_i32 = arith.constant 128 : i32
|
||||
%c127_i32 = arith.constant 127 : i32
|
||||
%0 = arith.addi %arg0, %c127_i32 : i32
|
||||
%1 = arith.divsi %0, %c128_i32 : i32
|
||||
return %1 : i32
|
||||
}
|
||||
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
|
||||
%1 = select %0, %arg0, %c8_i32 : i32
|
||||
return %1 : i32
|
||||
}
|
||||
}
|
@@ -1,105 +0,0 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
# Matrix dimensions
|
||||
M, N, K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
||||
# by to get the element one row down (A has M rows)
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Kernel for computing the matmul C = A x B.
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
"""
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of C it should compute.
|
||||
# This is done in a grouped ordering to promote L2 data reuse
|
||||
# See above `L2 Cache Optimizations` section for details
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Create pointers for the first blocks of A and B.
|
||||
# We will advance this pointer as we move in the K direction
|
||||
# and accumulate
|
||||
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
|
||||
# see above `Pointer Arithmetics` section for details
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
# Note that for simplicity, we don't apply a mask here.
|
||||
# This means that if K is not a multiple of BLOCK_SIZE_K,
|
||||
# this will access out-of-bounds memory and produce an
|
||||
# error or (worse!) incorrect results.
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
# We accumulate along the K dimension
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
c = accumulator.to(tl.float16)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output matrix C
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
c = torch.empty((512, 512), device='cuda', dtype=torch.float16)
|
||||
|
||||
|
||||
mod, ctx = matmul_kernel.compile_to_ttir(
|
||||
a, b, c,
|
||||
512, 512, 512,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
128, 128, 128,
|
||||
8, grid=(2,),
|
||||
num_stages=4
|
||||
)
|
||||
|
||||
assert mod.verify()
|
||||
# mod.dump()
|
||||
|
||||
res = matmul_kernel.compile_ttir_to_llir(mod, ctx)
|
||||
|
||||
assert mod.verify()
|
||||
# assert res
|
||||
mod.dump()
|
@@ -1,27 +0,0 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def foo(a, b):
|
||||
max, min = maxmin(a, b)
|
||||
return max, min
|
||||
|
||||
@triton.jit
|
||||
def maxmin(a, b):
|
||||
max = tl.maximum(a, b)
|
||||
min = tl.minimum(a, b)
|
||||
return max, min
|
||||
|
||||
|
||||
mod, ctx = foo.compile_to_ttir(3, 4, grid=(1,))
|
||||
assert mod.verify()
|
||||
mod.dump()
|
||||
|
||||
|
||||
pm = _triton.ir.pass_manager(ctx)
|
||||
pm.add_inliner_pass()
|
||||
pm.run(mod)
|
||||
assert mod.verify()
|
||||
mod.dump()
|
@@ -1,47 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
x_ptr, # *Pointer* to first input vector
|
||||
y_ptr, # *Pointer* to second input vector
|
||||
output_ptr, # *Pointer* to output vector
|
||||
n_elements, # Size of the vector
|
||||
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
|
||||
# NOTE: `constexpr` so it can be used as a shape value
|
||||
):
|
||||
# There are multiple 'program's processing different data. We identify which program
|
||||
# we are here
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||
# This program will process inputs that are offset from the initial data.
|
||||
# for instance, if you had a vector of length 256 and block_size of 64, the programs
|
||||
# would each access the elements [0:64, 64:128, 128:192, 192:256].
|
||||
# Note that offsets is a list of pointers
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||
mask = offsets < n_elements
|
||||
# Load x and y from DRAM, masking out any extra elements in case the input is not a
|
||||
# multiple of the block size
|
||||
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
||||
y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
|
||||
output = x + y
|
||||
# Write x + y back to DRAM
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
size = 1024
|
||||
x = torch.rand(size, device='cuda')
|
||||
y = torch.rand(size, device='cuda')
|
||||
z = torch.empty_like(x)
|
||||
# add_kernel[(1,)](x, y, z, size, 256)
|
||||
# print(add_kernel[(1,)].kernel.compile_to_ttir())
|
||||
# print(add_kernel.annotations)
|
||||
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,))
|
||||
assert mod.verify()
|
||||
mod.dump()
|
||||
add_kernel.compile_ttir_to_llir(mod, ctx)
|
||||
mod.dump()
|
@@ -1,56 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
x_ptr, # *Pointer* to first input vector
|
||||
y_ptr, # *Pointer* to second input vector
|
||||
output_ptr, # *Pointer* to output vector
|
||||
n_elements, # Size of the vector
|
||||
K,
|
||||
stride
|
||||
# BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
|
||||
# # NOTE: `constexpr` so it can be used as a shape value
|
||||
):
|
||||
# There are multiple 'program's processing different data. We identify which program
|
||||
# we are here
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
|
||||
# This program will process inputs that are offset from the initial data.
|
||||
# for instance, if you had a vector of length 256 and block_size of 64, the programs
|
||||
# would each access the elements [0:64, 64:128, 128:192, 192:256].
|
||||
# Note that offsets is a list of pointers
|
||||
block_start = pid * 256
|
||||
offsets = block_start + tl.arange(0, 256)
|
||||
# Create a mask to guard memory operations against out-of-bounds accesses
|
||||
mask = offsets < n_elements
|
||||
|
||||
x_ptrs = x_ptr + offsets
|
||||
y_ptrs = y_ptr + offsets
|
||||
output = tl.zeros((256,), dtype=tl.float32)
|
||||
for k in range(0, K, 32):
|
||||
x = tl.load(x_ptrs, mask=mask, other=0.0)
|
||||
y = tl.load(y_ptrs, mask=mask, other=0.0)
|
||||
output += x + y
|
||||
|
||||
x_ptrs += stride
|
||||
y_ptrs += stride
|
||||
|
||||
# Write x + y back to DRAM
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
size = 1024
|
||||
x = torch.rand(size, device='cuda')
|
||||
y = torch.rand(size, device='cuda')
|
||||
z = torch.empty_like(x)
|
||||
# add_kernel[(1,)](x, y, z, size, 256)
|
||||
# print(add_kernel[(1,)].kernel.compile_to_ttir())
|
||||
mod, ctx = add_kernel.compile_to_ttir(
|
||||
x, y, z, size, 128, 8, grid=(1,), num_stages=1)
|
||||
mod.dump()
|
||||
# print(mod)
|
||||
|
||||
res = add_kernel.compile_ttir_to_llir(mod, ctx)
|
||||
|
||||
mod.dump()
|
@@ -1,82 +0,0 @@
|
||||
module {
|
||||
func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
|
||||
%3 = tt.broadcast %1 : (i32) -> tensor<256xi32>
|
||||
%4 = arith.addi %3, %2 : tensor<256xi32>
|
||||
%5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32>
|
||||
%6 = arith.cmpi slt, %4, %5 : tensor<256xi32>
|
||||
%7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
|
||||
%8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr<f32>>
|
||||
%9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
|
||||
%10 = tt.getelementptr %9, %4 : tensor<256x!tt.ptr<f32>>
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%11 = tt.broadcast %cst : (f32) -> tensor<256xf32>
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c32_i32 = arith.constant 32 : i32
|
||||
%12 = arith.index_cast %c0_i32 : i32 to index
|
||||
%13 = arith.index_cast %arg4 : i32 to index
|
||||
%14 = arith.index_cast %c32_i32 : i32 to index
|
||||
%15:3 = scf.for %arg6 = %12 to %13 step %14 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>) {
|
||||
%cst_0 = arith.constant 0.000000e+00 : f32
|
||||
%18 = tt.broadcast %cst_0 : (f32) -> tensor<256xf32>
|
||||
%19 = tt.load %arg8, %6, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32>
|
||||
%cst_1 = arith.constant 0.000000e+00 : f32
|
||||
%20 = tt.broadcast %cst_1 : (f32) -> tensor<256xf32>
|
||||
%21 = tt.load %arg9, %6, %20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32>
|
||||
%22 = arith.addf %19, %21 : tensor<256xf32>
|
||||
%23 = arith.addf %arg7, %22 : tensor<256xf32>
|
||||
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>
|
||||
%25 = tt.getelementptr %arg8, %24 : tensor<256x!tt.ptr<f32>>
|
||||
%26 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>
|
||||
%27 = tt.getelementptr %arg9, %26 : tensor<256x!tt.ptr<f32>>
|
||||
scf.yield %23, %25, %27 : tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>
|
||||
}
|
||||
%16 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
|
||||
%17 = tt.getelementptr %16, %4 : tensor<256x!tt.ptr<f32>>
|
||||
tt.store %17, %15#0, %6, : tensor<256xf32>
|
||||
return
|
||||
}
|
||||
}
|
||||
module {
|
||||
func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
|
||||
%c32 = arith.constant 32 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||
%4 = triton_gpu.convert_layout %3 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%5 = arith.addi %4, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%6 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||
%7 = triton_gpu.convert_layout %6 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%8 = "triton_gpu.cmpi"(%5, %7) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%9 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||
%10 = triton_gpu.convert_layout %9 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%11 = tt.getelementptr %10, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%12 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||
%13 = triton_gpu.convert_layout %12 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%14 = tt.getelementptr %13, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%15 = arith.index_cast %arg4 : i32 to index
|
||||
%16:3 = scf.for %arg6 = %c0 to %15 step %c32 iter_args(%arg7 = %cst, %arg8 = %11, %arg9 = %14) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) {
|
||||
%20 = tt.load %arg8, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%21 = tt.load %arg9, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%22 = arith.addf %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%23 = arith.addf %arg7, %22 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||
%25 = triton_gpu.convert_layout %24 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%26 = tt.getelementptr %arg8, %25 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%27 = tt.getelementptr %arg9, %25 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
scf.yield %23, %26, %27 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
}
|
||||
%17 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%19 = tt.getelementptr %18, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
tt.store %19, %16#0, %8, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
return
|
||||
}
|
||||
}
|
@@ -1,38 +0,0 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
@triton.jit
|
||||
def atomic(lock):
|
||||
while tl.atomic_cas(lock, 0, 1) == 1:
|
||||
pass
|
||||
|
||||
@triton.jit
|
||||
def generic_while(lb, value):
|
||||
c = -1
|
||||
while c <= 0:
|
||||
c += 1
|
||||
|
||||
# locks = torch.zeros(32, dtype=torch.int32, device='cuda')
|
||||
# mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,))
|
||||
# mod_atomic.dump()
|
||||
|
||||
# mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,))
|
||||
# mod_generic_while.dump()
|
||||
|
||||
@triton.jit
|
||||
def nested_cf(X, lb, ub, Z):
|
||||
a = 0.0
|
||||
if lb < ub:
|
||||
for z in range(0, Z):
|
||||
a += 2.0
|
||||
else:
|
||||
while a < 1.2:
|
||||
a *= 2.0
|
||||
for _ in range(0, Z, 2):
|
||||
a *= -3.3
|
||||
a -= 1.0
|
||||
|
||||
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
|
||||
assert mod.verify(), mod.str()
|
||||
mod.dump()
|
@@ -1,58 +0,0 @@
|
||||
import triton._C.libtriton.triton.ir as ir
|
||||
|
||||
ctx = ir.context()
|
||||
ctx.load_triton()
|
||||
|
||||
# TODO
|
||||
builder = ir.builder(ctx)
|
||||
|
||||
module = builder.create_module()
|
||||
|
||||
|
||||
i1_ty = builder.get_int1_ty()
|
||||
i8_ty = builder.get_int8_ty()
|
||||
i16_ty = builder.get_int16_ty()
|
||||
i32_ty = builder.get_int32_ty()
|
||||
i64_ty = builder.get_int64_ty()
|
||||
|
||||
f16_ty = builder.get_half_ty()
|
||||
|
||||
f16_ptr_ty = builder.get_ptr_ty(f16_ty, 1)
|
||||
|
||||
func_ty = builder.get_function_ty([f16_ptr_ty, f16_ptr_ty, f16_ptr_ty], [])
|
||||
func = builder.create_function('foo', func_ty)
|
||||
|
||||
module.push_back(func)
|
||||
module.set_attr("num_warps", builder.get_int32_attr(4))
|
||||
|
||||
# ...
|
||||
entry = func.add_entry_block()
|
||||
builder.set_insertion_point_to_start(entry)
|
||||
offsets = builder.create_make_range(0, 128)
|
||||
pid = builder.create_get_program_id(0)
|
||||
_128 = builder.get_int32(128)
|
||||
offset = builder.create_add(pid, _128)
|
||||
offset = builder.create_splat(offset, [128])
|
||||
offsets = builder.create_add(offset, offsets)
|
||||
|
||||
|
||||
a_ptrs = builder.create_splat(entry.arg(0), [128])
|
||||
b_ptrs = builder.create_splat(entry.arg(1), [128])
|
||||
|
||||
a_ptrs = builder.create_gep(a_ptrs, offsets)
|
||||
b_ptrs = builder.create_gep(b_ptrs, offsets)
|
||||
|
||||
a = builder.create_load(a_ptrs, ir.CACHE_MODIFIER.NONE, ir.EVICTION_POLICY.NORMAL, False)
|
||||
b = builder.create_load(b_ptrs, ir.CACHE_MODIFIER.NONE, ir.EVICTION_POLICY.NORMAL, False)
|
||||
|
||||
c = builder.create_fadd(a, b)
|
||||
c.set_attr("ieee_rounding", builder.get_bool_attr(True))
|
||||
|
||||
c_ptrs = builder.create_splat(entry.arg(2), [128])
|
||||
c_ptrs = builder.create_gep(c_ptrs, offsets)
|
||||
builder.create_store(c_ptrs, c)
|
||||
|
||||
# func.dump()
|
||||
|
||||
|
||||
module.dump()
|
@@ -1,417 +0,0 @@
|
||||
import pytest
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
def test_if():
|
||||
ref_ir = """module {
|
||||
func @only_if__i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32) {
|
||||
%cst = arith.constant -1.000000e+00 : f32
|
||||
%0 = arith.cmpi sgt, %arg2, %arg0 : i32
|
||||
%1 = scf.if %0 -> (f32) {
|
||||
%cst_0 = arith.constant 0.000000e+00 : f32
|
||||
scf.yield %cst_0 : f32
|
||||
} else {
|
||||
scf.yield %cst : f32
|
||||
}
|
||||
%2 = arith.addf %1, %1 : f32
|
||||
return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def only_if(lb, ub, value):
|
||||
a = -1.0
|
||||
if value > lb:
|
||||
a = 0.0
|
||||
c = a + a
|
||||
|
||||
mod, _ = only_if.compile_to_ttir(2, 3, 4, grid=(1,))
|
||||
generated_ir = mod.str()
|
||||
assert mod.verify()
|
||||
assert ref_ir == generated_ir
|
||||
|
||||
def test_if_else():
|
||||
ref_ir = """module {
|
||||
func @if_else__i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32) {
|
||||
%0 = arith.cmpi sgt, %arg2, %arg0 : i32
|
||||
%1 = scf.if %0 -> (f32) {
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
scf.yield %cst : f32
|
||||
} else {
|
||||
%cst = arith.constant 1.000000e+00 : f32
|
||||
scf.yield %cst : f32
|
||||
}
|
||||
%2 = arith.addf %1, %1 : f32
|
||||
return
|
||||
}
|
||||
}
|
||||
"""
|
||||
@triton.jit
|
||||
def if_else(lb, ub, value):
|
||||
if value > lb:
|
||||
a = 0.0
|
||||
else:
|
||||
a = 1.0
|
||||
c = a + a
|
||||
|
||||
mod, _ = if_else.compile_to_ttir(2, 3, 4, grid=(1,))
|
||||
generated_ir = mod.str()
|
||||
assert mod.verify()
|
||||
assert ref_ir == generated_ir
|
||||
|
||||
def test_for():
|
||||
ref_ir = """module {
|
||||
func @for_loop__i32__(%arg0: i32) {
|
||||
%cst = arith.constant 1.000000e+00 : f32
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%0 = arith.index_cast %c0_i32 : i32 to index
|
||||
%1 = arith.index_cast %arg0 : i32 to index
|
||||
%2 = arith.index_cast %c1_i32 : i32 to index
|
||||
%3 = scf.for %arg1 = %0 to %1 step %2 iter_args(%arg2 = %cst) -> (f32) {
|
||||
%cst_0 = arith.constant 1.000000e+00 : f32
|
||||
%4 = arith.addf %arg2, %cst_0 : f32
|
||||
scf.yield %4 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def for_loop(K):
|
||||
a = 1.0
|
||||
for k in range(0, K):
|
||||
a += 1.0
|
||||
|
||||
mod, _ = for_loop.compile_to_ttir(2, grid=(1,))
|
||||
generated_ir = mod.str()
|
||||
assert mod.verify()
|
||||
assert ref_ir == generated_ir
|
||||
|
||||
def test_while():
|
||||
ref_ir = """module {
|
||||
func @generic_while__i32__(%arg0: i32) {
|
||||
%c-1_i32 = arith.constant -1 : i32
|
||||
%0 = scf.while (%arg1 = %c-1_i32) : (i32) -> i32 {
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%1 = arith.cmpi sle, %arg1, %c0_i32 : i32
|
||||
scf.condition(%1) %arg1 : i32
|
||||
} do {
|
||||
^bb0(%arg1: i32):
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%1 = arith.addi %arg1, %c1_i32 : i32
|
||||
scf.yield %1 : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
"""
|
||||
@triton.jit
|
||||
def generic_while(x):
|
||||
c = -1
|
||||
while c <= 0:
|
||||
c += 1
|
||||
|
||||
mod, _ = generic_while.compile_to_ttir(2, grid=(1,))
|
||||
generated_ir = mod.str()
|
||||
assert mod.verify()
|
||||
assert ref_ir == generated_ir
|
||||
|
||||
def test_nested():
|
||||
ref_ir = """module {
|
||||
func @nested_cf__i32_i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) {
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%0 = arith.index_cast %c0_i32 : i32 to index
|
||||
%1 = arith.index_cast %arg0 : i32 to index
|
||||
%2 = arith.index_cast %c1_i32 : i32 to index
|
||||
%3 = scf.for %arg4 = %0 to %1 step %2 iter_args(%arg5 = %cst) -> (f32) {
|
||||
%5 = arith.cmpi slt, %arg1, %arg2 : i32
|
||||
%6 = scf.if %5 -> (f32) {
|
||||
%c0_i32_1 = arith.constant 0 : i32
|
||||
%c1_i32_2 = arith.constant 1 : i32
|
||||
%7 = arith.index_cast %c0_i32_1 : i32 to index
|
||||
%8 = arith.index_cast %arg3 : i32 to index
|
||||
%9 = arith.index_cast %c1_i32_2 : i32 to index
|
||||
%10 = scf.for %arg6 = %7 to %8 step %9 iter_args(%arg7 = %arg5) -> (f32) {
|
||||
%cst_3 = arith.constant 2.000000e+00 : f32
|
||||
%11 = arith.addf %arg7, %cst_3 : f32
|
||||
scf.yield %11 : f32
|
||||
}
|
||||
scf.yield %10 : f32
|
||||
} else {
|
||||
%7 = scf.while (%arg6 = %arg5) : (f32) -> f32 {
|
||||
%cst_1 = arith.constant 1.200000e+00 : f32
|
||||
%8 = arith.cmpf olt, %arg6, %cst_1 : f32
|
||||
scf.condition(%8) %arg6 : f32
|
||||
} do {
|
||||
^bb0(%arg6: f32):
|
||||
%cst_1 = arith.constant 2.000000e+00 : f32
|
||||
%8 = arith.mulf %arg6, %cst_1 : f32
|
||||
scf.yield %8 : f32
|
||||
}
|
||||
scf.yield %7 : f32
|
||||
}
|
||||
scf.yield %6 : f32
|
||||
}
|
||||
%cst_0 = arith.constant 1.000000e+00 : f32
|
||||
%4 = arith.subf %3, %cst_0 : f32
|
||||
return
|
||||
}
|
||||
}
|
||||
"""
|
||||
@triton.jit
|
||||
def nested_cf(X, lb, ub, Z):
|
||||
a = 0.0
|
||||
for x in range(0, X):
|
||||
if lb < ub:
|
||||
for z in range(0, Z):
|
||||
a += 2.0
|
||||
else:
|
||||
while a < 1.2:
|
||||
a *= 2.0
|
||||
a -= 1.0
|
||||
|
||||
mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,))
|
||||
generated_ir = mod.str()
|
||||
assert mod.verify(), generated_ir
|
||||
assert ref_ir == generated_ir
|
||||
|
||||
def test_matmul():
|
||||
ref_ir = """module {
|
||||
func @matmul_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%c64_i32 = arith.constant 64 : i32
|
||||
%1 = arith.addi %arg3, %c64_i32 : i32
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%2 = arith.subi %1, %c1_i32 : i32
|
||||
%c64_i32_0 = arith.constant 64 : i32
|
||||
%3 = arith.divsi %2, %c64_i32_0 : i32
|
||||
%c64_i32_1 = arith.constant 64 : i32
|
||||
%4 = arith.addi %arg4, %c64_i32_1 : i32
|
||||
%c1_i32_2 = arith.constant 1 : i32
|
||||
%5 = arith.subi %4, %c1_i32_2 : i32
|
||||
%c64_i32_3 = arith.constant 64 : i32
|
||||
%6 = arith.divsi %5, %c64_i32_3 : i32
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%7 = arith.muli %6, %c8_i32 : i32
|
||||
%8 = arith.divsi %0, %7 : i32
|
||||
%c8_i32_4 = arith.constant 8 : i32
|
||||
%9 = arith.muli %8, %c8_i32_4 : i32
|
||||
%10 = arith.subi %3, %9 : i32
|
||||
%c8_i32_5 = arith.constant 8 : i32
|
||||
%11 = arith.cmpi slt, %10, %c8_i32_5 : i32
|
||||
%c8_i32_6 = arith.constant 8 : i32
|
||||
%12 = select %11, %10, %c8_i32_6 : i32
|
||||
%13 = arith.remsi %0, %12 : i32
|
||||
%14 = arith.addi %9, %13 : i32
|
||||
%15 = arith.remsi %0, %7 : i32
|
||||
%16 = arith.divsi %15, %12 : i32
|
||||
%c64_i32_7 = arith.constant 64 : i32
|
||||
%17 = arith.muli %14, %c64_i32_7 : i32
|
||||
%18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%19 = tt.broadcast %17 : (i32) -> tensor<64xi32>
|
||||
%20 = arith.addi %19, %18 : tensor<64xi32>
|
||||
%c64_i32_8 = arith.constant 64 : i32
|
||||
%21 = arith.muli %16, %c64_i32_8 : i32
|
||||
%22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%23 = tt.broadcast %21 : (i32) -> tensor<64xi32>
|
||||
%24 = arith.addi %23, %22 : tensor<64xi32>
|
||||
%25 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
|
||||
%26 = tt.reshape %20 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%27 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
|
||||
%28 = arith.muli %26, %27 : tensor<64x1xi32>
|
||||
%29 = tt.reshape %25 : (tensor<32xi32>) -> tensor<1x32xi32>
|
||||
%c1_i32_9 = arith.constant 1 : i32
|
||||
%30 = tt.broadcast %c1_i32_9 : (i32) -> tensor<1x32xi32>
|
||||
%31 = arith.muli %29, %30 : tensor<1x32xi32>
|
||||
%32 = tt.broadcast %28 : (tensor<64x1xi32>) -> tensor<64x32xi32>
|
||||
%33 = tt.broadcast %31 : (tensor<1x32xi32>) -> tensor<64x32xi32>
|
||||
%34 = arith.addi %32, %33 : tensor<64x32xi32>
|
||||
%35 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
|
||||
%36 = tt.getelementptr %35, %34, : tensor<64x32x!tt.ptr<f16>>
|
||||
%37 = tt.reshape %25 : (tensor<32xi32>) -> tensor<32x1xi32>
|
||||
%38 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
|
||||
%39 = arith.muli %37, %38 : tensor<32x1xi32>
|
||||
%40 = tt.reshape %24 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%c1_i32_10 = arith.constant 1 : i32
|
||||
%41 = tt.broadcast %c1_i32_10 : (i32) -> tensor<1x64xi32>
|
||||
%42 = arith.muli %40, %41 : tensor<1x64xi32>
|
||||
%43 = tt.broadcast %39 : (tensor<32x1xi32>) -> tensor<32x64xi32>
|
||||
%44 = tt.broadcast %42 : (tensor<1x64xi32>) -> tensor<32x64xi32>
|
||||
%45 = arith.addi %43, %44 : tensor<32x64xi32>
|
||||
%46 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
|
||||
%47 = tt.getelementptr %46, %45, : tensor<32x64x!tt.ptr<f16>>
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%48 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c32_i32 = arith.constant 32 : i32
|
||||
%49 = arith.index_cast %c0_i32 : i32 to index
|
||||
%50 = arith.index_cast %arg5 : i32 to index
|
||||
%51 = arith.index_cast %c32_i32 : i32 to index
|
||||
%52:3 = scf.for %arg9 = %49 to %50 step %51 iter_args(%arg10 = %48, %arg11 = %36, %arg12 = %47) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
|
||||
%cst_14 = arith.constant dense<true> : tensor<64x32xi1>
|
||||
%cst_15 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
|
||||
%82 = tt.load %arg11, %cst_14, %cst_15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
|
||||
%cst_16 = arith.constant dense<true> : tensor<32x64xi1>
|
||||
%cst_17 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
|
||||
%83 = tt.load %arg12, %cst_16, %cst_17 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
|
||||
%cst_18 = arith.constant 0.000000e+00 : f32
|
||||
%84 = tt.broadcast %cst_18 : (f32) -> tensor<64x64xf32>
|
||||
%85 = tt.dot %82, %83, %84 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
|
||||
%86 = arith.addf %arg10, %85 : tensor<64x64xf32>
|
||||
%c32_i32_19 = arith.constant 32 : i32
|
||||
%87 = tt.broadcast %c32_i32_19 : (i32) -> tensor<64x32xi32>
|
||||
%88 = tt.getelementptr %arg11, %87, : tensor<64x32x!tt.ptr<f16>>
|
||||
%c32_i32_20 = arith.constant 32 : i32
|
||||
%89 = arith.muli %arg7, %c32_i32_20 : i32
|
||||
%90 = tt.broadcast %89 : (i32) -> tensor<32x64xi32>
|
||||
%91 = tt.getelementptr %arg12, %90, : tensor<32x64x!tt.ptr<f16>>
|
||||
scf.yield %86, %88, %91 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
|
||||
}
|
||||
%53 = arith.truncf %52#0 : tensor<64x64xf32> to tensor<64x64xf16>
|
||||
%c64_i32_11 = arith.constant 64 : i32
|
||||
%54 = arith.muli %14, %c64_i32_11 : i32
|
||||
%55 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%56 = tt.broadcast %54 : (i32) -> tensor<64xi32>
|
||||
%57 = arith.addi %56, %55 : tensor<64xi32>
|
||||
%c64_i32_12 = arith.constant 64 : i32
|
||||
%58 = arith.muli %16, %c64_i32_12 : i32
|
||||
%59 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||
%60 = tt.broadcast %58 : (i32) -> tensor<64xi32>
|
||||
%61 = arith.addi %60, %59 : tensor<64xi32>
|
||||
%62 = tt.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%63 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
|
||||
%64 = arith.muli %63, %62 : tensor<64x1xi32>
|
||||
%65 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
|
||||
%66 = tt.getelementptr %65, %64, : tensor<64x1x!tt.ptr<f16>>
|
||||
%67 = tt.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%c1_i32_13 = arith.constant 1 : i32
|
||||
%68 = tt.broadcast %c1_i32_13 : (i32) -> tensor<1x64xi32>
|
||||
%69 = arith.muli %67, %68 : tensor<1x64xi32>
|
||||
%70 = tt.broadcast %66 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
|
||||
%71 = tt.broadcast %69 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||
%72 = tt.getelementptr %70, %71, : tensor<64x64x!tt.ptr<f16>>
|
||||
%73 = tt.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%74 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
|
||||
%75 = arith.cmpi slt, %73, %74 : tensor<64x1xi32>
|
||||
%76 = tt.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||
%77 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
|
||||
%78 = arith.cmpi slt, %76, %77 : tensor<1x64xi32>
|
||||
%79 = tt.broadcast %75 : (tensor<64x1xi1>) -> tensor<64x64xi1>
|
||||
%80 = tt.broadcast %78 : (tensor<1x64xi1>) -> tensor<64x64xi1>
|
||||
%81 = arith.andi %79, %80 : tensor<64x64xi1>
|
||||
tt.store %72, %53, %81, : tensor<64x64xf16>
|
||||
return
|
||||
}
|
||||
}
|
||||
"""
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
# Matrix dimensions
|
||||
M, N, K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
||||
# by to get the element one row down (A has M rows)
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Kernel for computing the matmul C = A x B.
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
"""
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of C it should compute.
|
||||
# This is done in a grouped ordering to promote L2 data reuse
|
||||
# See above `L2 Cache Optimizations` section for details
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Create pointers for the first blocks of A and B.
|
||||
# We will advance this pointer as we move in the K direction
|
||||
# and accumulate
|
||||
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
|
||||
# see above `Pointer Arithmetics` section for details
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
# Note that for simplicity, we don't apply a mask here.
|
||||
# This means that if K is not a multiple of BLOCK_SIZE_K,
|
||||
# this will access out-of-bounds memory and produce an
|
||||
# error or (worse!) incorrect results.
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
# We accumulate along the K dimension
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
c = accumulator.to(tl.float16)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output matrix C
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
c = torch.empty((512, 512), device='cuda', dtype=torch.float16)
|
||||
|
||||
|
||||
mod, ctx = matmul_kernel.compile_to_ttir(
|
||||
a, b, c,
|
||||
512, 512, 512,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
64, 64, 32,
|
||||
8, grid=(2,)
|
||||
)
|
||||
verify = mod.verify()
|
||||
assert verify
|
||||
# assert ref_ir == mod.str()
|
||||
print(mod.str())
|
||||
|
||||
pm = _triton.ir.pass_manager(ctx)
|
||||
pm.add_inliner_pass()
|
||||
pm.run(mod)
|
||||
|
||||
verify = mod.verify()
|
||||
assert verify
|
||||
# assert ref_ir == mod.str()
|
||||
print(mod.str())
|
Reference in New Issue
Block a user