"""The base models used to specify the LightCurveLynx computation graph.
The computation graph is composed of ParameterizedNodes which compute random variables
(called "model parameters" or "parameters" for short) and encode the dependencies between
them. Model parameters are different from object variables in that they are *not* stored
within the object, but rather programmatically set in an external graph_state dictionary by
sampling the graph. The user does not need to access the internals of graph_state directly,
but rather can treat it as an opaque object to pass around. The dictionary can hold either
individual values (often floats) or arrays of samples.
All model parameters (random variables in the probabilistic graph) must be added using
using the ParameterizedNode.add_parameter() function. This allows the graph to track
which parameters to update and how to set them. A parameter's values can be set from a few
sources:
1) A constant (e.g., a given standard deviation for a noise model).
2) A static function or method (which does not have variables that are resampled).
3) The result of evaluating a FunctionNode, which provides a computation using other
parameters in the graph.
4) The parameters of another ParameterizedNode.
We say that parameter X is dependent on parameter Y if the value of Y is necessary
to compute the value of X. For example, if X is set by evaluating a FunctionNode that
uses parameter Y in the computation, X is dependent on Y. The dependencies impose an
ordering of model parameters in the graph. Y must be computed before X.
ParameterNodes provide semantic groupings of individual parameters. For example, we may
have a ParameterNode representing the information needed for a Type Ia supernova.
That node's parameters would include the variables needed to evaluate the supernova's
light curve. Each of these parameters might depend on parameters in other nodes, such
as those of the host galaxy.
The execution graph is processed by starting at the final node, examining each
model parameter for that node, and recursively proceeding 'up' the graph for any
of its parameters that has a dependency. For example, the function::
f(a, b) = x
g(c) = y
h(x, y) = z
would form the graph::
a - \\
x -- \\
b - / \\
z
c -- y --- /
where z is the 'bottom' node. Parameters a, b, and c would be at the 'top' of the
graph because they have no dependencies. Such parameters are set by constants or
static functions.
"""
from functools import partial
import numpy as np
from lightcurvelynx.graph_state import DependencyGraph, GraphState
class _ParameterSource:
"""_ParameterSource specifies the information about where a ParameterizedNode should
get the value for a given parameter. These objects track internal state.
Users should not work with these objects directly.
Attributes
----------
parameter_name : str
The name of the parameter within the node (short name).
node_name : str
The name of the parent node.
source_type : int
The type of source as defined by the class variables.
Default: 0
value : any
The information that actually sets the parameter. Either a constant
or the attribute name of a dependency node.
dependency : ParameterizedNode or None
The node on which this parameter is dependent.
description : str (optional)
A brief description of the parameter.
allow_gradient : bool
Allow gradients to be computed at this variable.
Default: False
"""
# Class variables for the source enum.
UNDEFINED = 0
CONSTANT = 1
MODEL_PARAMETER = 2
FUNCTION_NODE = 3
COMPUTE_OUTPUT = 4
def __init__(self, parameter_name, source_type=0, node_name="", description=None):
self.parameter_name = parameter_name
self.node_name = node_name
self.source_type = source_type
self.allow_gradient = False
self.value = None
self.dependency = None
self.description = description
def help(self):
"""Display help information about this parameter."""
print(f"{self.parameter_name}:")
if self.description is not None:
print(f" Description: {self.description}")
else:
print(" Description: None")
if self.source_type == self.UNDEFINED:
print(" Source: UNDEFINED")
elif self.source_type == self.CONSTANT:
print(f" Source: CONSTANT with value = {self.value}")
elif self.source_type == self.MODEL_PARAMETER:
print(f" Source: MODEL_PARAMETER {self.value} of node {str(self.dependency)}")
elif self.source_type == self.FUNCTION_NODE:
print(f" Source: Result of FUNCTION_NODE {str(self.dependency)}")
elif self.source_type == self.COMPUTE_OUTPUT:
print(" Source: Result of computation within this node")
def set_as_constant(self, value, allow_gradient=True):
"""Set the parameter as a constant value.
Parameters
----------
value : any
The constant value to use.
allow_gradient : bool
Allow a gradient to be computed at this variable.
Default: True
"""
if callable(value):
raise ValueError(f"Using set_as_constant on callable {value}")
self.source_type = _ParameterSource.CONSTANT
self.allow_gradient = allow_gradient
self.dependency = None
self.value = value
def set_as_parameter(self, dependency, param_name):
"""Set the parameter as a model parameter of another node. This is
used for chaining, such as when an object's ra depends on its host's ra.
Parameters
----------
dependency : ParameterizedNode
The node in which to access the attribute.
param_name : str
The name of the parameter to access.
"""
self.source_type = _ParameterSource.MODEL_PARAMETER
self.allow_gradient = False
self.dependency = dependency
self.value = param_name
def set_as_function(self, dependency, param_name="function_node_result"):
"""Set the parameter as the result of a FunctionNode (that is not
the current node).
Parameters
----------
dependency : ParameterizedNode
The node in which to access the attribute.
param_name : str
The name of where the result is stored in the FunctionNode.
"""
self.source_type = _ParameterSource.FUNCTION_NODE
self.allow_gradient = False
self.dependency = dependency
self.value = param_name
def set_as_compute_output(self, param_name="function_node_result"):
"""Set the parameter the output of this current node's compute() method.
This needs to be separate from FUNCTION_NODE type (set_as_function) because
the sampling function needs to know to call the current node's compute()
method after the other parameters have been sampled.
Parameters
----------
dependency : ParameterizedNode
The node in which to access the attribute.
param_name : str
The name of where the result is stored in the FunctionNode.
"""
self.source_type = _ParameterSource.COMPUTE_OUTPUT
self.allow_gradient = False
self.value = param_name
class _AttributeIndicatorNode:
"""A class to wrap a single attribute of an object. These objects track internal state.
Users should not work with these objects directly.
Attributes
----------
attr_name : str
The name of the attribute to access.
parent : ParameterizedNode
The parent node that owns this attribute.
"""
def __init__(self, attr_name, parent):
self.attr_name = attr_name
self.parent = parent
def __call__(self, graph_state):
"""Return the value in the graph_state of this attribute.
Parameters
----------
graph_state : GraphState
The current graph state.
Returns
-------
any
The value of the attribute in the graph state.
"""
return graph_state[self.parent.node_string][self.attr_name]
[docs]
class ParameterizedNode:
"""Any model that uses parameters that can be set by constants,
functions, or other parameterized nodes.
ParameterizedNodes do not store values directly, but rather provide a recipe
for how to generate the parameters' values. The sampled values are read from
and written to a GraphState object that stores all the parameter values for
all the nodes.
Attributes
----------
node_label : str
An optional human readable identifier (name) for the current node.
node_string : str
The full string used to identify a node. This is a combination of the node's position
in the graph (if known), node_label (if provided), and class information. This is
used to access the parameters for this node in the graph_state.
setters : dict
A dictionary mapping the parameters' names to information about the setters
(_ParameterSource). The model parameters are stored in the order in which they
need to be set.
node_pos : int or None
A unique ID number for each node in the graph indicating its position.
Assigned during resampling or set_graph_positions(). This is required to resolve
naming collisions so we do not overwrite parameters from other nodes.
Parameters
----------
node_label : str, optional
An identifier (or name) for the current node.
**kwargs : dict, optional
Any additional keyword arguments.
"""
def __init__(self, node_label=None, **kwargs):
[docs]
self.node_label = node_label
[docs]
self.node_string = None
# Give the node a temporary name.
self._update_node_string()
[docs]
def __str__(self):
"""Return the string representation of the node."""
return self.node_string
def _update_node_string(self, new_str=None):
"""Update the node's string.
Parameters
----------
new_str : str, optional
The new node string. If not provided, the node_string
is automatically computed from the other node information.
"""
if self.node_label is not None:
# If a label is given, just use that. It overrides even the new_str.
self.node_string = self.node_label
elif new_str is not None:
self.node_string = new_str
else:
# Otherwise use a combination of the node's class and position.
pos_string = f"_{self.node_pos}" if self.node_pos is not None else ""
self.node_string = f"{self.__class__.__name__}{pos_string}"
# Update the node_name of all node's parameter setters.
for _, setter_info in self.setters.items():
setter_info.node_name = self.node_string
[docs]
def set_graph_positions(self, seen_nodes=None):
"""Force an update of the graph structure (numbering of each node).
Parameters
----------
seen_nodes : set, optional
A set of nodes that have already been processed to prevent infinite loops.
Caller should not set.
"""
# Make sure that we do not process the same nodes multiple times.
if seen_nodes is None:
seen_nodes = set()
if self in seen_nodes:
return
seen_nodes.add(self)
# Update the node's position in the graph and its string.
self.node_pos = len(seen_nodes) - 1
self._update_node_string()
# Recursively update any direct dependencies.
for setter_info in self.setters.values():
if setter_info.dependency is not None and setter_info.dependency is not self:
setter_info.dependency.set_graph_positions(seen_nodes)
[docs]
def list_params(self):
"""Return a list of this node's parameterized values.
Returns
-------
names : list[str]
The name of all of the parameterized values for this node.
"""
return list(self.setters.keys())
[docs]
def describe_params(self, names=None):
"""Print out a description of the node's parameters.
Parameters
----------
names : list[str], optional
A list of parameter names to describe. If None, describes all parameters.
"""
if names is None:
names = self.list_params()
if len(names) > 1:
print(f"Parameters in {self.node_string}:\n")
for name in names:
if name in self.setters:
self.setters[name].help()
print()
else:
print(f"Parameter: {name} not found in node {self.node_string}.")
[docs]
def has_valid_param(self, name):
"""Check whether the node has a given parameterized value and that it is not
always set to None.
Parameters
----------
name : str
The name of the parameter.
Returns
-------
contains : bool
Whether the node contains a given parameter and it is not always None.
"""
if name not in self.setters:
return False
setter = self.setters[name]
return not (setter.source_type == _ParameterSource.CONSTANT and setter.value is None)
[docs]
def get_param(self, graph_state, name, default=None):
"""Get the value of a parameter stored in this node or a default value.
Note
----
This is an optional helper function that accesses the internals of graph_state
using the node's information (e.g. its hash value).
Parameters
----------
graph_state : dict
The dictionary of graph state information.
name : str
The parameter name to query.
default : any
The default value to return if the parameter is not in GraphState.
Returns
-------
any
The parameter value or the default.
Raises
------
ValueError
If graph_state is None.
"""
if graph_state is None:
raise ValueError(f"Unable to look up parameter={name}. No graph_state given.")
if self.node_string in graph_state and name in graph_state[self.node_string]:
return graph_state[self.node_string][name]
return default
[docs]
def get_local_params(self, graph_state):
"""Get a dictionary of all parameters local to this node.
Note
----
This is an optional helper function that accesses the internals of graph_state
using the node's information (e.g. its hash value).
Parameters
----------
graph_state : GraphState
An object mapping graph parameters to their values.
Returns
-------
result : dict
A dictionary mapping the parameter name to its value.
Raises
------
KeyError
If no parameters have been set for this node.
ValueError
If graph_state is None.
"""
if graph_state is None:
raise ValueError("No graph_state given.")
return graph_state[self.node_string]
[docs]
def set_parameter(self, name, value=None, **kwargs):
"""Set the source of a single *existing* parameter in the ParameterizedNode.
Parameters within a node are actually a mapping of the parameter name
to a _ParameterSource object that indicates how they are set during sampling.
Some parameters may be set as a constant value while others may be set by
evaluating a function that depends on other parameters.
Note
-----
* Does NOT set an initial value for the model parameter. The user must
sample the parameters for this to be set.
* The model parameters are stored in the order in which they are added.
Parameters
----------
name : str
The parameter name to add.
value : any, optional
The information to use to set the parameter. Can be a constant,
function, ParameterizedNode, or self.
**kwargs : dict, optional
All other keyword arguments, possibly including the parameter setters.
Raises
------
KeyError
If there is a parameter collision or the parameter cannot be found.
ValueError
If the setter type is not supported.
"""
# Set the node's position in the graph to None to indicate that the
# structure might have changed. It needs to be updated with set_graph_positions().
self.node_pos = None
# Check for parameter has been added and if so, find the index. All parameters must
# be added first with "add_parameter()".
if name not in self.setters:
raise KeyError(
f"Tried to set parameter '{name}' that has not been added to node {self.node_string}."
) from None
if value is None and name in kwargs:
# The value wasn't set, but the name is in kwargs.
value = kwargs[name]
if callable(value):
if isinstance(value, _AttributeIndicatorNode):
# Case 1a: This is an attribute of a ParameterizedNode.
# We set the parameter's value in this node as the extraction of the
# parameter's value in the parent node.
if value.attr_name in value.parent.setters:
self.setters[name].set_as_parameter(value.parent, value.attr_name)
else:
raise ValueError(
f"Trying to set parameter '{name}' to the {type(value.parent)}.{value.attr_name}, "
f"but unable to find that parameter in class {type(value.parent)}."
)
else:
# Case 1b: This is a general function or callable method from another object.
# We treat it as static (we don't resample the other object) and
# wrap it in a FunctionNode.
func_node = FunctionNode(value, **kwargs)
self.setters[name].set_as_function(func_node)
elif isinstance(value, FunctionNode):
# Case 2: We are using the result of a computation of the function node.
# If the FunctionNode has names outputs that match the variable, use that.
output_name = name if name in value.outputs else "function_node_result"
self.setters[name].set_as_function(value, output_name)
elif isinstance(value, ParameterizedNode):
# Case 3 [No longer supported]: We are trying to access a parameter of a
# ParameterizedNode with the same name as the current parameter (implicit linking).
# We removed this pattern because it increases the potential for user confusion.
raise ValueError(
f"Error setting parameter '{name}': Setting a parameter to a ParameterizedNode "
f"and implicitly determining that parameter name (e.g. using '{name}=other_node' "
f"to link other_node.{name}) is no longer supported. You must explicitly specify "
f"the parameter name using the dot notation (e.g. '{name}=other_node.{name}')."
)
else:
# Case 4: The value is constant (including None).
self.setters[name].set_as_constant(value)
[docs]
def set_allow_gradient(self, name, allow_gradient):
"""Turn on or off the ability to compute a gradient for this variable.
Parameters
----------
name : str
The parameter name to modify.
allow_gradient : bool
The new setting for allow_gradient.
"""
self.setters[name].allow_gradient = allow_gradient
[docs]
def add_parameter(self, name, value=None, allow_gradient=None, description=None, **kwargs):
"""Add a single *new* parameter to the ParameterizedNode.
Note
----
* Checks multiple sources in the following order: Manually specified value,
an entry in kwargs, or None.
* Does NOT set an initial value for the model parameter. The user must
sample the parameters for this to be set.
* The model parameters are stored in the order in which they are added.
Parameters
----------
name : str
The parameter name to add.
value : any, optional
The information to use to set the parameter. Can be a constant,
function, ParameterizedNode, or self.
allow_gradient : bool or None
Allow gradients to be computed for this variable. If set to None uses the default
for the setter type (True for constant and False for everything else).
Default: None
description : str, optional
A brief description of the parameter.
**kwargs : dict, optional
All other keyword arguments, possibly including the parameter setters.
Raises
------
KeyError
If there is a parameter collision or the parameter cannot be found.
"""
# Check for parameter collision and add a place holder value to the 'setters' dictionary.
if hasattr(self, name) and name not in self.setters:
raise KeyError(
f"Parameter name '{name}' conflicts with a predefined model parameter "
f"or class attribute in {self.node_string}"
)
if self.setters.get(name, None) is not None:
raise KeyError(f"Duplicate parameter set: '{name}' in {self.node_string}")
# Add an entry for the setter function and fill in the remaining information using
# set_parameter(). We add an initial (dummy) value here to indicate that this parameter
# exists and was added via add_parameter().
self.setters[name] = _ParameterSource(
parameter_name=name,
source_type=_ParameterSource.UNDEFINED,
node_name=str(self),
description=description,
)
self.set_parameter(name, value, **kwargs)
# Check if we should override allow_gradient.
if allow_gradient is not None:
self.setters[name].allow_gradient = allow_gradient
# Create an _AttributeIndicatorNode to represent this parameter.
# This node allows us to reference the parameter as object.parameter_name
# for chaining without copying the value. For example, if my_node_1, is a
# ParameterizedNode with a parameter x, we can do:
# my_node_2 = ParameterizedNode(y=my_node_1.x)
# and my_node_2 will know to use the sampled values of x from my_node_1
# (as opposed to the setter for x).
setattr(self, name, _AttributeIndicatorNode(name, self))
[docs]
def get_parameter_indicator(self, param_name):
"""Return the _AttributeIndicatorNode for a given parameter. This replicates
the behavior of accessing the parameter as an attribute of the object
(e.g., obj.param_name), but in functional form.
Parameters
----------
param_name : str
The name of the parameter indicator to get.
Returns
-------
getter : _AttributeIndicatorNode
The indicator node for the given parameter.
"""
if param_name not in self.setters:
raise KeyError(f"Parameter name '{param_name}' not found in node {self.node_string}.")
return getattr(self, param_name)
[docs]
def compute(self, graph_state, rng_info=None, **kwargs):
"""Placeholder for a general compute function, which is called at the end
of the sampling process and can produce derived parameters. This function
is the main processing step in a FunctionNode.
Parameters
----------
graph_state : GraphState
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
rng_info : numpy.random._generator.Generator, optional
A given numpy random number generator to use for this computation. If not
provided, the function uses the node's random number generator.
**kwargs : dict, optional
Additional function arguments.
"""
return None
def _sample_helper(self, graph_state, seen_nodes, rng_info=None):
"""Internal recursive function to sample the model's underlying parameters
if they are provided by a function or ParameterizedNode. All sampled
parameters for all nodes are stored in the graph_state dictionary, which is
modified in-place.
Parameters
----------
graph_state : GraphState
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
seen_nodes : dict
A dictionary mapping nodes strings seen during this sampling run to their object.
Used to avoid sampling nodes multiple times and to validity check the graph.
rng_info : numpy.random._generator.Generator, optional
A given numpy random number generator to use for this computation. If not
provided, the function uses the node's random number generator.
Raises
------
KeyError
If the sampling encounters an error with the order of dependencies.
"""
node_str = str(self)
if node_str in seen_nodes:
if seen_nodes[node_str] != self:
raise ValueError(
f"Duplicate node label '{node_str}'. Every node must have a unique label. "
"This most often happens when the node_label parameter is set directly."
)
return # Nothing to do
seen_nodes[node_str] = self
# Run through each parameter and sample it based on the given recipe.
# As of Python 3.7 dictionaries are guaranteed to preserve insertion ordering,
# so this will iterate through model parameters in the order they were inserted.
any_compute = False
for name, setter in self.setters.items():
# Check if we need to sample this parameter's dependency node.
if setter.dependency is not None and setter.dependency != self:
setter.dependency._sample_helper(graph_state, seen_nodes, rng_info=rng_info)
# Set the result from the correct source.
if setter.source_type == _ParameterSource.CONSTANT:
if graph_state.num_samples == 1:
graph_state.set(self.node_string, name, setter.value)
else:
repeated_value = np.array([setter.value] * graph_state.num_samples)
graph_state.set(self.node_string, name, repeated_value)
elif setter.source_type == _ParameterSource.MODEL_PARAMETER:
graph_state.set(
self.node_string,
name,
graph_state[setter.dependency.node_string][setter.value],
)
elif setter.source_type == _ParameterSource.FUNCTION_NODE:
graph_state.set(
self.node_string,
name,
graph_state[setter.dependency.node_string][setter.value],
)
elif setter.source_type == _ParameterSource.COMPUTE_OUTPUT:
# Computed parameters are set only after all the other (input) parameters.
any_compute = True
else:
raise ValueError(f"Invalid _ParameterSource type {setter.source_type}")
# If this is a function node and the parameters depend on the result of its own computation
# call the compute function to fill them in.
if any_compute:
self.compute(graph_state, rng_info)
[docs]
def sample_parameters(self, given_args=None, num_samples=1, rng_info=None, sample_offset=0):
"""Sample the model's underlying parameters if they are provided by a function
or ParameterizedNode.
Parameters
----------
given_args : dict, optional
A dictionary representing the given arguments for this sample run.
This can be used as the JAX PyTree for differentiation.
num_samples : int
A count of the number of samples to compute.
Default: 1
rng_info : numpy.random._generator.Generator, optional
A given numpy random number generator to use for this computation. If not
provided, the function uses the node's random number generator.
sample_offset : int
An optional offset to add to the graph state for any stateful nodes.
This allows the system to better support testing and parallelized sampling.
Default: 0 (no offset)
Returns
-------
graph_state : GraphState
A dictionary of dictionaries mapping node->hash, variable_name to either a
value or array of values. This data structure is modified in place to represent
the model's state(s).
Raises
------
ValueError
If the sampling encounters a problem with the order of dependencies.
"""
# Check that the number of samples is valid.
if num_samples < 1:
raise ValueError(f"num_samples must be a positive integer. Got {num_samples}.")
# If the graph structure has never been set, do that now.
if self.node_pos is None:
nodes = set()
self.set_graph_positions(seen_nodes=nodes)
# Create space for the results and set all the given_args as fixed parameters.
results = GraphState(num_samples, sample_offset=sample_offset)
if given_args is not None:
results.update(given_args, all_fixed=True)
# Resample the nodes. All information is stored in the returned results dictionary.
seen_nodes = {}
self._sample_helper(results, seen_nodes, rng_info=rng_info)
return results
def _dependency_graph_helper(self, dependency_graph):
"""Internal recursive function to build a directed acyclic graph (DAG) representing
the parameters in the model and their dependencies. Each node in the graph is a parameter
and each edge indicates nodes that parameter depends on.
Parameters
----------
dependency_graph : DependencyGraph
An object tracking the dependencies between parameters. This object is modified
in place to represent the current state.
Raises
------
KeyError
If the sampling encounters an error with the order of dependencies.
"""
# Check if we have already processed this node. The node name gets added
# as soon as we add any of its parameters.
node_name = str(self)
if node_name in dependency_graph.all_nodes:
return # Nothing to do. We have already processed this node.
# Add each parameter to the dependency graph.
for param_name, setter in self.setters.items():
full_name = GraphState.extended_param_name(node_name, param_name)
dependency_graph.add_parameter(param_name, node_name)
# Recursively process any dependencies first, including constants.
dep_name = None
if setter.dependency is not None and setter.dependency != self:
dep_name = GraphState.extended_param_name(setter.dependency.node_string, setter.value)
setter.dependency._dependency_graph_helper(dependency_graph)
elif setter.source_type == _ParameterSource.CONSTANT:
dep_name = dependency_graph.add_constant(setter.value)
# If we have a dependency, add an edge from the dependency to this parameter.
# dep_name will be None for parameters that are the result of internal computations
# (i.e., _ParameterSource.COMPUTE_OUTPUT) which is handled in the subclass.
if dep_name is not None:
dependency_graph.add_edge(dep_name, full_name)
[docs]
def build_dependency_graph(self):
"""Build a directed acyclic graph (DAG) representing the parameters in the model
and their dependencies.
Returns
-------
dependency_graph : DependencyGraph
An object tracking the dependencies between parameters.
"""
# If the graph structure has never been set, do that now.
if self.node_pos is None:
nodes = set()
self.set_graph_positions(seen_nodes=nodes)
# Create space for the results and set all the given_args as fixed parameters.
dependency_graph = DependencyGraph()
# Recursively build the dependency graph.
self._dependency_graph_helper(dependency_graph)
return dependency_graph
[docs]
def build_pytree(self, graph_state, partial=None):
"""Build a JAX PyTree representation of the variables in this graph.
Parameters
----------
graph_state : dict
A dictionary of dictionaries mapping node->hash, variable_name to value.
This data structure is modified in place to represent the current state.
partial : dict
The partial results so far. This is modified in place by the function.
A dictionary mapping node name to a dictionary mapping each variable's name
to its value.
Default: None
Returns
-------
values : dict
A dictionary mapping node name to a dictionary mapping each variable's name
to its value.
"""
# Check if the node might have incomplete information.
if self.node_pos is None:
raise ValueError(
f"Node {self.node_string} is missing position. You must call "
"set_graph_positions() before building a pytree."
)
# Skip nodes that we have already seen.
if partial is None:
partial = {}
if self.node_string in partial:
return partial
# Add new values to the pytree, recursively exploring dependencies.
partial[self.node_string] = {}
for name, setter_info in self.setters.items():
if setter_info.allow_gradient:
# Anything wth allow_gradient == True goes in the PyTree.
partial[self.node_string][name] = graph_state[self.node_string][name]
elif setter_info.dependency is not None:
# We only recursively check parameters above non-gradient nodes.
partial = setter_info.dependency.build_pytree(graph_state, partial)
return partial
[docs]
class FunctionNode(ParameterizedNode):
"""A class to wrap functions and their argument settings.
The node can compute the result using a given function (the func
parameter) or through the compute() method. If func=None
then the user must override compute().
Attributes
----------
func : function or method or partial
The function to call during an evaluation. If this is None
you must override the compute() method directly.
args_names : list
A list of argument names to pass to the function.
outputs : list of str
The output model parameters of this function.
Parameters
----------
func : function or method
The function to call during an evaluation.
node_label : str, optional
An identifier (or name) for the current node.
outputs : list of str, optional
The output model parameters of this function. If None, uses
a single model parameter result.
fixed_params : dict, optional
A dictionary mapping a parameter name in the function to its fixed value.
**kwargs
Any additional keyword arguments.
Examples
--------
>>> my_func = TDFunc(random.randint, a=1, b=10) # doctest: +SKIP
>>> value1 = my_func() # Sample from default range # doctest: +SKIP
>>> value2 = my_func(b=20) # Sample from extended range # doctest: +SKIP
Note
----
All the function's parameters that will be used need to be specified
in either the default_args dict, object_args list, or as a kwarg in the
constructor. Arguments cannot be first given during function call.
For example, the following will fail (because b is not defined in the
constructor)::
my_func = TDFunc(random.randint, a=1)
value1 = my_func(b=10.0)
"""
def __init__(self, func, node_label=None, outputs=None, fixed_params=None, **kwargs):
# We set the function before calling the parent class so we can use
# the function's name (if needed).
if fixed_params is not None and len(fixed_params) > 0:
# Create a partial function with some of the parameters fixed.
self.func = partial(func, **fixed_params)
# We need to set the __name__ parameter because it is not preserved by partial.
self.func.__name__ = func.__name__
else:
# Use the function as-is.
self.func = func
super().__init__(node_label=node_label, **kwargs)
# Add all of the parameters from default_args or the kwargs.
for key, value in kwargs.items():
self.arg_names.append(key)
self.add_parameter(key, value, description="Input argument for function.")
# Add the output arguments.
if not outputs:
outputs = ["function_node_result"]
for name in outputs:
# For output parameters we add a placeholder of None to set up the basic data, such as
# the getter function and the entry in parameters. Then we change the
# type to point to own result.
self.add_parameter(name, None, description="Output result of function.")
self.setters[name].set_as_compute_output(param_name=name)
def _non_func(self):
"""This function does nothing. This is used for FunctionNodes where the actual computation
happens in an overloaded compute() function."""
pass
def _update_node_string(self, new_str=None):
"""Update the node's string. A FunctionNode's string includes
the function name in addition to the class name.
"""
if new_str is None:
pos_string = f"_{self.node_pos}" if self.node_pos is not None else ""
fn_str = f":{self.func.__name__}" if self.func is not None else ""
new_str = f"{self.__class__.__name__}{fn_str}{pos_string}"
super()._update_node_string(new_str)
def _build_inputs(self, graph_state, **kwargs):
"""Build the input arguments for the node's function.
Parameters
----------
graph_state : GraphState
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
**kwargs : dict, optional
Additional function arguments.
Returns
-------
args : dict
A dictionary mapping each input argument's name to its value.
"""
args = {}
for key in self.arg_names:
if key in kwargs:
args[key] = kwargs[key]
else:
args[key] = graph_state[self.node_string][key]
return args
def _save_results(self, results, graph_state):
"""Save the results to the graph state.
Parameters
----------
results : iterable
The function's results.
graph_state : GraphState
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
"""
if len(self.outputs) == 1:
graph_state.set(self.node_string, self.outputs[0], results)
else:
if len(results) != len(self.outputs):
raise ValueError(
f"Incorrect number of results returned by {self.func.__name__}. "
f"Expected {len(self.outputs)}, but got {results}."
)
for i in range(len(self.outputs)):
graph_state.set(self.node_string, self.outputs[i], results[i])
def _dependency_graph_helper(self, dependency_graph):
"""Internal recursive function to build a directed acyclic graph (DAG) representing
the parameters in the model and their dependencies.
Parameters
----------
dependency_graph : DependencyGraph
An object tracking the dependencies between parameters. This object is modified
in place to represent the current state.
"""
node_name = str(self)
if node_name in dependency_graph.all_nodes:
return # Nothing to do. We have already processed this node.
# Handle the dependencies of the input features.
super()._dependency_graph_helper(dependency_graph)
# For each computed parameter, add the cross product of inputs to outputs.
for param_name, setter in self.setters.items():
if setter.source_type == _ParameterSource.COMPUTE_OUTPUT:
out_full_name = GraphState.extended_param_name(node_name, param_name)
for input_name in self.arg_names:
input_full_name = GraphState.extended_param_name(node_name, input_name)
dependency_graph.add_edge(input_full_name, out_full_name)
[docs]
def compute(self, graph_state, rng_info=None, **kwargs):
"""Execute the wrapped function.
The input arguments are taken from the current graph_state and the outputs
are written to graph_state.
Parameters
----------
graph_state : GraphState
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
rng_info : numpy.random._generator.Generator, optional
A given numpy random number generator to use for this computation. If not
provided, the function uses the node's random number generator.
**kwargs : dict, optional
Additional function arguments.
Returns
-------
results : any
The result of the computation. This return value is provided so that testing
functions can easily access the results.
Raises
------
ValueError
If func attribute is None.
"""
if self.func is None:
raise ValueError(
f"The FunctionNode {self.node_string}'s 'func' parameter is None. "
"You need to either set func or override compute()."
)
# Build a dictionary of arguments for the function, call the function, and save
# the results in the graph state.
args = self._build_inputs(graph_state, **kwargs)
results = self.func(**args)
self._save_results(results, graph_state)
return results
[docs]
def generate(self, given_args=None, num_samples=1, rng_info=None, **kwargs):
"""A helper function that regenerates the parameters for this node and the
nodes above it, then returns the output for this individual node.
This is used both for testing and for computing JAX gradients.
Parameters
----------
given_args : dict, optional
A dictionary representing the given arguments for this sample run.
This can be used as the JAX PyTree for differentiation.
num_samples : int
A count of the number of samples to compute.
Default: 1
rng_info : numpy.random._generator.Generator, optional
A given numpy random number generator to use for this computation. If not
provided, the function uses the node's random number generator.
**kwargs : dict, optional
Additional function arguments.
"""
state = self.sample_parameters(given_args, num_samples, rng_info)
# Get the result(s) of compute from the state object.
if len(self.outputs) == 1:
return self.get_param(state, self.outputs[0])
results = []
for output_name in self.outputs:
results.append(self.get_param(state, output_name))
return results