[general] fixed some warnings

This commit is contained in:
Philippe Tillet
2019-08-18 14:08:57 -07:00
parent c05445d001
commit 81571246cf
22 changed files with 75 additions and 365 deletions

View File

@@ -25,7 +25,7 @@ include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
#llvm_map_components_to_libnames(llvm_libs all)
#Default build type
# Default build type
if(NOT CMAKE_BUILD_TYPE)
message(STATUS "Default build type: Release")
set(CMAKE_BUILD_TYPE "Release")
@@ -63,7 +63,14 @@ endif()
# Triton
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cpp)
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS})
add_library(triton SHARED ${LIBTRITON_SRC} ${EIGHTCC_SRC} ${PYTHON_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS})
target_link_libraries(triton LLVM)
# Warning level
if(MSVC)
target_compile_options(triton PRIVATE /W4)
else()
target_compile_options(triton PRIVATE -Wno-unused-parameter -Wall -Wextra -pedantic)
endif()

View File

@@ -1,4 +1,4 @@
foreach(PROG dot conv shift)
foreach(PROG dot)
add_executable(${PROG} ${PROG}.cpp)
set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG})
include_directories(/usr/local/cuda/include/)

View File

@@ -1,58 +0,0 @@
#include <cstring>
#include <cstdio>
#include <sstream>
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/dnn/conv.h"
#include "triton/tools/bench.hpp"
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
// initialization
int32_t B = 16, NF = 128;
int32_t D = 1, H = 16, W = 16;
int32_t NC = 64, T = 1, R = 3, S = 3;
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
// triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "float", "float", triton::dnn::conv::FPROP, 0);
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w,
upsample_d, upsample_h, upsample_w,
"float", "float", ty, 0);
// convolution configuration
std::vector<float> hc(configuration.c_size());
std::vector<float> rc(configuration.c_size());
std::vector<float> ha(configuration.a_size());
std::vector<float> hb(configuration.b_size());
srand(0);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = (float)rand()/RAND_MAX;
for(size_t i = 0; i < hb.size(); i++)
hb[i] = (float)rand()/RAND_MAX;
for(size_t i = 0; i < hc.size(); i++)
hc[i] = 0;
rc = hc;
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
triton::driver::stream* stream = triton::driver::stream::create(context);
stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
configuration.enqueue(stream, {da, db, dc, nullptr});
stream->read(dc, true, 0, hc);
configuration.cpu_ref(rc.data(), ha.data(), hb.data());
for(size_t i = 0; i < hc.size(); i++){
if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
}
std::cout << "Pass!" << std::endl;
}

View File

@@ -3,7 +3,6 @@
#include <cstdio>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/dnn/dot.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"

View File

