diff --git a/docs/tutorials/custom-operation.rst b/docs/tutorials/custom-operation.rst index 30619bef9..e6abffb12 100644 --- a/docs/tutorials/custom-operation.rst +++ b/docs/tutorials/custom-operation.rst @@ -39,7 +39,7 @@ As you will see, a wrapper for the above Triton function can be created in just import torch import triton - class _add(triton.function): + class _add(torch.autograd.Function): # source-code for Triton compute kernel src = """ __global__ void add(float* z, float* x, float* y, int N){ diff --git a/docs/tutorials/putting-it-all-together.rst b/docs/tutorials/putting-it-all-together.rst index 3303750a6..4f760c83c 100644 --- a/docs/tutorials/putting-it-all-together.rst +++ b/docs/tutorials/putting-it-all-together.rst @@ -15,7 +15,7 @@ The PyTriton API provides a :code:`triton.function` class which automatically ha import triton # Entry point - class _dot(triton.function): + class _dot(torch.autograd.Function): @staticmethod # Forward Pass @@ -170,7 +170,7 @@ Creating custom operations for Triton and PyTorch is very similar; programmers h return da, db, None, None, None, None, None, None, None -A callable operation can be created using the :code:`apply` method of our :code:`triton.function` class. +A callable operation can be created using the :code:`apply` method of the :code:`torch.autograd.Function` class. .. code:: python