diff --git a/include/triton/codegen/selection/machine_value.h b/include/triton/codegen/selection/machine_value.h index 917151971..aab1f023a 100644 --- a/include/triton/codegen/selection/machine_value.h +++ b/include/triton/codegen/selection/machine_value.h @@ -125,7 +125,6 @@ class distributed_tile: public tile{ private: void init_indices(); - Type *make_vector_ty(Type *ty, size_t vector_size); public: distributed_tile(Type *ty, const shapes_t& shapes, const std::vector& order, const axes_t &axes, Builder &builder); diff --git a/lib/codegen/selection/machine_value.cc b/lib/codegen/selection/machine_value.cc index 72aace4b2..a94661b90 100644 --- a/lib/codegen/selection/machine_value.cc +++ b/lib/codegen/selection/machine_value.cc @@ -15,7 +15,7 @@ void distributed_tile::init_indices() { std::vector order(id.size()); std::iota(order.begin(), order.end(), 0); auto cmp = [&](int x, int y) { - return axes_[x].contiguous > axes_[y].contiguous; + return order_[x] < order_[y]; }; std::sort(order.begin(), order.end(), cmp); // build @@ -39,11 +39,6 @@ void distributed_tile::init_indices() { } } -llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size) { - if(vector_size == 1) - return ty; - return VectorType::get(ty, vector_size); -} distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector& order, const axes_t &axes, llvm::IRBuilder<> &builder) : tile(ty, shapes), axes_(axes), order_(order), builder_(builder) { diff --git a/lib/codegen/transform/cts.cc b/lib/codegen/transform/cts.cc index f98b685e1..ae2791cc8 100644 --- a/lib/codegen/transform/cts.cc +++ b/lib/codegen/transform/cts.cc @@ -15,6 +15,8 @@ inline bool is_shmem_op(ir::instruction* i, int op) { return op==0 || op==1; if(i->get_id() == ir::INST_COPY_FROM_SHARED) return op==0; + if(i->get_id() == ir::INST_TRANS) + return op==0; return false; } diff --git a/python/examples/einsum.py b/python/examples/einsum.py index 2cbf2ca10..a3fdba5e0 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -168,7 +168,7 @@ for N, C, H, W, K, R, S in NCHWKRS: # Benchmark torch.set_num_threads(1) for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: - dtype = torch.cuda.HalfTensor + dtype = torch.cuda.FloatTensor # initialize input tensors a = torch.rand(*a_shape).type(dtype).cuda() b = torch.rand(*b_shape).type(dtype).cuda()