@@ -1,150 +0,0 @@
#include <cstring>
#include <cstdio>
#include <sstream>
#include "cuda.h"
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
#include "triton/dnn/shift.h"
#include "triton/external/half.hpp"
struct perf_t {
double triton;
double cublas;
};
perf_t do_bench(triton::driver::stream *stream,
int32_t R, int32_t S, int32_t B, int32_t F, int32_t H, int32_t W, int32_t C,
triton::dnn::op_t op, triton::dnn::layout_t layout,
std::string numeric_t) {
typedef float NumericT;
// driver variables
triton::driver::context* context = stream->context();
// random shifts
std::vector<int32_t> shift_h(C);
std::vector<int32_t> shift_w(C);
for(int32_t c = 0; c < C; c++){
shift_h[c] = rand() % R - R / 2;
shift_w[c] = rand() % S - S / 2;
}
// configuration
triton::dnn::shift shift(B, C, 1, H, W, 1, R, S, F, 1, 1,
shift_h.data(), shift_w.data(),
numeric_t, numeric_t,
op, false, layout);
// host buffers
size_t a_size = B*C*H*W;
size_t b_size = C*F;
size_t c_size = B*F*H*W;
if(op == triton::dnn::BPROP)
std::swap(a_size, c_size);
if(op == triton::dnn::WGRAD){
std::swap(b_size, c_size);
std::swap(a_size, b_size);
}
std::vector<NumericT> ha(a_size);
std::vector<NumericT> hb(b_size);
std::vector<float> hc(c_size);
std::vector<float> rc(hc.size());
// device buffers
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*sizeof(NumericT));
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*sizeof(NumericT));
// initialize host
srand(0);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = (NumericT)rand() / RAND_MAX;
for(size_t i = 0; i < hb.size(); i++)
hb[i] = (NumericT)rand() / RAND_MAX;
for(size_t i = 0; i < hc.size(); i++)
hc[i] = 0;
// initialize device
stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
// benchmark triton
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
// benchmark cublas
// NumericT alpha = 1;
// NumericT beta = 0;
// cublasGemmAlgo_t fastest;
// cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
// &alpha, da, shift.lda(),
// db, shift.ldb(), &beta,
// dc, shift.ldc(), &fastest);
// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
// &alpha, da, shift.lda(),
// db, shift.ldb(),
// &beta, dc, shift.ldc(), nullptr, fastest); }, stream);
// result
auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; };
perf_t result;
// result.cublas = tflops(cublas_ns);
result.triton = tflops(triton_ns);
delete da;
delete db;
delete dc;
return result;
}
int main() {
using triton::dnn::op_t;
using triton::dnn::layout_t;
struct config_t{
int32_t B;
int32_t C;
int32_t H;
int32_t W;
int32_t R;
int32_t S;
int32_t F;
int32_t stride_h;
int32_t stride_w;
op_t op;
layout_t layout;
std::string ty;
std::string repr() {
std::ostringstream oss;
oss << B << ", " << C << ", " << H << ", " << W << ", " << R << ", " << S << ", " << F << ", " << op << ", " << layout << ", " << ty;
return oss.str();
}
perf_t perf(triton::driver::stream *stream){
return do_bench(stream, R, S, B, F, H, W, C, op, layout, ty);
}
};
// shapes to benchmark
std::vector<config_t> configs;
std::vector<config_t> resnet18 =
{
{128, 128, 32, 32, 3, 3, 128, 1, 1},
{128, 128, 32, 32, 3, 3, 128, 1, 1},
{128, 128, 32, 32, 3, 3, 256, 2, 2},
{128, 256, 16, 16, 3, 3, 256, 1, 1},
{128, 256, 16, 16, 3, 3, 512, 2, 2},
{128, 512, 8, 8, 3, 3, 512, 1, 1},
{128, 512, 8, 8, 3, 3, 1024, 1, 1},
{128, 1024, 8, 8, 3, 3, 1024, 1, 1}
};
for(config_t c: resnet18){
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}){
configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "half"});
}
}
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream *stream = triton::driver::stream::create(context);
for(config_t c: configs){
std::string repr = c.repr();
perf_t perf = c.perf(stream);
std::cout << "// " << repr << ", " << perf.triton << ", " << perf.cublas << std::endl;
}
}

View File

