Source code for qiskit_dynamics.array.wrap

# This code is part of Qiskit.
#
# (C) Copyright IBM 2017, 2020.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Functions for working with :class:`Array` dispatch."""

import functools
from types import FunctionType
from typing import Callable

from .array import Array


[docs] def wrap( func: Callable, wrap_return: bool = True, wrap_args: bool = True, decorator: bool = False ) -> Callable: r"""Wrap an array backend function to work with :class:`Array`\s. Args: func: A function to wrap. wrap_return: If ``True`` convert results that are registered array backend types into :class:`Array` objects. wrap_args: If ``True`` also wrap function type args and kwargs of the wrapped function. decorator: If ``True`` the wrapped decorator function ``func`` will also wrap the decorated functions. Returns: Callable: The wrapped function. """ if decorator: return _wrap_decorator(func, wrap_return=wrap_return, wrap_args=wrap_args) else: return _wrap_function(func, wrap_return=wrap_return, wrap_args=wrap_args)
def _wrap_array_function(func: Callable) -> Callable: r"""Wrap a function to handle :class:`Array`\-like inputs and returns.""" @functools.wraps(func) def wrapped_function(*args, **kwargs): # Unwrap inputs args = tuple( x.__qiskit_array__().data if hasattr(x, "__qiskit_array__") else x for x in args ) kwargs = dict( (key, val.__qiskit_array__().data) if hasattr(val, "__qiskit_array__") else (key, val) for key, val in kwargs.items() ) # Evaluate function with unwrapped inputs result = func(*args, **kwargs) # Unwrap result if isinstance(result, tuple): result = tuple( x.__qiskit_array__().data if hasattr(x, "__qiskit_array__") else x for x in result ) elif hasattr(result, "__qiskit_array__"): result = result.__qiskit_array__().data return result return wrapped_function def _wrap_args(args): """Return wrapped args.""" return tuple(_wrap_array_function(x) if isinstance(x, FunctionType) else x for x in args) def _wrap_kwargs(kwargs): """Return wrapped kwargs.""" return dict( (key, _wrap_array_function(val)) if isinstance(val, FunctionType) else (key, val) for key, val in kwargs.items() ) def _wrap_function(func: Callable, wrap_return: bool = True, wrap_args: bool = True) -> Callable: r"""Wrap an array backend function to work with :class:`Array`\s. Args: func: A function to wrap. wrap_return: If ``True`` convert results that are registered array backend types into :class:`Array` objects. wrap_args: If ``True`` also wrap function type args and kwargs of the wrapped function. Returns: Callable: The wrapped function. """ # pylint: disable = function-redefined if wrap_return and wrap_args: @functools.wraps(func) def wrapped_func(*args, **kwargs): args = _wrap_args(args) kwargs = _wrap_kwargs(kwargs) result = _wrap_array_function(func)(*args, **kwargs) return Array._wrap(result) return wrapped_func elif wrap_return and not wrap_args: @functools.wraps(func) def wrapped_func(*args, **kwargs): result = _wrap_array_function(func)(*args, **kwargs) return Array._wrap(result) return wrapped_func elif not wrap_return and wrap_args: @functools.wraps(func) def wrapped_func(*args, **kwargs): args = _wrap_args(args) kwargs = _wrap_kwargs(kwargs) return _wrap_array_function(func)(*args, **kwargs) return wrapped_func else: @functools.wraps(func) def wrapped_func(*args, **kwargs): return _wrap_array_function(func)(*args, **kwargs) return wrapped_func def _wrap_decorator(func: Callable, wrap_return: bool = True, wrap_args: bool = True) -> Callable: r"""Wrap a function decorator to work with :class:`Array`\s. Args: func: A function to wrap. wrap_return: If ``True`` convert results that are registered array backend types into :class:`Array` objects. wrap_args: If ``True`` also wrap function type args and kwargs of the wrapped function. Returns: Callable: The wrapped function. """ # pylint: disable = function-redefined if wrap_return and wrap_args: @functools.wraps(func) def wrapped_func(*args, **kwargs): args = _wrap_args(args) kwargs = _wrap_kwargs(kwargs) decorated = _wrap_array_function(func)(*args, **kwargs) @functools.wraps(args[0]) def wrapped_decorated(*f_args, **f_kwargs): f_args = _wrap_args(f_args) f_kwargs = _wrap_kwargs(f_kwargs) result = _wrap_function(decorated)(*f_args, **f_kwargs) return Array._wrap(result) return wrapped_decorated return wrapped_func if wrap_return and not wrap_args: @functools.wraps(func) def wrapped_func(*args, **kwargs): decorated = _wrap_array_function(func)(*args, **kwargs) @functools.wraps(args[0]) def wrapped_decorated(*f_args, **f_kwargs): result = _wrap_function(decorated)(*f_args, **f_kwargs) return Array._wrap(result) return wrapped_decorated return wrapped_func if not wrap_return and wrap_args: @functools.wraps(func) def wrapped_func(*args, **kwargs): args = _wrap_args(args) kwargs = _wrap_kwargs(kwargs) decorated = _wrap_array_function(func)(*args, **kwargs) @functools.wraps(args[0]) def wrapped_decorated(*f_args, **f_kwargs): f_args = _wrap_args(f_args) f_kwargs = _wrap_kwargs(f_kwargs) return _wrap_function(decorated)(*f_args, **f_kwargs) return wrapped_decorated return wrapped_func else: @functools.wraps(func) def wrapped_func(*args, **kwargs): decorated = _wrap_array_function(func)(*args, **kwargs) @functools.wraps(args[0]) def wrapped_decorated(*f_args, **f_kwargs): return _wrap_function(decorated)(*f_args, **f_kwargs) return wrapped_decorated return wrapped_func