diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index c19d43e2e..43903c592 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -8,12 +8,12 @@ int main() { - bool AT = false; - bool BT = true; + bool AT = true; + bool BT = false; // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); // matrix multiplication parameters - int32_t M = 8192, N = 8192, K = 8192; + int32_t M = 2048, N = 2048, K = 2048; std::vector hc(M*N); std::vector rc(M*N); std::vector ha(M*K); diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 72ca66ad1..04a413b32 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -984,7 +984,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & distributed_tile *TC = (distributed_tile*)tmap_.at(C); Type *c_ty = llvm_type(C->get_type()->get_scalar_ty(), ctx); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty}); - unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value(); + size_t red_axis = dot->is_a_trans() ? 0 : 1; + unsigned NK = A->get_type()->get_tile_shapes()[red_axis]->get_value(); if(NK != 1) { shared_tile *TA = (shared_tile*)tmap_.at(A); @@ -1147,6 +1148,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & unsigned max_contiguous = axis_info_->get_max_contiguous(ptr); unsigned alignment = std::min(starting_multiple, max_contiguous); unsigned vector_size = std::min(result->axis(0).contiguous, alignment); +// vector_size = result->axis(0).contiguous; std::map packets; distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand()); result->for_each([&](indices_t idx){ diff --git a/lib/codegen/shmem_allocation.cpp b/lib/codegen/shmem_allocation.cpp index eb65b224f..4031864c2 100644 --- a/lib/codegen/shmem_allocation.cpp +++ b/lib/codegen/shmem_allocation.cpp @@ -21,13 +21,13 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) { bool is_op_1 = x == dot->get_operand(1); if(is_hmma && is_op_0){ if(dot->is_a_trans()) - return 20; + return 4; else return 16; } if(is_hmma && is_op_1){ if(!dot->is_b_trans()) - return 20; + return 4; else return 16; } diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 293ebf053..fcb519c4a 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -235,7 +235,7 @@ void tune::run(ir::module &mod) { continue; if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 4, 4)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 4)); *params_.at(i).at("nts.d0") = *tmp; } if(dynamic_cast(i) && i->get_type()->is_tile_ty()){