[test] added tests for copy
This commit is contained in:
@@ -156,6 +156,8 @@ private:
|
|||||||
Constant* llvm_constant(ir::constant *cst, LLVMContext &ctx);
|
Constant* llvm_constant(ir::constant *cst, LLVMContext &ctx);
|
||||||
Value* llvm_alloc_const(ir::alloc_const *v, Module *module, Builder &builder);
|
Value* llvm_alloc_const(ir::alloc_const *v, Module *module, Builder &builder);
|
||||||
ArrayType* llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx);
|
ArrayType* llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx);
|
||||||
|
Function* llvm_fn(ir::function *fn, Builder& builder, Module &dst);
|
||||||
|
Value* alloc_shared(Builder &builder, Module& dst);
|
||||||
|
|
||||||
// grid construction
|
// grid construction
|
||||||
void create_grids(std::vector<ir::value *> &grids,
|
void create_grids(std::vector<ir::value *> &grids,
|
||||||
@@ -167,7 +169,7 @@ private:
|
|||||||
void init_strided_scan_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
void init_strided_scan_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||||
void init_hmma_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
void init_hmma_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||||
void init_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
void init_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||||
void init_grids(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
|
void init_layouts(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
|
||||||
|
|
||||||
// lower scalar instruction
|
// lower scalar instruction
|
||||||
void lower_instruction(ir::instruction *src, Builder &builder);
|
void lower_instruction(ir::instruction *src, Builder &builder);
|
||||||
|
@@ -14,7 +14,7 @@ namespace ir {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
namespace analysis{
|
namespace transform{
|
||||||
|
|
||||||
class cts {
|
class cts {
|
||||||
public:
|
public:
|
||||||
|
@@ -573,51 +573,31 @@ inline int32_t ceil(int32_t num, int32_t div){
|
|||||||
return (num + div - 1)/div;
|
return (num + div - 1)/div;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void to_warps(const std::vector<unsigned> &bs, const std::vector<int>& order, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
|
|
||||||
static const size_t warp_size = 32;
|
|
||||||
size_t nthreads = 1, nwarps = 1;
|
|
||||||
nw.resize(bs.size());
|
|
||||||
ws.resize(bs.size());
|
|
||||||
for(size_t i = 0; i < bs.size(); ++i){
|
|
||||||
nthreads *= bs[i];
|
|
||||||
nw[order[i]] = ceil(nthreads, nwarps*warp_size);
|
|
||||||
nwarps *= nw[order[i]];
|
|
||||||
}
|
|
||||||
for(size_t i = 0; i < bs.size(); ++i){
|
|
||||||
ws[i] = bs[i] / nw[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||||
auto order = tiles_->order(v);
|
auto order = tiles_->order(v);
|
||||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||||
size_t dim = shapes.size();
|
size_t dim = shapes.size();
|
||||||
std::vector<unsigned> contiguous(dim);
|
std::vector<unsigned> contiguous(dim);
|
||||||
std::vector<unsigned> block_size(dim);
|
std::vector<unsigned> block_size(dim);
|
||||||
std::vector<unsigned> warp_size(dim);
|
|
||||||
std::vector<unsigned> n_warps(dim);
|
|
||||||
for(unsigned i = 0; i < shapes.size(); i++){
|
for(unsigned i = 0; i < shapes.size(); i++){
|
||||||
contiguous[i] = tiles_->nts(v, i);
|
contiguous[i] = tiles_->nts(v, i);
|
||||||
block_size[i] = tiles_->mts(v, i);
|
block_size[i] = tiles_->mts(v, i);
|
||||||
}
|
}
|
||||||
to_warps(block_size, order, n_warps, warp_size);
|
Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id);
|
||||||
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder);
|
std::vector<Value*> thread_id = delinearize(full_thread_id, order, block_size, builder);
|
||||||
std::vector<Value*> warp_id = delinearize(u_warp_id, order, n_warps, builder);
|
|
||||||
// Create axes
|
// Create axes
|
||||||
for(unsigned k = 0; k < dim; k++) {
|
for(unsigned k = 0; k < dim; k++) {
|
||||||
std::string str_k = std::to_string(k);
|
std::string str_k = std::to_string(k);
|
||||||
Value *warp_size_k = builder.getInt32(warp_size[k]);
|
|
||||||
Value *contiguous_k = builder.getInt32(contiguous[k]);
|
Value *contiguous_k = builder.getInt32(contiguous[k]);
|
||||||
Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k));
|
Value *scaled_thread_id = builder.CreateMul(thread_id[k], contiguous_k);
|
||||||
Value *scaled_thread_id = builder.CreateMul(thread_id, contiguous_k);
|
unsigned per_block = contiguous[k] * block_size[k];
|
||||||
unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k];
|
|
||||||
unsigned per_thread = contiguous[k] * shapes[k] / per_block;
|
unsigned per_thread = contiguous[k] * shapes[k] / per_block;
|
||||||
std::vector<Value*> idx_list(per_thread);
|
std::vector<Value*> idx_list(per_thread);
|
||||||
for(unsigned n = 0 ; n < per_thread; n++){
|
for(unsigned n = 0 ; n < per_thread; n++){
|
||||||
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
||||||
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||||
}
|
}
|
||||||
axes_[a_axes_->get_id(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
|
axes_[a_axes_->get_id(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id[k]};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -825,7 +805,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
|||||||
create_distributed_tile(v, builder);
|
create_distributed_tile(v, builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem_ptr){
|
void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_mem_ptr){
|
||||||
// fetch linear ID
|
// fetch linear ID
|
||||||
Module *mod = builder.GetInsertBlock()->getParent()->getParent();
|
Module *mod = builder.GetInsertBlock()->getParent()->getParent();
|
||||||
Value *warp_size = builder.getInt32(32);
|
Value *warp_size = builder.getInt32(32);
|
||||||
@@ -1454,84 +1434,83 @@ ArrayType* selection::llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx)
|
|||||||
return ArrayType::get(llvm_type(ty->get_scalar_ty(), ctx), size);
|
return ArrayType::get(llvm_type(ty->get_scalar_ty(), ctx), size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void selection::run(ir::module &src, Module &dst) {
|
Function* selection::llvm_fn(ir::function *fn, IRBuilder<>& builder, Module& dst) {
|
||||||
vmap_.clear();
|
LLVMContext &ctx = builder.getContext();
|
||||||
LLVMContext &dst_ctx = dst.getContext();
|
FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), ctx);
|
||||||
IRBuilder<> dst_builder(dst_ctx);
|
|
||||||
|
|
||||||
for(ir::alloc_const *x: src.allocs()) {
|
|
||||||
vmap_[x] = llvm_alloc_const(x, &dst, dst_builder);
|
|
||||||
}
|
|
||||||
|
|
||||||
// iterate over functions
|
|
||||||
for(ir::function *fn: src.get_function_list()) {
|
|
||||||
|
|
||||||
// create LLVM function
|
|
||||||
FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), dst_ctx);
|
|
||||||
FunctionType *dst_fn_ty = fn_ty;
|
FunctionType *dst_fn_ty = fn_ty;
|
||||||
if(!tgt_->is_gpu()){
|
if(!tgt_->is_gpu()){
|
||||||
Type *dst_fn_ret_ty = fn_ty->getReturnType();
|
Type *dst_fn_ret_ty = fn_ty->getReturnType();
|
||||||
std::vector<Type*> dst_fn_args_ty;
|
std::vector<Type*> dst_fn_args_ty;
|
||||||
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
|
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
|
||||||
dst_fn_args_ty.push_back(fn_ty->getParamType(i));
|
dst_fn_args_ty.push_back(fn_ty->getParamType(i));
|
||||||
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
dst_fn_args_ty.push_back(builder.getInt32Ty());
|
||||||
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
dst_fn_args_ty.push_back(builder.getInt32Ty());
|
||||||
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
|
dst_fn_args_ty.push_back(builder.getInt32Ty());
|
||||||
dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
|
dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
|
||||||
}
|
}
|
||||||
|
Function *ret = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
|
||||||
// grid indices
|
|
||||||
fn->get_fn_type()->get_return_ty();
|
|
||||||
Function *dst_fn = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
|
|
||||||
|
|
||||||
// set attributes
|
// set attributes
|
||||||
for(auto attr_pair: fn->attrs()){
|
for(auto attr_pair: fn->attrs()){
|
||||||
unsigned id = attr_pair.first;
|
unsigned id = attr_pair.first;
|
||||||
for(ir::attribute attr: attr_pair.second)
|
for(ir::attribute attr: attr_pair.second)
|
||||||
if(attr.is_llvm_attr()){
|
if(attr.is_llvm_attr())
|
||||||
dst_fn->addAttribute(id, llvm_attr(dst_ctx, attr));
|
ret->addAttribute(id, llvm_attr(ctx, attr));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn);
|
|
||||||
// set metadata
|
// set metadata
|
||||||
|
tgt_->set_kernel(builder, ctx, &dst, ret);
|
||||||
Metadata *md_args[] = {
|
Metadata *md_args[] = {
|
||||||
ValueAsMetadata::get(dst_fn),
|
ValueAsMetadata::get(ret),
|
||||||
MDString::get(dst_ctx, "maxntidx"),
|
MDString::get(ctx, "maxntidx"),
|
||||||
ValueAsMetadata::get(dst_builder.getInt32(num_warps_*32))
|
ValueAsMetadata::get(builder.getInt32(num_warps_*32))
|
||||||
};
|
};
|
||||||
dst.getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(dst_ctx, md_args));
|
dst.getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||||
|
|
||||||
|
|
||||||
// map parameters
|
// map parameters
|
||||||
for(unsigned i = 0; i < fn->args().size(); i++)
|
for(unsigned i = 0; i < fn->args().size(); i++)
|
||||||
vmap_[fn->args()[i]] = &*(dst_fn->arg_begin() + i);
|
vmap_[fn->args()[i]] = &*(ret->arg_begin() + i);
|
||||||
// create blocks
|
// create blocks
|
||||||
for(ir::basic_block *block: fn->blocks()) {
|
for(ir::basic_block *block: fn->blocks()) {
|
||||||
BasicBlock *dst_block = BasicBlock::Create(dst_ctx, block->get_name(), dst_fn);
|
BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
|
||||||
vmap_[block] = dst_block;
|
vmap_[block] = dst_block;
|
||||||
}
|
}
|
||||||
dst_builder.SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]);
|
builder.SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]);
|
||||||
|
}
|
||||||
|
|
||||||
// allocate shared memory
|
Value* selection::alloc_shared(IRBuilder<> &builder, Module& dst) {
|
||||||
Value *sh_mem_ptr = nullptr;
|
Value *ret = nullptr;
|
||||||
|
LLVMContext &ctx = builder.getContext();
|
||||||
if(tgt_->is_gpu())
|
if(tgt_->is_gpu())
|
||||||
if(unsigned alloc_size = alloc_->allocated_size()){
|
if(unsigned alloc_size = alloc_->allocated_size()){
|
||||||
Type *int_8_ty = Type::getInt8Ty(dst_ctx);
|
Type *int_8_ty = Type::getInt8Ty(ctx);
|
||||||
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
|
||||||
Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
Type *ptr_ty = PointerType::get(int_8_ty, 3);
|
||||||
GlobalVariable *sh_mem_array =
|
GlobalVariable *sh_mem_array =
|
||||||
new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
|
new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||||
sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
|
ret = builder.CreateBitCast(sh_mem_array, ptr_ty);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
sh_mem_ptr_ = sh_mem_ptr;
|
|
||||||
|
|
||||||
// create grids
|
void selection::run(ir::module &src, Module &dst) {
|
||||||
init_grids(fn, dst_builder, sh_mem_ptr);
|
vmap_.clear();
|
||||||
|
tmap_.clear();
|
||||||
|
|
||||||
|
LLVMContext &dst_ctx = dst.getContext();
|
||||||
|
IRBuilder<> dst_builder(dst_ctx);
|
||||||
|
|
||||||
// iterate through block
|
// constant memory
|
||||||
|
for(ir::alloc_const *x: src.allocs())
|
||||||
|
vmap_[x] = llvm_alloc_const(x, &dst, dst_builder);
|
||||||
|
|
||||||
|
// iterate over functions
|
||||||
|
for(ir::function *fn: src.get_function_list()) {
|
||||||
|
// create LLVM function
|
||||||
|
llvm_fn(fn, dst_builder, dst);
|
||||||
|
// allocate shared memory
|
||||||
|
sh_mem_ptr_ = alloc_shared(dst_builder, dst);
|
||||||
|
// initialize layouts
|
||||||
|
init_layouts(fn, dst_builder, sh_mem_ptr_);
|
||||||
|
// generate LLVM-IR code
|
||||||
std::map<ir::basic_block*, BasicBlock*> last_block;
|
std::map<ir::basic_block*, BasicBlock*> last_block;
|
||||||
for(ir::basic_block *block: fn->blocks()) {
|
for(ir::basic_block *block: fn->blocks()) {
|
||||||
BasicBlock *parent = (BasicBlock*)vmap_[block];
|
BasicBlock *parent = (BasicBlock*)vmap_[block];
|
||||||
@@ -1547,7 +1526,7 @@ void selection::run(ir::module &src, Module &dst) {
|
|||||||
last_block[block] = dst_builder.GetInsertBlock();
|
last_block[block] = dst_builder.GetInsertBlock();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// finalize double-buffering
|
||||||
for(ir::basic_block *block: fn->blocks())
|
for(ir::basic_block *block: fn->blocks())
|
||||||
for(ir::instruction *inst: block->get_inst_list()) {
|
for(ir::instruction *inst: block->get_inst_list()) {
|
||||||
if(liveness_->has_double(inst)) {
|
if(liveness_->has_double(inst)) {
|
||||||
@@ -1574,7 +1553,7 @@ void selection::run(ir::module &src, Module &dst) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// add phi operands
|
// finalize phi
|
||||||
for(ir::basic_block *block: fn->blocks())
|
for(ir::basic_block *block: fn->blocks())
|
||||||
for(ir::instruction *inst: block->get_inst_list())
|
for(ir::instruction *inst: block->get_inst_list())
|
||||||
if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){
|
if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){
|
||||||
|
@@ -9,9 +9,8 @@
|
|||||||
#include "triton/ir/type.h"
|
#include "triton/ir/type.h"
|
||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
namespace analysis{
|
namespace transform{
|
||||||
|
|
||||||
// run pass on module
|
// run pass on module
|
||||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder) {
|
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder) {
|
||||||
|
@@ -199,7 +199,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||||
// create passes
|
// create passes
|
||||||
codegen::analysis::cts cts;
|
codegen::transform::cts cts;
|
||||||
codegen::analysis::align align;
|
codegen::analysis::align align;
|
||||||
codegen::analysis::liveness shmem_liveness;
|
codegen::analysis::liveness shmem_liveness;
|
||||||
codegen::analysis::axes axes;
|
codegen::analysis::axes axes;
|
||||||
|
@@ -1,65 +1,35 @@
|
|||||||
#include <cstring>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <tuple>
|
||||||
#include <cstdio>
|
#include "copy.h"
|
||||||
#include "triton/driver/backend.h"
|
#include "triton/driver/backend.h"
|
||||||
#include "triton/driver/stream.h"
|
|
||||||
#include "triton/tools/bench.hpp"
|
|
||||||
#include "triton/external/half.hpp"
|
|
||||||
#include "triton/runtime/function.h"
|
|
||||||
#include "src/copy.h"
|
|
||||||
#include "util.h"
|
|
||||||
#include "cuda/cublas.h"
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<double> do_bench(drv::stream* stream, int32_t M, int32_t N, order_t order_x, order_t order_y){
|
|
||||||
typedef float NumericT;
|
|
||||||
std::string ty = "float";
|
|
||||||
size_t dt_nbytes = sizeof(NumericT);
|
|
||||||
drv::context* context = stream->context();
|
|
||||||
// create inputs
|
|
||||||
auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
|
|
||||||
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
|
|
||||||
// create options
|
|
||||||
rt::function::options_space_t opt;
|
|
||||||
opt.defines.push_back({"TYPE", {ty}});
|
|
||||||
opt.defines.push_back({"STRIDE_XM", {(order_x == ROWMAJOR)?"M":"1"}});
|
|
||||||
opt.defines.push_back({"STRIDE_XN", {(order_x == ROWMAJOR)?"1":"N"}});
|
|
||||||
opt.defines.push_back({"STRIDE_YM", {(order_y == ROWMAJOR)?"M":"1"}});
|
|
||||||
opt.defines.push_back({"STRIDE_YN", {(order_y == ROWMAJOR)?"1":"N"}});
|
|
||||||
opt.defines.push_back({"TM", {"32"}});
|
|
||||||
opt.defines.push_back({"TN", {"32"}});
|
|
||||||
opt.num_warps = {4};
|
|
||||||
// create function
|
|
||||||
rt::function function(src::copy2d, opt);
|
|
||||||
// benchmark available libraries
|
|
||||||
std::vector<double> result;
|
|
||||||
auto gbps = [&](double ns) { return 2*M*N*dt_nbytes / (ns * 1e-9) * 1e-9; };
|
|
||||||
// triton
|
|
||||||
double triton_ns = triton::tools::bench([&]() { function({&*dx, &*dy, M, N}, grid2d(M, N), stream);}, stream);
|
|
||||||
result.push_back(gbps(triton_ns));
|
|
||||||
// done
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
typedef std::tuple<int, int, order_t, order_t> config_t;
|
typedef std::tuple<std::vector<int>, std::vector<int>, std::vector<int>> config_t;
|
||||||
std::vector<config_t> configs = {
|
std::vector<config_t> configs = {
|
||||||
{4096, 4096, ROWMAJOR, ROWMAJOR},
|
{{4096*4096}, {0}, {0}},
|
||||||
{4096, 4096, COLMAJOR, ROWMAJOR},
|
{{4096, 4096}, {0, 1}, {1, 0}},
|
||||||
{4096, 4096, ROWMAJOR, COLMAJOR},
|
{{4096, 4096}, {0, 1}, {1, 0}},
|
||||||
{4096, 4096, COLMAJOR, COLMAJOR},
|
{{4096, 4096}, {1, 0}, {0, 1}},
|
||||||
|
{{4096, 4096}, {0, 1}, {0, 1}},
|
||||||
|
{{256, 256, 256}, {0, 1, 2}, {0, 1, 2}},
|
||||||
|
{{256, 256, 256}, {0, 1, 2}, {0, 2, 1}},
|
||||||
|
{{256, 256, 256}, {1, 0, 2}, {1, 2, 0}},
|
||||||
|
{{256, 256, 256}, {1, 2, 0}, {1, 0, 2}},
|
||||||
|
{{256, 256, 256}, {2, 0, 1}, {0, 1, 2}},
|
||||||
|
{{256, 256, 256}, {2, 1, 0}, {0, 2, 1}}
|
||||||
};
|
};
|
||||||
// does the work
|
// does the work
|
||||||
int32_t M, N;
|
std::vector<int32_t> shape;
|
||||||
order_t ord_x, ord_y;
|
std::vector<int32_t> ord_x, ord_y;
|
||||||
for(const auto& c: configs){
|
for(const auto& c: configs){
|
||||||
std::tie(M, N, ord_x, ord_y) = c;
|
std::tie(shape, ord_x, ord_y) = c;
|
||||||
std::cout << "// " << M << ", " << N << ", " << ord_x << ", " << ord_y << std::flush;
|
std::cout << "// " << c << std::flush;
|
||||||
for(auto perf: do_bench(stream, M, N, ord_x, ord_y))
|
for(auto perf: bench_copy_nd(stream, shape, ord_x, ord_y))
|
||||||
std::cout << ", " << perf << std::flush;
|
std::cout << ", " << perf << std::flush;
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
142
tests/common/copy.h
Normal file
142
tests/common/copy.h
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
#include "src/copy.h"
|
||||||
|
#include "triton/driver/stream.h"
|
||||||
|
#include "triton/runtime/function.h"
|
||||||
|
#include "triton/tools/bench.hpp"
|
||||||
|
#include "util.h"
|
||||||
|
|
||||||
|
int32_t off(const std::vector<int32_t>& idx, const std::vector<int32_t>& strides) {
|
||||||
|
int32_t res = 0;
|
||||||
|
for(size_t d = 0; d < idx.size(); d++)
|
||||||
|
res += idx[d] * strides[d];
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum run_mode_t {
|
||||||
|
BENCH,
|
||||||
|
TEST
|
||||||
|
};
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
void cc_copy_nd(const std::vector<T>& x, std::vector<T>& y,
|
||||||
|
const std::vector<int32_t>& shape,
|
||||||
|
const std::vector<int32_t>& x_order, const std::vector<int32_t>& y_order) {
|
||||||
|
size_t rank = shape.size();
|
||||||
|
// strides for x
|
||||||
|
std::vector<int32_t> x_strides(shape.size());
|
||||||
|
for(size_t d = 0; d < rank; d++)
|
||||||
|
x_strides[x_order[d]] = (d == 0) ? 1 : (x_strides[x_order[d-1]] * shape[x_order[d-1]]);
|
||||||
|
// strides for y
|
||||||
|
std::vector<int32_t> y_strides(shape.size());
|
||||||
|
for(size_t d = 0; d < rank; d++)
|
||||||
|
y_strides[y_order[d]] = (d == 0) ? 1 : (y_strides[y_order[d-1]] * shape[y_order[d-1]]);
|
||||||
|
// copy 1d
|
||||||
|
if(rank == 1)
|
||||||
|
for(int32_t i = 0; i < shape[0]; i++)
|
||||||
|
y[off({i}, y_strides)] = x[off({i}, x_strides)];
|
||||||
|
// copy 2d
|
||||||
|
if(rank == 2)
|
||||||
|
for(int32_t i = 0; i < shape[0]; i++)
|
||||||
|
for(int32_t j = 0; j < shape[1]; j++)
|
||||||
|
y[off({i, j}, y_strides)] = x[off({i, j}, x_strides)];
|
||||||
|
// copy 3d
|
||||||
|
if(rank == 3)
|
||||||
|
for(int32_t i = 0; i < shape[0]; i++)
|
||||||
|
for(int32_t j = 0; j < shape[1]; j++)
|
||||||
|
for(int32_t k = 0; k < shape[2]; k++)
|
||||||
|
y[off({i, j, k}, y_strides)] = x[off({i, j, k}, x_strides)];
|
||||||
|
}
|
||||||
|
|
||||||
|
void triton_copy_nd(drv::stream* stream, const std::vector<int32_t>& shape,
|
||||||
|
const std::vector<int32_t>& x_order, const std::vector<int32_t>& y_order,
|
||||||
|
std::vector<std::vector<std::string>> TS,
|
||||||
|
run_mode_t mode, std::vector<double>& bench, bool &test) {
|
||||||
|
typedef float NumericT;
|
||||||
|
std::string ty = "float";
|
||||||
|
size_t dtsize = sizeof(NumericT);
|
||||||
|
drv::context* context = stream->context();
|
||||||
|
|
||||||
|
// rank
|
||||||
|
size_t rank = shape.size();
|
||||||
|
// size
|
||||||
|
size_t size = 1;
|
||||||
|
for(int32_t d: shape)
|
||||||
|
size *= d;
|
||||||
|
std::vector<std::string> shapename = {"S0", "S1", "S2"};
|
||||||
|
// strides for x
|
||||||
|
std::vector<std::string> x_strides = {"1"};
|
||||||
|
for(size_t d = 0; d < rank - 1; d++)
|
||||||
|
x_strides.push_back(x_strides[d] + " * " + shapename[x_order[d]]);
|
||||||
|
// strides for y
|
||||||
|
std::vector<std::string> y_strides = {"1"};
|
||||||
|
for(size_t d = 0; d < rank - 1; d++)
|
||||||
|
y_strides.push_back(y_strides[d] + " * " + shapename[y_order[d]]);
|
||||||
|
|
||||||
|
// create inputs
|
||||||
|
auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize));
|
||||||
|
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize));
|
||||||
|
// create options
|
||||||
|
rt::function::options_space_t opt;
|
||||||
|
|
||||||
|
|
||||||
|
// macros
|
||||||
|
opt.defines.push_back({"TYPE", {ty}});
|
||||||
|
for(size_t d = 0; d < rank; d++)
|
||||||
|
opt.defines.push_back({"STRIDE_XS" + std::to_string(x_order[d]), {x_strides[d]}});
|
||||||
|
for(size_t d = 0; d < rank; d++)
|
||||||
|
opt.defines.push_back({"STRIDE_YS" + std::to_string(y_order[d]), {y_strides[d]}});
|
||||||
|
if(TS.empty())
|
||||||
|
TS = tile_nd(rank);
|
||||||
|
for(size_t d = 0; d < rank; d++)
|
||||||
|
opt.defines.push_back({"TS" + std::to_string(d), TS[d]});
|
||||||
|
opt.num_warps = {4};
|
||||||
|
|
||||||
|
// kernel
|
||||||
|
rt::function function(src::copy_nd[rank - 1], opt);
|
||||||
|
std::vector<rt::arg> args = {&*dx, &*dy};
|
||||||
|
for(int32_t d: shape)
|
||||||
|
args.push_back(d);
|
||||||
|
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
|
||||||
|
auto grid = grid_nd(shape, ts);
|
||||||
|
|
||||||
|
// metrics
|
||||||
|
if(mode == BENCH){
|
||||||
|
auto gbps = [&](double ns) { return 2 * size * dtsize / (ns * 1e-9) * 1e-9; };
|
||||||
|
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
|
||||||
|
bench.push_back(gbps(triton_ns));
|
||||||
|
}
|
||||||
|
|
||||||
|
// test triton
|
||||||
|
if(mode == TEST){
|
||||||
|
std::vector<NumericT> hx(size);
|
||||||
|
std::vector<NumericT> hy(size);
|
||||||
|
std::vector<NumericT> ry(size);
|
||||||
|
for(size_t i = 0; i < hx.size(); i++)
|
||||||
|
hx[i] = static_cast<NumericT>((float)rand()/RAND_MAX);
|
||||||
|
stream->write(&*dx, true, 0, hx);
|
||||||
|
function(args, grid, stream);
|
||||||
|
stream->synchronize();
|
||||||
|
stream->read(&*dy, true, 0, hy);
|
||||||
|
cc_copy_nd(hx, ry, shape, x_order, y_order);
|
||||||
|
test = testing::diff(hy, ry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<double> bench_copy_nd(drv::stream* stream, const std::vector<int32_t>& shape,
|
||||||
|
const std::vector<int32_t>& x_order, const std::vector<int32_t>& y_order) {
|
||||||
|
std::vector<double> bench;
|
||||||
|
bool test;
|
||||||
|
triton_copy_nd(stream, shape, x_order, y_order, {}, BENCH, bench, test);
|
||||||
|
return bench;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool test_copy_nd(drv::stream* stream, const std::vector<int32_t>& shape,
|
||||||
|
const std::vector<int32_t>& TS,
|
||||||
|
const std::vector<int32_t>& x_order, const std::vector<int32_t>& y_order) {
|
||||||
|
std::vector<double> bench;
|
||||||
|
bool test;
|
||||||
|
std::vector<std::vector<std::string>> TSS;
|
||||||
|
for(int32_t d: TS)
|
||||||
|
TSS.push_back({std::to_string(d)});
|
||||||
|
triton_copy_nd(stream, shape, x_order, y_order, TSS, TEST, bench, test);
|
||||||
|
return test;
|
||||||
|
}
|
@@ -1,33 +1,66 @@
|
|||||||
|
#ifndef _TRITON_TEST_SRC_COPY_H_
|
||||||
|
#define _TRITON_TEST_SRC_COPY_H_
|
||||||
|
|
||||||
namespace src {
|
namespace src {
|
||||||
|
|
||||||
const char *copy1d =
|
const char *copy1d =
|
||||||
R"(
|
R"(
|
||||||
void copy1d(TYPE * X __noalias __readonly __aligned(16),
|
void copy1d(TYPE * X __noalias __readonly __aligned(16),
|
||||||
TYPE * Y __noalias __readonly __aligned(16),
|
TYPE * Y __noalias __readonly __aligned(16),
|
||||||
int N) {
|
int S0) {
|
||||||
int ridm = get_program_id(0);
|
int pid0 = get_program_id(0);
|
||||||
int rm[TN] = ridm * TN + 0 ... TN;
|
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
|
||||||
TYPE* px[TN] = X + rm;
|
TYPE* px[TS0] = X + rs0;
|
||||||
TYPE* py[TN] = Y + rm;
|
TYPE* py[TS0] = Y + rs0;
|
||||||
*py = *px;
|
*py = *px;
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
|
||||||
const char *copy2d =
|
const char *copy2d =
|
||||||
R"(
|
R"(
|
||||||
void copy2d(TYPE * X __noalias __readonly __aligned(16),
|
void copy2d(TYPE * X __noalias __readonly __aligned(16),
|
||||||
TYPE * Y __noalias __writeonly __aligned(16),
|
TYPE * Y __noalias __writeonly __aligned(16),
|
||||||
int M __multipleof(8),
|
int S0 __multipleof(8),
|
||||||
int N __multipleof(8)) {
|
int S1 __multipleof(8)) {
|
||||||
int ridm = get_program_id(0);
|
int pid0 = get_program_id(0);
|
||||||
int ridn = get_program_id(1);
|
int pid1 = get_program_id(1);
|
||||||
int rm[TM] = ridm * TM + 0 ... TM;
|
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
|
||||||
int rn[TN] = ridn * TN + 0 ... TN;
|
int rs1[TS1] = pid1 * TS1 + 0 ... TS1;
|
||||||
TYPE* px[TM, TN] = X + rm[:, newaxis] * STRIDE_XM + rn[newaxis, :] * STRIDE_XN;
|
TYPE* px[TS0, TS1] = X + rs0[:, newaxis] * STRIDE_XS0 + rs1[newaxis, :] * STRIDE_XS1;
|
||||||
TYPE* py[TM, TN] = Y + rm[:, newaxis] * STRIDE_YM + rn[newaxis, :] * STRIDE_YN;
|
TYPE* py[TS0, TS1] = Y + rs0[:, newaxis] * STRIDE_YS0 + rs1[newaxis, :] * STRIDE_YS1;
|
||||||
*py = *px;
|
*py = *px;
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
const char *copy3d =
|
||||||
|
R"(
|
||||||
|
void copy3d(TYPE * X __noalias __readonly __aligned(16),
|
||||||
|
TYPE * Y __noalias __writeonly __aligned(16),
|
||||||
|
int S0 __multipleof(8),
|
||||||
|
int S1 __multipleof(8),
|
||||||
|
int S2 __multipleof(8)) {
|
||||||
|
// program id
|
||||||
|
int pid0 = get_program_id(0);
|
||||||
|
int pid1 = get_program_id(1);
|
||||||
|
int pid2 = get_program_id(2);
|
||||||
|
// ranges
|
||||||
|
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
|
||||||
|
int rs1[TS1] = pid1 * TS1 + 0 ... TS1;
|
||||||
|
int rs2[TS2] = pid2 * TS2 + 0 ... TS2;
|
||||||
|
// X pointers
|
||||||
|
TYPE* px[TS0, TS1, TS2] = X + rs0[:, newaxis, newaxis] * STRIDE_XS0
|
||||||
|
+ rs1[newaxis, :, newaxis] * STRIDE_XS1
|
||||||
|
+ rs2[newaxis, newaxis, :] * STRIDE_XS2;
|
||||||
|
// Y pointers
|
||||||
|
TYPE* py[TS0, TS1, TS2] = Y + rs0[:, newaxis, newaxis] * STRIDE_YS0
|
||||||
|
+ rs1[newaxis, :, newaxis] * STRIDE_YS1
|
||||||
|
+ rs2[newaxis, newaxis, :] * STRIDE_YS2;
|
||||||
|
*py = *px;
|
||||||
}
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
const char* copy_nd[] = {copy1d, copy2d, copy3d};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
@@ -4,6 +4,7 @@
|
|||||||
#define _TRITON_TESTS_UTIL_H
|
#define _TRITON_TESTS_UTIL_H
|
||||||
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
#include <cmath>
|
||||||
#include "triton/runtime/function.h"
|
#include "triton/runtime/function.h"
|
||||||
|
|
||||||
namespace drv = triton::driver;
|
namespace drv = triton::driver;
|
||||||
@@ -26,6 +27,30 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline rt::function::grid_fn_ty grid_nd(const std::vector<int32_t> &shape,
|
||||||
|
const std::vector<std::string>& ts) {
|
||||||
|
return [&shape, &ts](const rt::function::options_t& x) {
|
||||||
|
rt::grid_t ret;
|
||||||
|
for(size_t d = 0; d < shape.size(); d++)
|
||||||
|
ret.push_back(ceil(shape[d], x.D<int>(ts[d])));
|
||||||
|
return ret;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<std::vector<std::string>> tile_nd(size_t rank) {
|
||||||
|
assert(rank <= 3);
|
||||||
|
if(rank == 1)
|
||||||
|
return {{"128", "256", "512", "1024"}};
|
||||||
|
if(rank == 2)
|
||||||
|
return {{"16", "32", "64"},
|
||||||
|
{"16", "32", "64"}};
|
||||||
|
if(rank == 3)
|
||||||
|
return {{"4", "16", "32"},
|
||||||
|
{"4", "16", "32"},
|
||||||
|
{"4", "16", "32"}};
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
enum order_t {
|
enum order_t {
|
||||||
ROWMAJOR,
|
ROWMAJOR,
|
||||||
COLMAJOR
|
COLMAJOR
|
||||||
@@ -44,17 +69,30 @@ struct gen_seq<0, Is...> : seq<Is...>{};
|
|||||||
template<class Ch, class Tr, class Tuple, std::size_t... Is>
|
template<class Ch, class Tr, class Tuple, std::size_t... Is>
|
||||||
void print_tuple(std::basic_ostream<Ch,Tr>& os, Tuple const& t, seq<Is...>){
|
void print_tuple(std::basic_ostream<Ch,Tr>& os, Tuple const& t, seq<Is...>){
|
||||||
using swallow = int[];
|
using swallow = int[];
|
||||||
(void)swallow{0, (void(os << (Is == 0? "" : ", ") << std::setfill(' ') << std::setw(3) << std::get<Is>(t)), 0)...};
|
(void)swallow{0, (void(os << (Is == 0? "" : ", ") << std::get<Is>(t)), 0)...};
|
||||||
}
|
}
|
||||||
} // aux::
|
} // aux::
|
||||||
|
|
||||||
|
|
||||||
template<class Ch, class Tr, class... Args>
|
template<class Ch, class Tr, class... Args>
|
||||||
auto operator<<(std::basic_ostream<Ch, Tr>& os, std::tuple<Args...> const& t)
|
auto operator<<(std::basic_ostream<Ch, Tr>& os, std::tuple<Args...> const& t)
|
||||||
-> std::basic_ostream<Ch, Tr>&
|
-> std::basic_ostream<Ch, Tr>&
|
||||||
{
|
{
|
||||||
os << "(";
|
|
||||||
aux::print_tuple(os, t, aux::gen_seq<sizeof...(Args)>());
|
aux::print_tuple(os, t, aux::gen_seq<sizeof...(Args)>());
|
||||||
return os << ")";
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class Ch, class Tr, class T>
|
||||||
|
auto operator<<(std::basic_ostream<Ch, Tr>& os, std::vector<T> const& t)
|
||||||
|
-> std::basic_ostream<Ch, Tr>&
|
||||||
|
{
|
||||||
|
os << "{";
|
||||||
|
for(size_t i = 0; i < t.size(); i++) {
|
||||||
|
if(i > 0)
|
||||||
|
os << ", ";
|
||||||
|
os << t[i];
|
||||||
|
}
|
||||||
|
return os << "}";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
foreach(PROG dot)
|
foreach(PROG dot copy1d copy2d copy3d)
|
||||||
set(TARGET unit_${PROG})
|
set(TARGET unit_${PROG})
|
||||||
add_executable(${TARGET} ${PROG}.cc)
|
add_executable(${TARGET} ${PROG}.cc)
|
||||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})
|
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})
|
||||||
|
30
tests/unit/copy1d.cc
Normal file
30
tests/unit/copy1d.cc
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <tuple>
|
||||||
|
#include "copy.h"
|
||||||
|
#include "triton/driver/backend.h"
|
||||||
|
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
// initialize default compute device
|
||||||
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
|
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||||
|
// shapes to benchmark
|
||||||
|
typedef std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, std::vector<int>> config_t;
|
||||||
|
std::vector<config_t> configs = {
|
||||||
|
// {{65536}, {32}, {0}, {0}},
|
||||||
|
{{65536}, {128}, {0}, {0}},
|
||||||
|
{{65536}, {512}, {0}, {0}},
|
||||||
|
{{65536}, {1024}, {0}, {0}},
|
||||||
|
};
|
||||||
|
// does the work
|
||||||
|
std::vector<int32_t> shape, tile;
|
||||||
|
std::vector<int32_t> ord_x, ord_y;
|
||||||
|
bool result = true;
|
||||||
|
for(const auto& c: configs){
|
||||||
|
std::tie(shape, tile, ord_x, ord_y) = c;
|
||||||
|
bool pass = test_copy_nd(stream, shape, tile, ord_x, ord_y);
|
||||||
|
result = result && pass;
|
||||||
|
std::cout << "// " << c << ", " << pass << std::endl;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
46
tests/unit/copy2d.cc
Normal file
46
tests/unit/copy2d.cc
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <tuple>
|
||||||
|
#include "copy.h"
|
||||||
|
#include "triton/driver/backend.h"
|
||||||
|
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
// initialize default compute device
|
||||||
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
|
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||||
|
// shapes to benchmark
|
||||||
|
typedef std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, std::vector<int>> config_t;
|
||||||
|
std::vector<config_t> configs = {
|
||||||
|
{{256, 256}, {16, 16}, {0, 1}, {0, 1}},
|
||||||
|
{{256, 256}, {16, 64}, {0, 1}, {0, 1}},
|
||||||
|
{{256, 256}, {64, 16}, {0, 1}, {0, 1}},
|
||||||
|
{{256, 256}, {64, 64}, {0, 1}, {0, 1}},
|
||||||
|
|
||||||
|
{{256, 256}, {16, 16}, {0, 1}, {1, 0}},
|
||||||
|
{{256, 256}, {16, 64}, {0, 1}, {1, 0}},
|
||||||
|
{{256, 256}, {64, 16}, {0, 1}, {1, 0}},
|
||||||
|
{{256, 256}, {64, 64}, {0, 1}, {1, 0}},
|
||||||
|
|
||||||
|
{{256, 256}, {16, 16}, {1, 0}, {0, 1}},
|
||||||
|
{{256, 256}, {16, 64}, {1, 0}, {0, 1}},
|
||||||
|
{{256, 256}, {64, 16}, {1, 0}, {0, 1}},
|
||||||
|
{{256, 256}, {64, 64}, {1, 0}, {0, 1}},
|
||||||
|
|
||||||
|
{{256, 256}, {64, 64}, {1, 0}, {1, 0}},
|
||||||
|
{{256, 256}, {16, 64}, {1, 0}, {1, 0}},
|
||||||
|
{{256, 256}, {64, 16}, {1, 0}, {1, 0}},
|
||||||
|
{{256, 256}, {64, 64}, {1, 0}, {1, 0}},
|
||||||
|
};
|
||||||
|
// does the work
|
||||||
|
std::vector<int32_t> shape, tile;
|
||||||
|
std::vector<int32_t> ord_x, ord_y;
|
||||||
|
bool result = true;
|
||||||
|
for(const auto& c: configs){
|
||||||
|
std::tie(shape, tile, ord_x, ord_y) = c;
|
||||||
|
bool pass = test_copy_nd(stream, shape, tile, ord_x, ord_y);
|
||||||
|
result = result && pass;
|
||||||
|
std::cout << "// " << c << ", " << pass << std::endl;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
38
tests/unit/copy3d.cc
Normal file
38
tests/unit/copy3d.cc
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <tuple>
|
||||||
|
#include "copy.h"
|
||||||
|
#include "triton/driver/backend.h"
|
||||||
|
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
// initialize default compute device
|
||||||
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
|
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||||
|
// shapes to benchmark
|
||||||
|
typedef std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, std::vector<int>> config_t;
|
||||||
|
std::vector<config_t> configs;
|
||||||
|
std::vector<int> x_idx = {0, 1, 2};
|
||||||
|
do {
|
||||||
|
std::vector<int> y_idx = {0, 1, 2};
|
||||||
|
do {
|
||||||
|
configs.push_back(config_t{{64, 64, 32}, {16, 4, 8}, x_idx, y_idx});
|
||||||
|
configs.push_back(config_t{{64, 64, 32}, {8, 16, 2}, x_idx, y_idx});
|
||||||
|
configs.push_back(config_t{{64, 64, 32}, {32, 2, 2}, x_idx, y_idx});
|
||||||
|
configs.push_back(config_t{{64, 64, 32}, {16, 64, 4}, x_idx, y_idx});
|
||||||
|
|
||||||
|
} while(std::next_permutation(y_idx.begin(), y_idx.end()));
|
||||||
|
} while(std::next_permutation(x_idx.begin(), x_idx.end()));
|
||||||
|
// testing
|
||||||
|
std::vector<int32_t> shape, tile;
|
||||||
|
std::vector<int32_t> ord_x, ord_y;
|
||||||
|
bool result = true;
|
||||||
|
for(const auto& c: configs){
|
||||||
|
std::tie(shape, tile, ord_x, ord_y) = c;
|
||||||
|
bool pass = test_copy_nd(stream, shape, tile, ord_x, ord_y);
|
||||||
|
result = result && pass;
|
||||||
|
std::cout << "// " << c << ", " << pass << std::endl;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@@ -32,7 +32,7 @@ static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vecto
|
|||||||
float acc = 0;
|
float acc = 0;
|
||||||
for(size_t k = 0; k < K; k++)
|
for(size_t k = 0; k < K; k++)
|
||||||
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
|
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
|
||||||
c[m + n*M] = static_cast<T>(acc);
|
c[m*N + n] = static_cast<T>(acc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,7 +120,9 @@ int main() {
|
|||||||
std::cout << "Testing " << c << " ... " << std::flush;
|
std::cout << "Testing " << c << " ... " << std::flush;
|
||||||
if(do_test(stream, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp))
|
if(do_test(stream, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp))
|
||||||
std::cout << " Pass! " << std::endl;
|
std::cout << " Pass! " << std::endl;
|
||||||
else
|
else{
|
||||||
std::cout << " Fail! " << std::endl;
|
std::cout << " Fail! " << std::endl;
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user