Creating Custom Function Nodes
In this tutorial we do a deep dive into FunctionNode objects and how users can create their own nodes for various functions. Users should already be familiar with the concepts covered in the Sampling Parameters notebook.
[1]:
from lightcurvelynx.base_models import FunctionNode
from lightcurvelynx.math_nodes.np_random import NumpyRandomFunc
Function Node Overview
Function nodes provide users the ability to wrap arbitrary computations during the parameter generation stage. The name FunctionNode is a bit of a misnomer as these nodes can wrap any Python callable object. For simplicity, we will use the term function throughout this notebook, but users should understand the more general behavior.
The basic flow of the FunctionNode is wrapped in the base class’s compute method:
Assemble the wrapped function’s input values from the
GraphStateobject (model parameters) and keyword arguments,Call the wrapped function with the assembled input values,
Capture the function’s output, and
Write those values to the
GraphStateobject.
By default each function node stores its result in a parameter called function_node_result. Since model parameters are indexed by a combination of node name and parameter name, it will often be the case that multiple nodes in the model will generate function_node_result values. As we will see later in this notebook, we can override the name of the outputs to be more user friendly.
There are two ways to use the FunctionNode class: as a standalone wrapper or as a parent class.
FunctionNode as a Standalone Wrapper
Users can wrap a function directly by passing the function and its arguments into the FunctionNode constructor. This wraps the provided function and uses the functions returned value as its output.
As a concrete example, let’s create a FunctionNode that computes wraps an existing function that computes y = m * x + b. We need to pass this function and values for each of its parameters to the constructor.
[2]:
# This is the function we would like to wrap.
def linear_eq_function(x, m, b):
"""Compute y = m * x + b"""
return m * x + b
# This is how we wrap linear_eq_function.
func_node = FunctionNode(
linear_eq_function, # First argument is the function to call.
# The function's parameters are given as keyword arguments to the FunctionNode.
x=NumpyRandomFunc("uniform", low=0.0, high=10.0), # Random value
m=5.0, # Constant value
b=-2.0, # Constant value
)
The first parameter of the function node is the function to evaluate, such as our linear equation above (linear_eq_function). Each input into that function must be included as a named parameter during the FunctionNode definition, such as x, m, and b above. If any of the input parameters are missing, the code will give an error. The FunctionNode class handles all the internal book keeping of: determining the names of the function’s arguments, creating internal
parameters, and assembling those arguments whenever the function is called.
Here we provide constants for m and b so we use the same linear formulation for each sample. Only the value of x changes. However, we could have also used a whole tree of function nodes, including sampling functions, to set m and b. In that case it is important to remember that each of our results is a consistent sampling and computation over all the parameters in the model.
[3]:
state = func_node.sample_parameters(num_samples=5)
print(state)
NumpyRandomFunc:uniform_1:
low: [0. 0. 0. 0. 0.]
high: [10. 10. 10. 10. 10.]
function_node_result: [6.16710321 4.30270113 7.87894201 0.83858047 2.93566364]
FunctionNode:linear_eq_function_0:
x: [6.16710321 4.30270113 7.87894201 0.83858047 2.93566364]
m: [5. 5. 5. 5. 5.]
b: [-2. -2. -2. -2. -2.]
function_node_result: [28.83551603 19.51350565 37.39471003 2.19290235 12.67831818]
As described above both of the nodes (the numpy sampler and the linear function) create function_node_result parameters to store their intermediate results.
The nodes can be chained by using one FunctionNode as the value for a parameter of another. When the a FunctionNode is passed as a parameter, LightCurveLynx will automatically link that parameter to the FunctionNode’s function_node_result value. Below you can see that the input (x) of our increment function corresponds directly to the output (function_node_result) of the linear equation function.
[4]:
def increment(x):
"""Increment x by 1."""
return x + 1
# This is how we wrap increment function.
inc_node = FunctionNode(
increment, # First argument is the function to call.
# The function's parameters are given as keyword arguments to the FunctionNode.
x=func_node, # Use the output of func_node as input to increment.
)
state = inc_node.sample_parameters(num_samples=5)
print(state)
NumpyRandomFunc:uniform_2:
low: [0. 0. 0. 0. 0.]
high: [10. 10. 10. 10. 10.]
function_node_result: [0.16460675 8.42603063 1.88543862 5.3767179 4.12702928]
FunctionNode:linear_eq_function_1:
x: [0.16460675 8.42603063 1.88543862 5.3767179 4.12702928]
m: [5. 5. 5. 5. 5.]
b: [-2. -2. -2. -2. -2.]
function_node_result: [-1.17696625 40.13015316 7.4271931 24.88358952 18.63514638]
FunctionNode:increment_0:
x: [-1.17696625 40.13015316 7.4271931 24.88358952 18.63514638]
function_node_result: [-0.17696625 41.13015316 8.4271931 25.88358952 19.63514638]
We could make the linking of parameters more explicit by using the dot notation and the parameter name. But the behavior is identical.
[5]:
# This is how we wrap increment function.
inc_node = FunctionNode(
increment, # First argument is the function to call.
# The function's parameters are given as keyword arguments to the FunctionNode.
x=func_node.function_node_result, # named parameter
)
state = inc_node.sample_parameters(num_samples=5)
print(state)
NumpyRandomFunc:uniform_2:
low: [0. 0. 0. 0. 0.]
high: [10. 10. 10. 10. 10.]
function_node_result: [0.55896745 5.68462003 6.60622722 6.11956424 4.7820393 ]
FunctionNode:linear_eq_function_1:
x: [0.55896745 5.68462003 6.60622722 6.11956424 4.7820393 ]
m: [5. 5. 5. 5. 5.]
b: [-2. -2. -2. -2. -2.]
function_node_result: [ 0.79483726 26.42310013 31.03113608 28.59782122 21.91019652]
FunctionNode:increment_0:
x: [ 0.79483726 26.42310013 31.03113608 28.59782122 21.91019652]
function_node_result: [ 1.79483726 27.42310013 32.03113608 29.59782122 22.91019652]
More realistically, users will want to wrap functions that perform complex astronomical calculations.
FunctionNode Subclasses
In the case where users will want to create function nodes that carry around additional data, users can create subclasses of the FunctionNode class. For example, when computing the distmod from the redshift, we need to load the cosmology. While we could load the cosmology each time the function is called, it would be more efficient to load it once and reuse it across computations.
[6]:
from astropy.cosmology import FlatLambdaCDM
class DistModFromRedshift(FunctionNode):
"""A wrapper class for the _distmod_from_redshift() function.
Parameters
----------
redshift : function or constant
The function or constant providing the redshift value.
H0 : constant
The Hubble constant.
Omega_m : constant
The matter density Omega_m.
**kwargs : dict, optional
Any additional keyword arguments.
"""
def __init__(self, redshift, H0=73.0, Omega_m=0.3, **kwargs):
# Create the cosmology once for this node. This is constructed ONCE for all samples.
if not isinstance(H0, float) or not isinstance(Omega_m, float):
raise ValueError("H0 and Omega_m must be constants.")
self.cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m)
# Call the super class's constructor with the needed information.
super().__init__(
func=self._distmod_from_redshift, # "Function" being wrapped
redshift=redshift,
**kwargs,
)
def _distmod_from_redshift(self, redshift):
"""Compute distance modulus given redshift and cosmology.
Parameters
----------
redshift : float or numpy.ndarray
The redshift value(s).
Returns
-------
distmod : float or numpy.ndarray
The distance modulus (in mag)
"""
return self.cosmo.distmod(redshift).value
There are a few things to note from the implementation above.
First, since the cosmology is created on a per-object basis, it will be the same for every evaluation. Its parameters, H0 and Omega_m are fixed for all samples. Only the input redshift is changing.
Second, the “function” being wrapped by the function node is actually an object method. As we noted earlier, the FunctionNode can actually wrap any Python callable object. By wrapping an internal method, the computation has access to the object’s attributes via self.
Supporting Multiple Outputs
If the wrapped function produces multiple outputs, the user can assign names to each output via the outputs constructor argument. This argument takes a list of strings that is same length as the number of outputs produced. Each result is separately stored in a corresponding named parameter (instead of the default function_node_result parameter). These parameters are added automatically to the object.
[7]:
# A function that returns two values.
def _linear_pair(x, m, b):
"""Compute y1 = m * x + b and y2 = -1/m * x - b"""
return (m * x + b, -1.0 / m * x - b)
# A function node that returns two values. The outputs are named "y1" and "y2".
func_node2 = FunctionNode(
_linear_pair, # First parameter is the function to call.
x=NumpyRandomFunc("uniform", low=0.0, high=10.0),
m=5.0,
b=-2.0,
outputs=["y1", "y2"], # The output names.
)
print(func_node2.sample_parameters(num_samples=5))
NumpyRandomFunc:uniform_1:
low: [0. 0. 0. 0. 0.]
high: [10. 10. 10. 10. 10.]
function_node_result: [9.04452102 0.94670644 6.67752155 6.35896395 6.09439661]
FunctionNode:_linear_pair_0:
x: [9.04452102 0.94670644 6.67752155 6.35896395 6.09439661]
m: [5. 5. 5. 5. 5.]
b: [-2. -2. -2. -2. -2.]
y1: [43.22260511 2.7335322 31.38760775 29.79481973 28.47198307]
y2: [0.1910958 1.81065871 0.66449569 0.72820721 0.78112068]
The outputs can be referenced individually using the dot notation with their given name. Below we reimplement the increment function using just the y2 output as the function’s input.
[8]:
# This is how we wrap increment function.
inc_node2 = FunctionNode(
increment, # First argument is the function to call.
# The function's parameters are given as keyword arguments to the FunctionNode.
x=func_node2.y2, # Use the named output.
)
print(inc_node2.sample_parameters(num_samples=10))
NumpyRandomFunc:uniform_2:
low: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
high: [10. 10. 10. 10. 10. 10. 10. 10. 10. 10.]
function_node_result: [0.49435854 5.82445724 2.43280594 6.34442967 5.59811224 5.22455135
2.3225619 1.55242886 9.94308959 5.02614475]
FunctionNode:_linear_pair_1:
x: [0.49435854 5.82445724 2.43280594 6.34442967 5.59811224 5.22455135
2.3225619 1.55242886 9.94308959 5.02614475]
m: [5. 5. 5. 5. 5. 5. 5. 5. 5. 5.]
b: [-2. -2. -2. -2. -2. -2. -2. -2. -2. -2.]
y1: [ 0.47179269 27.12228622 10.16402968 29.72214837 25.9905612 24.12275673
9.61280948 5.7621443 47.71544795 23.13072373]
y2: [1.90112829 0.83510855 1.51343881 0.73111407 0.88037755 0.95508973
1.53548762 1.68951423 0.01138208 0.99477105]
FunctionNode:increment_0:
x: [1.90112829 0.83510855 1.51343881 0.73111407 0.88037755 0.95508973
1.53548762 1.68951423 0.01138208 0.99477105]
function_node_result: [2.90112829 1.83510855 2.51343881 1.73111407 1.88037755 1.95508973
2.53548762 2.68951423 1.01138208 1.99477105]
The named output is most often used in nodes that produce a combination of correlated values, such as (RA, Dec). See the Sampling Object Positions notebook for examples.
Randomization
Care must be taken when creating new function nodes that use randomization. To be consistent, users will want the nodes to be completely random by default, but have the ability to use a provided random number generator. The difficulty is that FunctionNode.compute() does not pass along the random number generator to the function. It can’t because not all wrapped functions can even take a random number generator parameter.
Instead there are two supported approaches to enable random behavior.
Use Random Parameters (RECOMMENDED)
Users can add new parameters in the their class’s constructor that correspond to the random values they would like to generate. For example, if we wanted to implement a noisy linear function: y = m * x + b, we could add a noise parameter. We set this parameter using a NumpyRandomFunc or other random node. This approach takes care of the internal bookkeeping
[9]:
class NoisyLinear(FunctionNode):
"""A noisy linear function node."""
def __init__(self, x, m, b, **kwargs):
# Create the noise function once that will be constructed once, but queried for each sample.
self.noise_func = NumpyRandomFunc("normal", loc=0.0, scale=1.0)
# Call the super class's constructor with the needed information.
super().__init__(
func=self._noisy_linear_eq, # "Function" being wrapped
x=x,
m=m,
b=b,
noise=self.noise_func,
**kwargs,
)
def _noisy_linear_eq(self, x, m, b, noise):
"""Compute y = m * x + b + noise."""
return m * x + b + noise
my_node = NoisyLinear(
x=NumpyRandomFunc("uniform", low=0.0, high=10.0),
m=5.0,
b=-2.0,
)
print(my_node.sample_parameters(num_samples=5))
NumpyRandomFunc:uniform_1:
low: [0. 0. 0. 0. 0.]
high: [10. 10. 10. 10. 10.]
function_node_result: [7.66580875 9.3376184 1.97323225 2.79892872 3.10108938]
NoisyLinear:_noisy_linear_eq_0:
x: [7.66580875 9.3376184 1.97323225 2.79892872 3.10108938]
m: [5. 5. 5. 5. 5.]
b: [-2. -2. -2. -2. -2.]
noise: [-0.80769063 -1.21502062 0.17056517 0.13885519 0.34306928]
function_node_result: [35.52135312 43.47307139 8.03672644 12.13349876 13.84851619]
NumpyRandomFunc:normal_2:
loc: [0. 0. 0. 0. 0.]
scale: [1. 1. 1. 1. 1.]
function_node_result: [-0.80769063 -1.21502062 0.17056517 0.13885519 0.34306928]
As you can see, the noise parameter is sampled first and applied as though it was any other constant.
Custom Compute Function
If users need more control over how the randomness is used, they can override the compute function which does take a random number generator. However, the compute function contains other logic that will need to be replicated, including the assembly of the functions parameters and writing the results to the GraphState. We recommend this approach only for experienced users. For examples of this approach, see the code for the NumpyRandomFunc class itself.