Source code for astronat.units.decorators

# -*- coding: utf-8 -*-

"""Decorators for functions accepting Astropy Quantities."""

__author__ = "Nathaniel Starkman"
__credit__ = "astropy"


__all__ = ["quantity_output", "QuantityInputOutput", "quantity_io"]


##############################################################################
# IMPORTS

import textwrap
import typing as T

from astropy.units import dimensionless_unscaled
from astropy.units.core import Unit, add_enabled_equivalencies
from astropy.units.decorators import _get_allowed_units, _validate_arg_value
from astropy.utils.decorators import format_doc
from astropy.utils.misc import isiterable
from utilipy.utils import functools, inspect
from utilipy.utils.typing import UnitableType

from .core import _doc_base_params, _doc_base_raises, quantity_return_

###############################################################################
# PARAMETERS

_aioattrs = (
    "unit",
    "to_value",
    "equivalencies",
    "decompose",
    "assumed_units",
    "assume_annotation_units",
)


# ----------------------------------------


_doc_quantity_output_examples: str = """
`quantity_output` decorated function

    >>> from astronat.units.decorators import quantity_output
    >>> @quantity_output(unit=u.m, to_value=True)
    ... def example_function(x):
    ...     return x
    >>> example_function(10 * u.km)
    10000.0
    >>> example_function(10)
    10
    >>> example_function(10 * u.km, to_value=False)  # doctest: +FLOAT_CMP
    <Quantity 10000. m>

"""
_doc_quantity_output_examples = _doc_quantity_output_examples[1:]

_doc_quantity_output_wrapped: str = """
Other Parameters
----------------
{parameters}

Raises
-------
{raises}

Examples
--------
{examples}
""".format(
    parameters=_doc_base_params,
    raises=_doc_base_raises,
    examples=_doc_quantity_output_examples,
)


# ----------------------------------------

# QuantityInputOutput parameters, combine base and assumed_units
_doc_qio_params: str = """function: Callable
    the function to decorate (default None)
{parameters}

assumed_units: dict
    dictionary of default units
    (default dict())

    >>> from astronat.units.decorators import quantity_io
    >>> dfu = dict(x=u.km)
    >>> x = 10
    >>> y = 20*u.km
    >>> @quantity_io(assumed_units=dfu)
    ... def add(x, y):
    ...     return x + y
    >>> add(x, y) # doctest: +SKIP
    <Quantity 30.0 km>

assume_annotation_units: bool, optional
    whether to interpret function annotations as default units
    (default False)
    function annotations have lower precedence than `assumed_units`

""".format(
    parameters=_doc_base_params,
)

_doc_qio_notes: str = """
Order of Precedence:

1. Function Arguments
2. Decorator Arguments
3. Function Annotation Arguments

Decorator Key-Word Arguments:

    Unit specifications can be provided as keyword arguments
    to the decorator, or by using function annotation syntax.
    Arguments to the decorator take precedence
    over any function annotations present.
    **note**
    decorator key-word arguments are NEVER interpreted as `assumed_units`

    >>> from astronat.units.decorators import quantity_io
    >>> @quantity_io(x=u.m, y=u.s)
    ... def func(x, y):
    ...     pass

Function Annotation Arguments:

    Unit specifications can be provided as keyword arguments
    to the decorator, or by using function annotation syntax.
    Arguments to the function and decorator take precedence
    over any function annotations present.

    >>> def func(x: u.m, y: u.s) -> u.m / u.s:
    ...     pass

    if `assume_annotation_units` is True (default False)
    function annotations are interpreted as default units
    function annotations have lower precedence than `assumed_units`

"""

# TODO replace
_funcdec: str = """

Other Parameters
----------------
{parameters}

""".format(
    parameters=_doc_qio_params
)


###############################################################################
# CODE
###############################################################################


