Creating Custom Function Nodes

In this tutorial we do a deep dive into FunctionNode objects and how users can create their own nodes for various functions. Users should already be familiar with the concepts covered in the Sampling Parameters notebook.

[1]:
from lightcurvelynx.base_models import FunctionNode
from lightcurvelynx.math_nodes.np_random import NumpyRandomFunc

Function Node Overview

Function nodes provide users the ability to wrap arbitrary computations during the parameter generation stage. The name FunctionNode is a bit of a misnomer as these nodes can wrap any Python callable object. For simplicity, we will use the term function throughout this notebook, but users should understand the more general behavior.

The basic flow of the FunctionNode is wrapped in the base class’s compute method:

  • Assemble the wrapped function’s input values from the GraphState object (model parameters) and keyword arguments,

  • Call the wrapped function with the assembled input values,

  • Capture the function’s output, and

  • Write those values to the GraphState object.

By default each function node stores its result in a parameter called function_node_result. Since model parameters are indexed by a combination of node name and parameter name, it will often be the case that multiple nodes in the model will generate function_node_result values. As we will see later in this notebook, we can override the name of the outputs to be more user friendly.

There are two ways to use the FunctionNode class: as a standalone wrapper or as a parent class.

FunctionNode as a Standalone Wrapper

Users can wrap a function directly by passing the function and its arguments into the FunctionNode constructor. This wraps the provided function and uses the functions returned value as its output.

As a concrete example, let’s create a FunctionNode that computes wraps an existing function that computes y = m * x + b. We need to pass this function and values for each of its parameters to the constructor.

[2]:
# This is the function we would like to wrap.
def linear_eq_function(x, m, b):
    """Compute y = m * x + b"""
    return m * x + b


# This is how we wrap linear_eq_function.
func_node = FunctionNode(
    linear_eq_function,  # First argument is the function to call.
    # The function's parameters are given as keyword arguments to the FunctionNode.
    x=NumpyRandomFunc("uniform", low=0.0, high=10.0),  # Random value
    m=5.0,  # Constant value
    b=-2.0,  # Constant value
)

The first parameter of the function node is the function to evaluate, such as our linear equation above (linear_eq_function). Each input into that function must be included as a named parameter during the FunctionNode definition, such as x, m, and b above. If any of the input parameters are missing, the code will give an error. The FunctionNode class handles all the internal book keeping of: determining the names of the function’s arguments, creating internal parameters, and assembling those arguments whenever the function is called.

Here we provide constants for m and b so we use the same linear formulation for each sample. Only the value of x changes. However, we could have also used a whole tree of function nodes, including sampling functions, to set m and b. In that case it is important to remember that each of our results is a consistent sampling and computation over all the parameters in the model.

[3]:
state = func_node.sample_parameters(num_samples=5)
print(state)
NumpyRandomFunc:uniform_1:
    low: [0. 0. 0. 0. 0.]
    high: [10. 10. 10. 10. 10.]
    function_node_result: [8.68548435 1.75274506 6.0863476  1.13596911 7.60445169]
FunctionNode:linear_eq_function_0:
    x: [8.68548435 1.75274506 6.0863476  1.13596911 7.60445169]
    m: [5. 5. 5. 5. 5.]
    b: [-2. -2. -2. -2. -2.]
    function_node_result: [41.42742173  6.76372531 28.43173799  3.67984553 36.02225844]

As described above both of the nodes (the numpy sampler and the linear function) create function_node_result parameters to store their intermediate results.

The nodes can be chained by using one FunctionNode as the value for a parameter of another. When the a FunctionNode is passed as a parameter, LightCurveLynx will automatically link that parameter to the FunctionNode’s function_node_result value. Below you can see that the input (x) of our increment function corresponds directly to the output (function_node_result) of the linear equation function.

[4]:
def increment(x):
    """Increment x by 1."""
    return x + 1


# This is how we wrap increment function.
inc_node = FunctionNode(
    increment,  # First argument is the function to call.
    # The function's parameters are given as keyword arguments to the FunctionNode.
    x=func_node,  # Use the output of func_node as input to increment.
)

state = inc_node.sample_parameters(num_samples=5)
print(state)
NumpyRandomFunc:uniform_2:
    low: [0. 0. 0. 0. 0.]
    high: [10. 10. 10. 10. 10.]
    function_node_result: [4.35653998 8.34570126 5.78434961 6.28502969 9.79076755]
FunctionNode:linear_eq_function_1:
    x: [4.35653998 8.34570126 5.78434961 6.28502969 9.79076755]
    m: [5. 5. 5. 5. 5.]
    b: [-2. -2. -2. -2. -2.]
    function_node_result: [19.78269991 39.7285063  26.92174807 29.42514845 46.95383777]
FunctionNode:increment_0:
    x: [19.78269991 39.7285063  26.92174807 29.42514845 46.95383777]
    function_node_result: [20.78269991 40.7285063  27.92174807 30.42514845 47.95383777]

