Source code for lightcurvelynx.math_nodes.basic_math_node

"""Nodes that perform basic math operations that can be specified as strings.

The goal of this library is to save users from needing to create a bunch of
small FunctionNodes to perform basic math.
"""

import ast

from lightcurvelynx.base_models import FunctionNode


[docs] class BasicMathNode(FunctionNode): """A node that evaluates basic mathematical functions. The BasicMathNode wraps Python's eval() function to sanitize the input string and thus prevent the execution of arbitrary code. It also allows the user to write the expression once and execute using math, numpy, or JAX. The names of the variables in the expression must match the input variables provided by kwargs. Example:: my_node = BasicMathNode( "redshift + 10.0 * sin(phase)", redshift=host.redshift, phase=source.phase, ) Attributes ---------- expression : str The expression to evaluate. backend : str The math libary to use. This is auto-converted to one of (math, np, or jnp) depending on the input parameter. Parameters ---------- expression : str The expression to evaluate. backend : str The math libary to use. Must be one of: math, numpy, np, jax, or jnp. node_label : str, optional An identifier (or name) for the current node. **kwargs : dict, optional Any additional keyword arguments. Every variable in the expression must be included as a kwarg. """ # A list of supported Python operations. Used to prevent eval from # running arbitrary python expressions. The Call and Name types are special # cased so we can do checks and translations. _supported_ast_nodes = ( ast.Module, # Top level object when parsed as exec. ast.Expression, # Top level object when parsed as eval. ast.Expr, # Math expressions. ast.Constant, # Constant values. ast.Load, # Load a variable - must come from an approved function or variable. ast.Store, # Store value - must come from an approved function or variable. ast.BinOp, # Binary operations ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.UnaryOp, # Uninary operations ast.UAdd, ast.USub, ast.Invert, ) # A map from a very limited set of supported math constant/function names to # the corresponding names in [math, numpy, jax]. This is needed because # a very few functions have different names in different libraries. _math_map = { "abs": ["abs", "np.abs", "jnp.abs"], # Special handling for math. "acos": ["math.acos", "np.acos", "jnp.acos"], "acosh": ["math.acosh", "np.acosh", "jnp.acosh"], "asin": ["math.asin", "np.asin", "jnp.asin"], "asinh": ["math.asinh", "np.asinh", "jnp.asinh"], "atan": ["math.atan", "np.atan", "jnp.atan"], "atan2": ["math.atan2", "np.atan2", "jnp.atan2"], "cos": ["math.cos", "np.cos", "jnp.cos"], "cosh": ["math.cosh", "np.cosh", "jnp.cosh"], "ceil": ["math.ceil", "np.ceil", "jnp.ceil"], "degrees": ["math.degrees", "np.degrees", "jnp.degrees"], "deg2rad": ["math.radians", "np.deg2rad", "jnp.deg2rad"], # Special handling for math "e": ["math.e", "np.e", "jnp.e"], "exp": ["math.exp", "np.exp", "jnp.exp"], "fabs": ["math.fabs", "np.fabs", "jnp.fabs"], "floor": ["math.floor", "np.floor", "jnp.floor"], "log": ["math.log", "np.log", "jnp.log"], "log10": ["math.log10", "np.log10", "jnp.log10"], "log2": ["math.log2", "np.log2", "jnp.log2"], "max": ["max", "np.max", "jnp.max"], # Special handling for math "min": ["min", "np.min", "jnp.min"], # Special handling for math "pi": ["math.pi", "np.pi", "jnp.pi"], "pow": ["math.pow", "np.power", "jnp.power"], # Special handling for numpy "power": ["math.pow", "np.power", "jnp.power"], # Special handling for math "radians": ["math.radians", "np.radians", "jnp.radians"], "rad2deg": ["math.degrees", "np.rad2deg", "jnp.rad2deg"], # Special handling for math "sin": ["math.sin", "np.sin", "jnp.sin"], "sinh": ["math.sinh", "np.sinh", "jnp.sinh"], "sqrt": ["math.sqrt", "np.sqrt", "jnp.sqrt"], "tan": ["math.tan", "np.tan", "jnp.tan"], "tanh": ["math.tanh", "np.tanh", "jnp.tanh"], "trunc": ["math.trunc", "np.trunc", "jnp.trunc"], } def __init__(self, expression, backend="numpy", node_label=None, **kwargs): # Set the backend and the corresponding math library. self._set_backend(backend) # Check the expression is pure math and translate it into the correct backend. self._prepare(expression, **kwargs) super().__init__(self.eval, node_label=node_label, **kwargs) def _set_backend(self, backend): if backend == "jax" or backend == "jnp": try: import jax.numpy as jnp except ImportError as err: # pragma: no cover raise ImportError( "JAX is required to use the BasicMathNode with backend='jax', please " "install with `pip install jax` or `conda install conda-forge::jax`" ) from err self.backend = "jnp" self.backend_lib = jnp self.to_array = jnp.asarray elif backend == "numpy" or backend == "np": import numpy as np self.backend = "np" self.backend_lib = np self.to_array = np.asarray elif backend == "math": import math self.backend = "math" self.backend_lib = math self.to_array = lambda x: x # No conversion else: raise ValueError( f"Unsupported math backend '{backend}'. Must be one of: math, numpy, np, jax, or jnp." )
[docs] def __getstate__(self): """We override the default pickling behavior to handle non-pickable attributes such as the backend_lib and to_array function. """ state = self.__dict__.copy() # Remove the backend_lib from the state to be pickled. if "backend_lib" in state: del state["backend_lib"] if "to_array" in state: del state["to_array"] return state
[docs] def __setstate__(self, state): """We override the default unpickling behavior to restore the non-pickable attributes.""" self.__dict__.update(state) # Restore the backend_lib and to_array attributes based on the backend. self._set_backend(self.backend)
[docs] def eval(self, **kwargs): """Evaluate the expression.""" params = self._prepare_params(**kwargs) params[self.backend] = self.backend_lib try: return eval(self.expression, globals(), params) except Exception as problem: # pragma: no cover # Provide more detailed logging, including the expression and parameters # used, when we encounter a math error like divide by zero. new_message = ( f"Error during math operation '{self.expression}' with args={kwargs}. " f"Original error: {problem}" ) raise type(problem)(new_message) from problem
@staticmethod
[docs] def list_functions(): """Return a list of the support functions. Returns ------- list A list of the supported functions. """ return list(BasicMathNode._math_map.keys())
def _prepare_params(self, **kwargs): """Convert all of the incoming parameters into the correct type, such as numpy arrays. Parameters ---------- **kwargs : dict, optional The keyword arguments, including every variable in the expression. Returns ------- params : dict The converted list of parameters. """ params = {} for name, value in kwargs.items(): params[name] = self.to_array(value) return params def _prepare(self, expression, **kwargs): """Rewrite a python expression that consists of only basic math to use the prespecified math library. Santizes the string to prevent arbitrary code execution. Parameters ---------- expression : str The expression to evaluate. Must only contain basic math operations, functions on the allow list, and variables provided in kwargs. **kwargs : dict, optional Any additional keyword arguments, including the variable assignments. """ tree = ast.parse(expression) # Walk the tree and confirm that it only contains the basic math. for node in ast.walk(tree): if isinstance(node, self._supported_ast_nodes): # Nothing to do, this is a valid operation for the ast. continue elif isinstance(node, ast.Call): # Check that function calls are only using items on the allow list. if node.func.id not in self._math_map: raise ValueError(f"Unsupported function {node.func.id}") elif isinstance(node, ast.Name): if node.id in kwargs: # This is a user supplied variable. continue elif node.id in self._math_map: # This is a math function or constant. Overwrite if self.backend == "math": node.id = self._math_map[node.id][0] elif self.backend == "numpy" or self.backend == "np": node.id = self._math_map[node.id][1] elif self.backend == "jax" or self.backend == "jnp": node.id = self._math_map[node.id][2] else: raise ValueError( f"Unrecognized named variable or function {node.id}. " "This could be because the function is not supported or " "you forgot to include the variable as an argument." ) else: raise ValueError(f"Invalid part of expression {type(node)}") # Convert the expression back into a string. self.expression = ast.unparse(tree)