[FRONTEND] Add scalar type support for some ops (#661)

This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
Shintaro Iwasaki
2022-09-15 16:12:52 -07:00
committed by GitHub
parent 2e08450c80
commit 43be75ad42
27 changed files with 203 additions and 129 deletions

View File

@@ -1245,13 +1245,13 @@ void init_triton_ir(py::module &&m) {
return mlir::Value(
self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs));
})
// GEP
.def("create_gep",
// AddPtr (similar to GEP)
.def("create_addptr",
[](mlir::OpBuilder &self, mlir::Value &ptr,
mlir::Value &offset) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::GEPOp>(loc, ptr.getType(), ptr,
offset);
return self.create<mlir::triton::AddPtrOp>(loc, ptr.getType(), ptr,
offset);
})
// Comparison (int)
.def("create_icmpSLE",

View File

@@ -121,7 +121,7 @@ def add(input: tl.tensor,
if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
input, other = other, input
if input_scalar_ty.is_ptr():
return tl.tensor(builder.create_gep(input.handle, other.handle), input.type)
return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type)
# float + float
elif input_scalar_ty.is_floating():
return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
@@ -138,7 +138,7 @@ def sub(input: tl.tensor,
scalar_ty = input.type.scalar
# ptr - offset
if scalar_ty.is_ptr():
return tl.tensor(builder.create_gep(input.handle, minus(other, builder).handle),
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle),
input.type)
# float - float
if scalar_ty.is_floating():