We could make the linking of parameters more explicit by using the dot notation and the parameter name. But the behavior is identical.

[5]:
# This is how we wrap increment function.
inc_node = FunctionNode(
    increment,  # First argument is the function to call.
    # The function's parameters are given as keyword arguments to the FunctionNode.
    x=func_node.function_node_result,  # named parameter
)

state = inc_node.sample_parameters(num_samples=5)
print(state)
NumpyRandomFunc:uniform_2:
    low: [0. 0. 0. 0. 0.]
    high: [10. 10. 10. 10. 10.]
    function_node_result: [6.28242412 1.4054691  2.71360123 0.14566048 1.50508898]
FunctionNode:linear_eq_function_1:
    x: [6.28242412 1.4054691  2.71360123 0.14566048 1.50508898]
    m: [5. 5. 5. 5. 5.]
    b: [-2. -2. -2. -2. -2.]
    function_node_result: [29.4121206   5.02734548 11.56800614 -1.27169758  5.52544488]
FunctionNode:increment_0:
    x: [29.4121206   5.02734548 11.56800614 -1.27169758  5.52544488]
    function_node_result: [30.4121206   6.02734548 12.56800614 -0.27169758  6.52544488]

More realistically, users will want to wrap functions that perform complex astronomical calculations.

FunctionNode Subclasses

In the case where users will want to create function nodes that carry around additional data, users can create subclasses of the FunctionNode class. For example, when computing the distmod from the redshift, we need to load the cosmology. While we could load the cosmology each time the function is called, it would be more efficient to load it once and reuse it across computations.

[6]:
from astropy.cosmology import FlatLambdaCDM


class DistModFromRedshift(FunctionNode):
    """A wrapper class for the _distmod_from_redshift() function.

    Parameters
    ----------
    redshift : function or constant
        The function or constant providing the redshift value.
    H0 : constant
        The Hubble constant.
    Omega_m : constant
        The matter density Omega_m.
    **kwargs : dict, optional
        Any additional keyword arguments.
    """

    def __init__(self, redshift, H0=73.0, Omega_m=0.3, **kwargs):
        # Create the cosmology once for this node. This is constructed ONCE for all samples.
        if not isinstance(H0, float) or not isinstance(Omega_m, float):
            raise ValueError("H0 and Omega_m must be constants.")
        self.cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m)

        # Call the super class's constructor with the needed information.
        super().__init__(
            func=self._distmod_from_redshift,  # "Function" being wrapped
            redshift=redshift,
            **kwargs,
        )

    def _distmod_from_redshift(self, redshift):
        """Compute distance modulus given redshift and cosmology.

        Parameters
        ----------
        redshift : float or numpy.ndarray
            The redshift value(s).

        Returns
        -------
        distmod : float or numpy.ndarray
            The distance modulus (in mag)
        """
        return self.cosmo.distmod(redshift).value

There are a few things to note from the implementation above.

First, since the cosmology is created on a per-object basis, it will be the same for every evaluation. Its parameters, H0 and Omega_m are fixed for all samples. Only the input redshift is changing.

Second, the “function” being wrapped by the function node is actually an object method. As we noted earlier, the FunctionNode can actually wrap any Python callable object. By wrapping an internal method, the computation has access to the object’s attributes via self.

Supporting Multiple Outputs

If the wrapped function produces multiple outputs, the user can assign names to each output via the outputs constructor argument. This argument takes a list of strings that is same length as the number of outputs produced. Each result is separately stored in a corresponding named parameter (instead of the default function_node_result parameter). These parameters are added automatically to the object.

[7]:
# A function that returns two values.
def _linear_pair(x, m, b):
    """Compute y1 = m * x + b and y2 = -1/m * x - b"""
    return (m * x + b, -1.0 / m * x - b)


# A function node that returns two values. The outputs are named "y1" and "y2".
func_node2 = FunctionNode(
    _linear_pair,  # First parameter is the function to call.
    x=NumpyRandomFunc("uniform", low=0.0, high=10.0),
    m=5.0,
    b=-2.0,
    outputs=["y1", "y2"],  # The output names.
)

print(func_node2.sample_parameters(num_samples=5))
NumpyRandomFunc:uniform_1:
    low: [0. 0. 0. 0. 0.]
    high: [10. 10. 10. 10. 10.]
    function_node_result: [7.18505077 5.28280778 7.5259607  8.64060858 9.4201836 ]
FunctionNode:_linear_pair_0:
    x: [7.18505077 5.28280778 7.5259607  8.64060858 9.4201836 ]
    m: [5. 5. 5. 5. 5.]
    b: [-2. -2. -2. -2. -2.]
    y1: [33.92525386 24.41403889 35.62980348 41.20304292 45.10091802]
    y2: [0.56298985 0.94343844 0.49480786 0.27187828 0.11596328]