@@ -1,93 +0,0 @@
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-24817639
// Cuda compilation tools, release 10.0, V10.0.130
// Based on LLVM 3.4svn
//
.version 6.3
.target sm_60
.address_size 64
// .globl _Z25shift_cuda_forward_kernelPKfPKiPfiiii
.visible .entry shift(
.param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_0,
.param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_1,
.param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_2,
.param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_3,
.param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_4,
.param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_5,
.param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_6
)
{
.reg .pred %p<10>;
.reg .f32 %f<2>;
.reg .b32 %r<31>;
.reg .b64 %rd<13>;
ld.param.u64 %rd1, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_0];
ld.param.u64 %rd3, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_1];
ld.param.u64 %rd2, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_2];
ld.param.u32 %r3, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_3];
ld.param.u32 %r4, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_4];
ld.param.u32 %r5, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_5];
ld.param.u32 %r6, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_6];
cvta.to.global.u64 %rd4, %rd3;
mov.u32 %r7, %ntid.x;
mov.u32 %r8, %ctaid.x;
mov.u32 %r9, %tid.x;
mad.lo.s32 %r1, %r7, %r8, %r9;
mul.lo.s32 %r10, %r4, %r3;
mul.lo.s32 %r11, %r10, %r5;
mul.lo.s32 %r12, %r11, %r6;
mul.lo.s32 %r13, %r5, %r4;
mul.lo.s32 %r14, %r13, %r6;
rem.s32 %r15, %r1, %r14;
sub.s32 %r16, %r1, %r15;
mul.lo.s32 %r17, %r6, %r5;
div.s32 %r18, %r15, %r17;
mul.lo.s32 %r19, %r18, %r17;
sub.s32 %r20, %r15, %r19;
div.s32 %r21, %r20, %r5;
mul.lo.s32 %r22, %r21, %r6;
sub.s32 %r23, %r20, %r22;
shl.b32 %r24, %r18, 1;
mul.wide.s32 %rd5, %r24, 4;
add.s64 %rd6, %rd4, %rd5;
ld.global.nc.u32 %r25, [%rd6];
add.s32 %r26, %r25, %r21;
ld.global.nc.u32 %r27, [%rd6+4];
add.s32 %r28, %r23, %r27;
add.s32 %r29, %r16, %r19;
mad.lo.s32 %r30, %r26, %r5, %r29;
add.s32 %r2, %r30, %r28;
setp.lt.s32 %p1, %r1, %r12;
setp.gt.s32 %p2, %r26, -1;
and.pred %p3, %p1, %p2;
setp.lt.s32 %p4, %r26, %r5;
and.pred %p5, %p3, %p4;
setp.gt.s32 %p6, %r28, -1;
and.pred %p7, %p5, %p6;
setp.lt.s32 %p8, %r28, %r6;
and.pred %p9, %p7, %p8;
@!%p9 bra BB0_2;
bra.uni BB0_1;
BB0_1:
cvta.to.global.u64 %rd7, %rd1;
mul.wide.s32 %rd8, %r1, 4;
add.s64 %rd9, %rd7, %rd8;
ld.global.nc.f32 %f1, [%rd9];
cvta.to.global.u64 %rd10, %rd2;
mul.wide.s32 %rd11, %r2, 4;
add.s64 %rd12, %rd10, %rd11;
st.global.f32 [%rd12], %f1;
BB0_2:
ret;
}

View File

@@ -100,8 +100,8 @@ public:
private:
Value *ptr_;
bool return_vector_;
Value *offset_;
Builder &builder_;
Value *offset_;
std::map<indices_t, Value*> ptr_cache_;
unsigned vector_size_;
};
@@ -206,9 +206,9 @@ private:
tmap_t tmap_;
analysis::shmem::allocation *alloc_;
analysis::tune *params_;
target *tgt_;
analysis::shmem::info *buffer_info_;
analysis::alignment_info *alignment_;
target *tgt_;
std::map<unsigned, distributed_axis> axes_;
Value *sh_mem_ptr_;
Value *offset_a_i_, *offset_a_k_;

View File

@@ -47,11 +47,11 @@ public:
return std::make_pair(kind_, value_) < std::make_pair(other.kind_, other.value_);
}
const attribute_kind_t get_kind() const {
attribute_kind_t get_kind() const {
return kind_;
}
const unsigned get_value() const {
unsigned get_value() const {
return value_;
}

View File

@@ -344,8 +344,8 @@ public:
const expression *rvalue() const { return rvalue_; }
public:
ASSIGN_OP_T op_;
const expression *lvalue_;
ASSIGN_OP_T op_;
const expression *rvalue_;
};

View File

@@ -76,8 +76,8 @@ private:
void operator()(driver::stream *stream, const std::array<size_t, 3>& grid, const std::vector<arg>& args) const;
private:
std::shared_ptr<driver::module> parent_;
std::shared_ptr<driver::kernel> bin_;
std::shared_ptr<driver::module> parent_;
std::vector<arg_type> param_tys_;
size_t n_threads_;
};

