[CODEGEN][TRANSFORM] some bug-fixes for FP32 einsum

This commit is contained in:
Philippe Tillet
2020-01-19 19:58:40 -05:00
parent f278d9741a
commit fbf2a3f56f
4 changed files with 4 additions and 8 deletions

View File

@@ -125,7 +125,6 @@ class distributed_tile: public tile{
private:
void init_indices();
Type *make_vector_ty(Type *ty, size_t vector_size);
public:
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder);

View File

@@ -15,7 +15,7 @@ void distributed_tile::init_indices() {
std::vector<size_t> order(id.size());
std::iota(order.begin(), order.end(), 0);
auto cmp = [&](int x, int y) {
return axes_[x].contiguous > axes_[y].contiguous;
return order_[x] < order_[y];
};
std::sort(order.begin(), order.end(), cmp);
// build
@@ -39,11 +39,6 @@ void distributed_tile::init_indices() {
}
}
llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size) {
if(vector_size == 1)
return ty;
return VectorType::get(ty, vector_size);
}
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder)
: tile(ty, shapes), axes_(axes), order_(order), builder_(builder) {

View File

@@ -15,6 +15,8 @@ inline bool is_shmem_op(ir::instruction* i, int op) {
return op==0 || op==1;
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
return op==0;
if(i->get_id() == ir::INST_TRANS)
return op==0;
return false;
}

View File

@@ -168,7 +168,7 @@ for N, C, H, W, K, R, S in NCHWKRS:
# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
dtype = torch.cuda.HalfTensor
dtype = torch.cuda.FloatTensor
# initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()