@format_doc(
    None,
    parameters=textwrap.indent(_doc_base_params, " " * 4)[4:],
    raises=textwrap.indent(_doc_base_raises, " " * 4),
    examples=textwrap.indent(_doc_quantity_output_examples, " " * 4),
    # doc_quantity_output_wrapped=textwrap.indent(
    #     _doc_quantity_output_wrapped, " " * 12 + "| "
    # ),
)
def quantity_output(
    function: T.Callable = None,
    *,
    unit: UnitableType = None,
    to_value: bool = False,
    equivalencies: T.Sequence = [],
    decompose: T.Union[bool, T.Sequence] = False,
):
    r"""Decorate functions for unit output.

    Any wrapped function accepts the additional key-word arguments
    `unit`, `to_value`, `equivalencies`, `decompose`

    Parameters
    ----------
    {parameters}

    Returns
    -------
    wrapper: Callable
        wrapped function
        with the unit operations performed by
        :func:`~astronat.units.quantity_return_`

    Raises
    ------
    {raises}

    Examples
    --------
    .. code-block:: python

        @quantity_output
        def func(x, y):
            return x + y

    is equivalent to

    .. code-block:: python

        def func(x, y, unit=None, to_value=False, equivalencies=[],
                 decompose=False):
            result = x + y
            return quantity_return_(result, unit, to_value, equivalencies,
                                    decompose)

    {examples}

    """
    # allowing for optional arguments
    if function is None:
        return functools.partial(
            quantity_output,
            unit=unit,
            to_value=to_value,
            equivalencies=equivalencies,
            decompose=decompose,
        )

    # making decorator
    @functools.wraps(function)
    @format_doc(
        None,
        # _doc_quantity_output_wrapped=textwrap.indent(
        #     _doc_quantity_output_wrapped, " " * 8
        # ),
        parameters=textwrap.indent(_doc_base_params, " " * 8)[8:],
        raises=textwrap.indent(_doc_base_raises, " " * 8)[8:],
        examples=textwrap.indent(_doc_quantity_output_examples, " " * 8)[8:],
    )
    def wrapper(
        *args: T.Any,
        unit: T.Type[Unit] = unit,
        to_value: bool = to_value,
        equivalencies: T.Sequence = equivalencies,
        decompose: T.Union[bool, T.Sequence] = decompose,
        **kwargs: T.Any,
    ):
        """Wrapper docstring.

        Other Parameters
        ----------------
        {parameters}

        Raises
        ------
        {raises}

        Examples
        --------
        {examples}

        """
        return quantity_return_(
            function(*args, **kwargs),  # evaluated function
            unit=unit,
            to_value=to_value,
            equivalencies=equivalencies,
            decompose=decompose,
        )

    # /def

    return wrapper


# /def


###############################################################################