View File

@@ -227,7 +227,7 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
if(auto *x = dynamic_cast<ir::constant_range*>(v)){
return cache(x->get_first()->get_value());
}
if(auto *x = dynamic_cast<ir::nv_dynamic_program_idx_inst*>(v)){
if(dynamic_cast<ir::nv_dynamic_program_idx_inst*>(v)){
return cache(128);
}
if(auto *x = dynamic_cast<ir::nv_static_program_idx*>(v)){

View File

@@ -19,7 +19,7 @@ bool info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
return br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(auto *br = dynamic_cast<ir::uncond_branch_inst*>(terminator))
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
return false;
else
throw std::runtime_error("unreachable");
@@ -36,15 +36,15 @@ void info::replace(ir::value* before, ir::value *after) {
}
inline bool get_is_shared(ir::value* v) {
if(auto x = dynamic_cast<ir::atomic_cas_inst*>(v))
if(dynamic_cast<ir::atomic_cas_inst*>(v))
return true;
if(auto x = dynamic_cast<ir::trans_inst*>(v))
if(dynamic_cast<ir::trans_inst*>(v))
return true;
if(auto x = dynamic_cast<ir::copy_to_shared_inst*>(v))
if(dynamic_cast<ir::copy_to_shared_inst*>(v))
return true;
if(auto x = dynamic_cast<ir::reduce_inst*>(v))
if(dynamic_cast<ir::reduce_inst*>(v))
return true;
if(auto x = dynamic_cast<ir::phi_node*>(v)){
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
bool res = true;
for(unsigned inc = 0; inc < x->get_num_incoming(); inc++)
res = res && get_is_shared(x->get_incoming_value(inc));

View File

@@ -58,7 +58,7 @@ void tune::init_c_graph(ir::instruction *v) {
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
else if(auto *atom = dynamic_cast<ir::atomic_add_inst*>(v))
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v))
else if(dynamic_cast<ir::downcast_inst*>(v))
return;
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v)) {
unsigned axis = reduce->get_axis();
@@ -116,7 +116,7 @@ void tune::init_c_graph(ir::instruction *v) {
}
}
// Matrix multiplication
else if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
else if(dynamic_cast<ir::dot_inst*>(v)){
ir::value *A = v->get_operand(0);
ir::value *B = v->get_operand(1);
ir::value *D = v->get_operand(2);
@@ -166,7 +166,7 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
std::string suffix = ".d" + std::to_string(x.second);
for(int i = 0; i < mps.size(); i++)
for(unsigned i = 0; i < mps.size(); i++)
params_[x.first].insert({prefixes[i] + suffix, mps[i]});
ir::type *ty = x.first->get_type();
if(ty->is_tile_ty()){
@@ -254,24 +254,24 @@ void tune::init(ir::module &mod) {
create_grids(grids_, references, fn);
}
int num_threads = get_num_threads();
auto clamp = [&](int x, int lo, int hi) { return std::min(std::max(x, lo), hi); };
unsigned num_threads = get_num_threads();
auto clamp = [&](unsigned x, unsigned lo, unsigned hi) { return std::min(std::max(x, lo), hi); };
for(ir::value *i: grids_){
if(!i->get_type()->is_tile_ty())
continue;
auto shapes = i->get_type()->get_tile_shapes();
int shape_0 = shapes[0]->get_value();
int shape_1 = shapes[1]->get_value();
int size = i->get_type()->get_tile_num_elements();
unsigned shape_0 = shapes[0]->get_value();
unsigned shape_1 = shapes[1]->get_value();
unsigned size = i->get_type()->get_tile_num_elements();
/* HMMA parameters*/
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
std::vector<int> fpw = {1, 1, 1};
std::vector<int> fpw_nm1;
int num_fragments = std::min((shape_0/8)*(shape_1/8), 4);
std::vector<unsigned> fpw = {1, 1, 1};
std::vector<unsigned> fpw_nm1;
unsigned num_fragments = std::min<unsigned>((shape_0/8)*(shape_1/8), 4);
do {
fpw_nm1 = fpw;
if(fpw[0]*fpw[1] < num_fragments)
@@ -280,13 +280,13 @@ void tune::init(ir::module &mod) {
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
}while(fpw_nm1 != fpw);
// store parameters
for(int d = 0; d < shapes.size(); d++)
for(unsigned d = 0; d < shapes.size(); d++)
params_.at(i).at("fpw.d" + std::to_string(d))->set_value(fpw[d]);
/* warps per tile */
// try to make things as square as possible to maximize data re-use
std::vector<int> wpt = {1, 1, 1};
std::vector<int> wpt_nm1;
std::vector<unsigned> wpt = {1, 1, 1};
std::vector<unsigned> wpt_nm1;
do{
wpt_nm1 = wpt;
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
@@ -295,7 +295,7 @@ void tune::init(ir::module &mod) {
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
}while(wpt_nm1 != wpt);
// store parameters
for(int d = 0; d < shapes.size(); d++)
for(unsigned d = 0; d < shapes.size(); d++)
params_.at(i).at("wpt.d" + std::to_string(d))->set_value(wpt[d]);
/* sanity check */
@@ -309,8 +309,8 @@ void tune::init(ir::module &mod) {
/* Scan-line */
else{
int shape = shapes[0]->get_value();
int current = num_threads;
unsigned shape = shapes[0]->get_value();
unsigned current = num_threads;
params_.at(i).at("nts.d0")->set_value(clamp(size / num_threads, 1, 8));
params_.at(i).at("mts.d0")->set_value(clamp(current, 1, shape / params_.at(i).at("nts.d0")->get_value()));
current = current / params_.at(i).at("mts.d0")->get_value();

View File

@@ -226,6 +226,7 @@ llvm::Instruction::BinaryOps llvm_op(ir::binary_op_t op) {
case ttop::Or: return llop::Or;
case ttop::Xor: return llop::Xor;
}
throw std::runtime_error("unknown operator");
}
llvm::Instruction::CastOps llvm_op(ir::cast_op_t op) {
@@ -246,6 +247,7 @@ llvm::Instruction::CastOps llvm_op(ir::cast_op_t op) {
case ttop::BitCast: return llop::BitCast;
case ttop::AddrSpaceCast: return llop::AddrSpaceCast;
}
throw std::runtime_error("unknown operator");
}
llvm::CmpInst::Predicate llvm_pred(ir::cmp_pred_t pred) {
@@ -283,6 +285,7 @@ llvm::CmpInst::Predicate llvm_pred(ir::cmp_pred_t pred) {
case ttop::ICMP_SLE: return llop::ICMP_SLE;
case ttop::LAST_ICMP_PREDICATE: return llop::LAST_ICMP_PREDICATE;
}
throw std::runtime_error("unknown operator");
}
/* convert ir::type to Type */
@@ -468,7 +471,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
if(ir::atomic_add_inst* ii = dynamic_cast<ir::atomic_add_inst*>(inst)){
Value *ptr = value(ii->get_operand(0));
Value *val = value(ii->get_operand(1));
Value *atom_f_add;
Value *atom_f_add = nullptr;
if(val->getType()->isFloatTy())
atom_f_add = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_atomic_load_add_f32, {ptr->getType()});
else if(val->getType()->isHalfTy()){
@@ -477,6 +480,8 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
FunctionType *atom_ty = FunctionType::get(fp16, {fp16->getPointerTo(), fp16}, false);
atom_f_add = InlineAsm::get(atom_ty, " atom.relaxed.global.gpu.add.noftz.f16 $0, [$1], $2;", "=h,l,h", true);
}
if(atom_f_add == nullptr)
throw std::runtime_error("unsupported atomic add");
Value *res = builder.CreateCall(atom_f_add, {ptr, val});
return (Instruction*)res;
}
@@ -607,7 +612,6 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
Value *_2 = builder.getInt32(2);
Value *_3 = builder.getInt32(3);
Value *_4 = builder.getInt32(4);
Value *_8 = builder.getInt32(8);
Value *_16 = builder.getInt32(16);
// fragments per warp
@@ -1303,11 +1307,10 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
unsigned id = linear / vector_size;
if(linear % vector_size == 0) {
Value *ptr = pointers->get_value(idx);
ConstantInt *cst = nullptr;
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
if(gep->getNumIndices() == 1){
cst = dyn_cast<ConstantInt>(gep->idx_begin());
}
// ConstantInt *cst = nullptr;
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
// if(gep->getNumIndices() == 1)
// cst = dyn_cast<ConstantInt>(gep->idx_begin());
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
ptr->getType()->getPointerAddressSpace()));
@@ -1374,10 +1377,6 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB
unsigned id = linear / vector_size;
if(linear % vector_size == 0) {
Value *ptr = pointers->get_value(idx);
ConstantInt *cst = nullptr;
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
if(gep->getNumIndices() == 1)
cst = dyn_cast<ConstantInt>(gep->idx_begin());
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
ptr->getType()->getPointerAddressSpace()));
packets[id] = builder.CreateLoad(ptr);

View File

@@ -60,6 +60,7 @@ ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
trans->set_operand(0, i);
return trans;
}
return nullptr;
}
bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
@@ -76,6 +77,8 @@ bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
if(!phi)
return false;
ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm());
if(!new_phi)
return false;
trans->replace_all_uses_with(new_phi);
return true;

View File

@@ -67,8 +67,7 @@ constant_range::constant_range(type *ty, constant_int *first, constant_int *last
constant *constant_range::get(constant_int *first, constant_int *last) {
assert(first->get_type()->is_integer_ty());
assert(first->get_type() == last->get_type());
unsigned vfirst = ((constant_int*)first)->get_value();
assert(vfirst == 0);
assert(((constant_int*)first)->get_value() == 0);
type *ty = tile_type::get(first->get_type(), {last});
return new constant_range(ty, first, last);
}

View File

@@ -359,8 +359,11 @@ getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::
: instruction(get_return_type(pointee_ty, ptr, idx), 1 + idx.size(), 1, name, next),
source_elt_ty(pointee_ty),
res_elt_ty(get_indexed_type(pointee_ty, idx)){
type *expected_ty = ((pointer_type*)(get_type()->get_scalar_ty()))->get_element_ty();
// sanity check
type *expected_ty = get_type()->get_scalar_ty();
expected_ty = ((pointer_type*)expected_ty)->get_element_ty();
assert(res_elt_ty == expected_ty);
// set operands
set_operand(0, ptr);
for(size_t i = 0; i < idx.size(); i++)
set_operand(1 + i, idx[i]);
@@ -574,7 +577,7 @@ ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector<constant_int*> perm)
// permutate argument shapes
perm = init_perm(ty, perm);
ir::tile_type::tile_shapes_t res_shapes = arg_shapes;
for(int i = 0; i < perm.size(); i++)
for(size_t i = 0; i < perm.size(); i++)
res_shapes[i] = arg_shapes[perm[i]->get_value()];
// construct type
return tile_type::get(ty->get_scalar_ty(), res_shapes);
@@ -587,16 +590,17 @@ std::vector<constant_int*> trans_inst::init_perm(ir::type* ty, const std::vector
ir::type* int32_ty = type::get_int32_ty(ty->get_context());
std::vector<constant_int*> result;
result.push_back(ir::constant_int::get(int32_ty, size - 1));
for(int i = 0; i < size - 1; i++)
for(size_t i = 0; i < size - 1; i++)
result.push_back(ir::constant_int::get(int32_ty, i));
return result;
}
trans_inst::trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string &name, instruction *next)
: builtin_inst(get_res_ty(arg->get_type(), perm), 1, 1, name, next) {
// sanity check
perm_ = init_perm(arg->get_type(), perm);
auto size = arg->get_type()->get_tile_shapes().size();
assert(perm_.size() == size);
//auto size = arg->get_type()->get_tile_shapes().size();
//assert(perm_.size() == size);
set_operand(0, arg);
}

View File

@@ -96,8 +96,7 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
bool is_const = const_.find(name) != const_.end();
auto &preds = block->get_predecessors();
ir::type *ty = get_scope().types.at(name);
if(block)
if(!is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
incomplete_phis_[block][name] = make_phi(ty, 1, block);
result = (ir::value*)incomplete_phis_[block][name];
}
@@ -106,9 +105,9 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
result = get_value(name, has_pred?preds.front():nullptr);
}
else{
result = make_phi(ty, 1, block);
set_value(name, block, result);
result = add_phi_operands(name, (ir::phi_node*&)result);
ir::phi_node* phi = make_phi(ty, 1, block);
set_value(name, block, phi);
result = add_phi_operands(name, phi);
}
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
result = try_remove_trivial_phis(phi);

