Source code for _pytask.nodes

"""Contains implementations of tasks and nodes following the node protocols."""
from __future__ import annotations

import functools
import hashlib
import inspect
from pathlib import Path  # noqa: TCH003
from typing import Any
from typing import Callable
from typing import TYPE_CHECKING

from _pytask.node_protocols import PNode
from _pytask.node_protocols import PPathNode
from _pytask.node_protocols import PTask
from _pytask.node_protocols import PTaskWithPath
from _pytask.typing import no_default
from _pytask.typing import NoDefault
from attrs import define
from attrs import field


if TYPE_CHECKING:
    from _pytask.tree_util import PyTree
    from _pytask.mark import Mark


__all__ = ["PathNode", "PythonNode", "Task", "TaskWithoutPath"]


@define(kw_only=True)
class TaskWithoutPath(PTask):
    """The class for tasks without a source file.

    Tasks may have no source file because
    - they are dynamically created in a REPL.
    - they are created in a Jupyter notebook.

    Attributes
    ----------
    name
        The name of the task.
    function
        The task function.
    depends_on
        A list of dependencies of task.
    produces
        A list of products of task.
    markers
        A list of markers attached to the task function.
    report_sections
        Reports with entries for when, what, and content.
    attributes: dict[Any, Any]
        A dictionary to store additional information of the task.
    """

    name: str
    function: Callable[..., Any]
    depends_on: dict[str, PyTree[PNode]] = field(factory=dict)
    produces: dict[str, PyTree[PNode]] = field(factory=dict)
    markers: list[Mark] = field(factory=list)
    report_sections: list[tuple[str, str, str]] = field(factory=list)
    attributes: dict[Any, Any] = field(factory=dict)

    def state(self) -> str | None:
        """Return the state of the node."""
        try:
            source = inspect.getsource(self.function)
        except OSError:
            return None
        else:
            return hashlib.sha256(source.encode()).hexdigest()

    def execute(self, **kwargs: Any) -> None:
        """Execute the task."""
        return self.function(**kwargs)


[docs] @define(kw_only=True) class Task(PTaskWithPath): """The class for tasks which are Python functions. Attributes ---------- base_name The base name of the task. path Path to the file where the task was defined. function The task function. name The name of the task. display_name The shortest uniquely identifiable name for task for display. depends_on A list of dependencies of task. produces A list of products of task. markers A list of markers attached to the task function. report_sections Reports with entries for when, what, and content. attributes: dict[Any, Any] A dictionary to store additional information of the task. """ base_name: str path: Path function: Callable[..., Any] name: str = field(default="", init=False) display_name: str = field(default="", init=False) depends_on: dict[str, PyTree[PNode]] = field(factory=dict) produces: dict[str, PyTree[PNode]] = field(factory=dict) markers: list[Mark] = field(factory=list) report_sections: list[tuple[str, str, str]] = field(factory=list) attributes: dict[Any, Any] = field(factory=dict) def __attrs_post_init__(self: Task) -> None: """Change class after initialization.""" if not self.name: self.name = self.path.as_posix() + "::" + self.base_name if not self.display_name: self.display_name = self.name def state(self) -> str | None: """Return the state of the node.""" if self.path.exists(): return str(self.path.stat().st_mtime) return None def execute(self, **kwargs: Any) -> None: """Execute the task.""" return self.function(**kwargs)
[docs] @define(kw_only=True) class PathNode(PPathNode): """The class for a node which is a path. Attributes ---------- name Name of the node which makes it identifiable in the DAG. path The path to the file. """ name: str path: Path @classmethod @functools.lru_cache def from_path(cls, path: Path) -> PathNode: """Instantiate class from path to file. The `lru_cache` decorator ensures that the same object is not collected twice. """ if not path.is_absolute(): msg = "Node must be instantiated from absolute path." raise ValueError(msg) return cls(name=path.as_posix(), path=path) def state(self) -> str | None: """Calculate the state of the node. The state is given by the modification timestamp. """ if self.path.exists(): return str(self.path.stat().st_mtime) return None def load(self) -> Path: """Load the value.""" return self.path def save(self, value: bytes | str) -> None: """Save strings or bytes to file.""" if isinstance(value, str): self.path.write_text(value) elif isinstance(value, bytes): self.path.write_bytes(value) else: msg = f"'PathNode' can only save 'str' and 'bytes', not {type(value)}" raise TypeError(msg)
[docs] @define(kw_only=True) class PythonNode(PNode): """The class for a node which is a Python object. Attributes ---------- name Name of the node that is set internally. value Value of the node. hash Whether the value should be hashed to determine the state. """ name: str = "" value: Any | NoDefault = no_default hash: bool | Callable[[Any], bool] = False # noqa: A003 def load(self) -> Any: """Load the value.""" if isinstance(self.value, PythonNode): return self.value.load() return self.value def save(self, value: Any) -> None: """Save the value.""" self.value = value def state(self) -> str | None: """Calculate state of the node. If ``hash = False``, the function returns ``"0"``, a constant hash value, so the :class:`PythonNode` is ignored when checking for a changed state of the task. If ``hash`` is a callable, then use this function to calculate a hash. If ``hash = True``, the builtin ``hash()`` function (`link <https://docs.python.org/3.11/library/functions.html?highlight=hash#hash>`_) is used for all types except strings. The hash for strings and bytes is calculated using hashlib because ``hash("asd")`` returns a different value every invocation since the hash of strings is salted with a random integer and it would confuse users. See {meth}`object.__hash__` for more information. """ if self.hash: value = self.load() if callable(self.hash): return str(self.hash(value)) if isinstance(value, str): return str(hashlib.sha256(value.encode()).hexdigest()) if isinstance(value, bytes): return str(hashlib.sha256(value).hexdigest()) return str(hash(value)) return "0"