[docs]class QuantityInputOutput: """Decorator for validating the units of arguments to functions."""
[docs] @format_doc( None, parameters=textwrap.indent(_doc_qio_params, " " * 8), notes=textwrap.indent(_doc_qio_notes, " " * 8), ) @classmethod def as_decorator( cls, function: T.Callable = None, unit: UnitableType = None, to_value: bool = False, equivalencies: T.Sequence = [], decompose: T.Union[bool, T.Sequence] = False, assumed_units: T.Dict = {}, assume_annotation_units: bool = False, **decorator_kwargs, ): """Decorator for validating the units of arguments to functions. Parameters ---------- {parameters} See Also -------- :class:`~astropy.units.quantity_input` Notes ----- {notes} """ # making instance from base class self = super().__new__(cls) # modifying docstring _locals = locals() self.__doc__ = __doc__.format( **{k: _locals.get(k).__repr__() for k in set(_aioattrs)} ) self.__init__( unit=unit, to_value=to_value, equivalencies=equivalencies, decompose=decompose, assumed_units=assumed_units, assume_annotation_units=assume_annotation_units, **decorator_kwargs, ) if function is not None: return self(function) return self
# /def # ------------------------------------------ @format_doc( None, parameters=textwrap.indent(_doc_qio_params, " " * 4), notes=textwrap.indent(_doc_qio_notes, " " * 4), ) def __init__( self, function: T.Callable = None, unit: UnitableType = None, to_value: bool = False, equivalencies: T.Sequence = [], decompose: T.Union[bool, T.Sequence] = False, assumed_units: dict = {}, assume_annotation_units: bool = False, **decorator_kwargs, ): """Decorator for validating the units of arguments to functions. Parameters ---------- {parameters} See Also -------- :class:`~astropy.units.quantity_input` Notes ----- {notes} """ super().__init__() self.unit = unit self.to_value = to_value self.equivalencies = equivalencies self.decompose = decompose self.assumed_units = assumed_units self.assume_annotation_units = assume_annotation_units self.decorator_kwargs = decorator_kwargs return # /def # ------------------------------------------
[docs] def __call__(self, wrapped_function: T.Callable): """Make decorator. Parameters ---------- wrapped_function : Callable function to wrap Returns ------- wrapped: Callable wrapped function """ # Extract the function signature for the function we are wrapping. wrapped_signature = inspect.signature(wrapped_function) @functools.wraps(wrapped_function) def wrapped( *func_args: T.Any, unit: UnitableType = self.unit, to_value: bool = self.to_value, equivalencies: T.Sequence = self.equivalencies, decompose: T.Union[bool, T.Sequence] = self.decompose, assumed_units: dict = self.assumed_units, _skip_decorator: bool = False, **func_kwargs: T.Any, ): # skip the decorator if _skip_decorator: return wrapped_function(*func_args, **func_kwargs) # make func_args editable _func_args: list = list(func_args) # Bind the arguments to our new function to the signature of the original. bound_args = wrapped_signature.bind(*_func_args, **func_kwargs) # Iterate through the parameters of the original signature for i, param in enumerate(wrapped_signature.parameters.values()): # We do not support variable arguments (*args, **kwargs) if param.kind in { inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL, }: continue # Catch the (never) case where bind relied on a default value. if ( param.name not in bound_args.arguments and param.default is not param.empty ): bound_args.arguments[param.name] = param.default # Get the value of this parameter (argument to new function) arg = bound_args.arguments[param.name] # +----------------------------------+ # Get default unit or physical type, # either from decorator kwargs # or annotations if param.name in assumed_units: dfunit = assumed_units[param.name] elif self.assume_annotation_units is True: dfunit = param.annotation # elif not assumed_units: # dfunit = param.annotation else: dfunit = inspect.Parameter.empty adjargbydfunit = True # If the dfunit is empty, then no target units or physical # types were specified so we can continue to the next arg if dfunit is inspect.Parameter.empty: adjargbydfunit = False # If the argument value is None, and the default value is None, # pass through the None even if there is a dfunit unit elif arg is None and param.default is None: adjargbydfunit = False # Here, we check whether multiple dfunit unit/physical type's # were specified in the decorator/annotation, or whether a # single string (unit or physical type) or a Unit object was # specified elif isinstance(dfunit, str): dfunit = _get_allowed_units([dfunit])[0] elif not isiterable(dfunit): pass else: raise ValueError("target must be one Unit, not list") if (not hasattr(arg, "unit")) & (adjargbydfunit is True): if i < len(_func_args): # print(i, len(bound_args.args)) _func_args[i] *= dfunit else: func_kwargs[param.name] *= dfunit arg *= dfunit # +----------------------------------+ # Get target unit or physical type, # from decorator kwargs or annotations if param.name in self.decorator_kwargs: targets = self.decorator_kwargs[param.name] else: targets = param.annotation # If the targets is empty, then no target units or physical # types were specified so we can continue to the next arg if targets is inspect.Parameter.empty: continue # If the argument value is None, and the default value is None, # pass through the None even if there is a target unit if arg is None and param.default is None: continue # Here, we check whether multiple target unit/physical type's # were specified in the decorator/annotation, or whether a # single string (unit or physical type) or a Unit object was # specified if isinstance(targets, str) or not isiterable(targets): valid_targets = [targets] # Check for None in the supplied list of allowed units and, if # present and the passed value is also None, ignore. elif None in targets: if arg is None: continue else: valid_targets = [t for t in targets if t is not None] if not hasattr(arg, "unit"): arg = arg * dimensionless_unscaled valid_targets.append(dimensionless_unscaled) else: valid_targets = targets # Now loop over the allowed units/physical types and validate # the value of the argument: _validate_arg_value( param.name, wrapped_function.__name__, arg, valid_targets, self.equivalencies, ) # # evaluated wrapped_function with add_enabled_equivalencies(equivalencies): return_ = wrapped_function(*_func_args, **func_kwargs) # if func_kwargs: # return_ = wrapped_function(*_func_args, **func_kwargs) # else: # return_ = wrapped_function(*_func_args) if ( wrapped_signature.return_annotation not in (inspect.Signature.empty, None) and unit is None ): unit = wrapped_signature.return_annotation return quantity_return_( return_, unit=unit, to_value=to_value, equivalencies=equivalencies, decompose=decompose, ) # /def # TODO dedent # wrapped.__doc__ = inspect.cleandoc(wrapped.__doc__ or "") + _funcdec wrapped.__doc__ = wrapped_function.__doc__ return wrapped
# /def quantity_io = QuantityInputOutput.as_decorator # /class ############################################################################### def from_amuse_decorator( function: T.Callable = None, *, arguments: list = [] ) -> T.Callable: """Function decorator to convert inputs to Astropy quantities. Parameters ---------- function : types.FunctionType or None, optional the function to be decoratored if None, then returns decorator to apply. arguments : list, optional arguments to convert integers are indices into `arguments` strings are names of `kw` arguments Returns ------- wrapper : types.FunctionType wrapper for function does a few things includes the original function in a method `.__wrapped__` """ from .convert import from_amuse # TODO, to prevent circular import if not all([isinstance(a, (int, str)) for a in arguments]): raise TypeError("elements of `arguments` must be int or str") if function is None: # allowing for optional arguments return functools.partial(from_amuse_decorator, arguments=arguments) sig = inspect.signature(function) pnames = tuple(sig.parameters.keys()) @functools.wraps(function) def wrapper(*args, **kw): """Wrapper docstring.""" ba = sig.bind_partial(*args, **kw) ba.apply_defaults() for i in arguments: if isinstance(i, str): ba.arguments[i] = from_amuse(ba.arguments[i]) else: # int ba.arguments[pnames[i]] = from_amuse(ba.arguments[pnames[i]]) return function(*ba.args, **ba.kwargs) # /def return wrapper # /def ############################################################################### # END