View File

@@ -106,7 +106,7 @@ void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs)
size_t res_size = std::max(lhs_size, rhs_size);
ir::type::tile_shapes_t res_shapes(res_size);
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
for(int i = 0; i < res_size; i++){
for(size_t i = 0; i < res_size; i++){
if(i >= res_size - lhs_size && i >= res_size - rhs_size)
res_shapes[i] = lhs_shapes[i]==one?rhs_shapes[i]:lhs_shapes[i];
else if(i >= res_size - lhs_size)
@@ -147,7 +147,7 @@ void node::implicit_broadcast(ir::module *mod, ir::type *ty, ir::value *&src){
int src_dim = src_shapes.size();
// Pad
int off = dst_dim - src_dim;
for(size_t i = 0; i < off; i++)
for(int i = 0; i < off; i++)
src_shapes.insert(src_shapes.begin(), one);
if(off > 0)
src = builder.create_reshape(src, src_shapes);

View File

@@ -88,10 +88,10 @@ arg_type convert(ir::type *ty) {
}
function::caller::caller(ir::function *ir, std::shared_ptr<driver::module> parent, size_t n_threads)
: bin_(driver::kernel::create(&*parent, ir->get_name().c_str())), n_threads_(n_threads), parent_(parent) {
: bin_(driver::kernel::create(&*parent, ir->get_name().c_str())), parent_(parent), n_threads_(n_threads) {
// extract signature
ir::function_type* ty = ir->get_fn_type();
for(int i = 0; i < ty->get_num_params(); i++)
for(size_t i = 0; i < ty->get_num_params(); i++)
param_tys_.push_back(convert(ty->get_param_ty(i)));
}

View File

@@ -11,7 +11,8 @@ void matmul(restrict read_only align(16) half *A,
restrict read_only align(16) half *B,
restrict read_only align(16) half *C,
int M, int N, int K,
multiple_of(8) int lda, multiple_of(8) int ldb, int ldc) {
multiple_of(8) int lda, multiple_of(8) int ldb, int ldc)
{
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rxa[TM] = ridx * TM + (0 ... TM);

View File

@@ -17,8 +17,8 @@ import tensorflow as tf
extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so')
def make_bindings(src, outputs, grids):
return libtriton.make_tensorflow_src(src, outputs, grids)
def make_bindings(src, out, grid):
return libtriton.make_tensorflow_src(src, out, grid)
def make_cache_path(src):
md5 = hashlib.sha1(src.encode())