[BACKEND] Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp & GEPOp and bugfix for SplatOp, StoreOp, FuncOp (#60)
Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp, GEPOp and bugfix for SplatOp, StoreOp, FuncOp Co-authored-by: gzhu <gzhu@nvidia.com>
This commit is contained in:
@@ -1504,7 +1504,7 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUVerifier());
|
||||
})
|
||||
.def("triton_gpu_to_llvm", [](mlir::PassManager &self) {
|
||||
.def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||
});
|
||||
}
|
||||
|
30
python/test/vecadd_no_scf.py
Normal file
30
python/test/vecadd_no_scf.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
NUM_WARPS = 4
|
||||
|
||||
# triton kernel
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel(x_ptr, stride_xn,
|
||||
y_ptr, stride_yn,
|
||||
z_ptr, stride_zn,
|
||||
BLOCK_SIZE_N: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
x_ptrs = x_ptr + offset
|
||||
y_ptrs = y_ptr + offset
|
||||
x = tl.load(x_ptrs)
|
||||
y = tl.load(y_ptrs)
|
||||
z = x + y
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z)
|
||||
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
|
||||
|
||||
print(ret)
|
||||
|
||||
# TODO: base class for python end2end tests,
|
||||
# runtime execution, correctness comparison etc.
|
Reference in New Issue
Block a user