[FRONTEND] Add scalar type support for some ops (#661)
This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
@@ -1245,13 +1245,13 @@ void init_triton_ir(py::module &&m) {
|
||||
return mlir::Value(
|
||||
self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs));
|
||||
})
|
||||
// GEP
|
||||
.def("create_gep",
|
||||
// AddPtr (similar to GEP)
|
||||
.def("create_addptr",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptr,
|
||||
mlir::Value &offset) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::GEPOp>(loc, ptr.getType(), ptr,
|
||||
offset);
|
||||
return self.create<mlir::triton::AddPtrOp>(loc, ptr.getType(), ptr,
|
||||
offset);
|
||||
})
|
||||
// Comparison (int)
|
||||
.def("create_icmpSLE",
|
||||
|
@@ -121,7 +121,7 @@ def add(input: tl.tensor,
|
||||
if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
|
||||
input, other = other, input
|
||||
if input_scalar_ty.is_ptr():
|
||||
return tl.tensor(builder.create_gep(input.handle, other.handle), input.type)
|
||||
return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type)
|
||||
# float + float
|
||||
elif input_scalar_ty.is_floating():
|
||||
return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
|
||||
@@ -138,7 +138,7 @@ def sub(input: tl.tensor,
|
||||
scalar_ty = input.type.scalar
|
||||
# ptr - offset
|
||||
if scalar_ty.is_ptr():
|
||||
return tl.tensor(builder.create_gep(input.handle, minus(other, builder).handle),
|
||||
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle),
|
||||
input.type)
|
||||
# float - float
|
||||
if scalar_ty.is_floating():
|
||||
|
Reference in New Issue
Block a user