The outputs can be referenced individually using the dot notation with their given name. Below we reimplement the increment function using just the y2 output as the function’s input.

[8]:
# This is how we wrap increment function.
inc_node2 = FunctionNode(
    increment,  # First argument is the function to call.
    # The function's parameters are given as keyword arguments to the FunctionNode.
    x=func_node2.y2,  # Use the named output.
)

print(inc_node2.sample_parameters(num_samples=10))
NumpyRandomFunc:uniform_2:
    low: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
    high: [10. 10. 10. 10. 10. 10. 10. 10. 10. 10.]
    function_node_result: [2.21278639 5.24137228 7.25162529 5.43297324 9.39768288 0.38619844
 9.21635687 3.90274874 0.99282193 7.69141292]
FunctionNode:_linear_pair_1:
    x: [2.21278639 5.24137228 7.25162529 5.43297324 9.39768288 0.38619844
 9.21635687 3.90274874 0.99282193 7.69141292]
    m: [5. 5. 5. 5. 5. 5. 5. 5. 5. 5.]
    b: [-2. -2. -2. -2. -2. -2. -2. -2. -2. -2.]
    y1: [ 9.06393194 24.20686141 34.25812646 25.1648662  44.98841442 -0.06900778
 44.08178436 17.51374369  2.96410965 36.45706458]
    y2: [1.55744272 0.95172554 0.54967494 0.91340535 0.12046342 1.92276031
 0.15672863 1.21945025 1.80143561 0.46171742]
FunctionNode:increment_0:
    x: [1.55744272 0.95172554 0.54967494 0.91340535 0.12046342 1.92276031
 0.15672863 1.21945025 1.80143561 0.46171742]
    function_node_result: [2.55744272 1.95172554 1.54967494 1.91340535 1.12046342 2.92276031
 1.15672863 2.21945025 2.80143561 1.46171742]

The named output is most often used in nodes that produce a combination of correlated values, such as (RA, Dec). See the Sampling Object Positions notebook for examples.

Randomization

Care must be taken when creating new function nodes that use randomization. To be consistent, users will want the nodes to be completely random by default, but have the ability to use a provided random number generator. The difficulty is that FunctionNode.compute() does not pass along the random number generator to the function. It can’t because not all wrapped functions can even take a random number generator parameter.

Instead there are two supported approaches to enable random behavior.

Use Random Parameters (RECOMMENDED)

Users can add new parameters in the their class’s constructor that correspond to the random values they would like to generate. For example, if we wanted to implement a noisy linear function: y = m * x + b, we could add a noise parameter. We set this parameter using a NumpyRandomFunc or other random node. This approach takes care of the internal bookkeeping

[9]:
class NoisyLinear(FunctionNode):
    """A noisy linear function node."""

    def __init__(self, x, m, b, **kwargs):
        # Create the noise function once that will be constructed once, but queried for each sample.
        self.noise_func = NumpyRandomFunc("normal", loc=0.0, scale=1.0)

        # Call the super class's constructor with the needed information.
        super().__init__(
            func=self._noisy_linear_eq,  # "Function" being wrapped
            x=x,
            m=m,
            b=b,
            noise=self.noise_func,
            **kwargs,
        )

    def _noisy_linear_eq(self, x, m, b, noise):
        """Compute y = m * x + b + noise."""
        return m * x + b + noise


my_node = NoisyLinear(
    x=NumpyRandomFunc("uniform", low=0.0, high=10.0),
    m=5.0,
    b=-2.0,
)
print(my_node.sample_parameters(num_samples=5))
NumpyRandomFunc:uniform_1:
    low: [0. 0. 0. 0. 0.]
    high: [10. 10. 10. 10. 10.]
    function_node_result: [7.18417095 8.99691895 1.70320779 2.63393127 7.13002592]
NoisyLinear:_noisy_linear_eq_0:
    x: [7.18417095 8.99691895 1.70320779 2.63393127 7.13002592]
    m: [5. 5. 5. 5. 5.]
    b: [-2. -2. -2. -2. -2.]
    noise: [ 3.64949918  0.83589374 -0.03042521  0.69863663  0.87485232]
    function_node_result: [37.57035395 43.82048852  6.48561372 11.86829297 34.52498191]
NumpyRandomFunc:normal_2:
    loc: [0. 0. 0. 0. 0.]
    scale: [1. 1. 1. 1. 1.]
    function_node_result: [ 3.64949918  0.83589374 -0.03042521  0.69863663  0.87485232]

As you can see, the noise parameter is sampled first and applied as though it was any other constant.

Custom Compute Function

If users need more control over how the randomness is used, they can override the compute function which does take a random number generator. However, the compute function contains other logic that will need to be replicated, including the assembly of the functions parameters and writing the results to the GraphState. We recommend this approach only for experienced users. For examples of this approach, see the code for the NumpyRandomFunc class itself.