lightcurvelynx.base_models

The base models used to specify the LightCurveLynx computation graph.

The computation graph is composed of ParameterizedNodes which compute random variables (called “model parameters” or “parameters” for short) and encode the dependencies between them. Model parameters are different from object variables in that they are not stored within the object, but rather programmatically set in an external graph_state dictionary by sampling the graph. The user does not need to access the internals of graph_state directly, but rather can treat it as an opaque object to pass around. The dictionary can hold either individual values (often floats) or arrays of samples.

All model parameters (random variables in the probabilistic graph) must be added using using the ParameterizedNode.add_parameter() function. This allows the graph to track which parameters to update and how to set them. A parameter’s values can be set from a few sources:

  1. A constant (e.g., a given standard deviation for a noise model).

  2. A static function or method (which does not have variables that are resampled).

  3. The result of evaluating a FunctionNode, which provides a computation using other parameters in the graph.

  4. The parameters of another ParameterizedNode.

We say that parameter X is dependent on parameter Y if the value of Y is necessary to compute the value of X. For example, if X is set by evaluating a FunctionNode that uses parameter Y in the computation, X is dependent on Y. The dependencies impose an ordering of model parameters in the graph. Y must be computed before X.

ParameterNodes provide semantic groupings of individual parameters. For example, we may have a ParameterNode representing the information needed for a Type Ia supernova. That node’s parameters would include the variables needed to evaluate the supernova’s light curve. Each of these parameters might depend on parameters in other nodes, such as those of the host galaxy.

The execution graph is processed by starting at the final node, examining each model parameter for that node, and recursively proceeding ‘up’ the graph for any of its parameters that has a dependency. For example, the function:

f(a, b) = x
g(c) = y
h(x, y) = z

would form the graph:

a - \
     x -- \
b - /      \
            z
c -- y --- /

where z is the ‘bottom’ node. Parameters a, b, and c would be at the ‘top’ of the graph because they have no dependencies. Such parameters are set by constants or static functions.

Classes

ParameterizedNode

Any model that uses parameters that can be set by constants,

FunctionNode

A class to wrap functions and their argument settings.

Module Contents

class ParameterizedNode(node_label=None, **kwargs)[source]

Any model that uses parameters that can be set by constants, functions, or other parameterized nodes.

ParameterizedNodes do not store values directly, but rather provide a recipe for how to generate the parameters’ values. The sampled values are read from and written to a GraphState object that stores all the parameter values for all the nodes.

node_label[source]

An optional human readable identifier (name) for the current node.

Type:

str

node_string[source]

The full string used to identify a node. This is a combination of the node’s position in the graph (if known), node_label (if provided), and class information. This is used to access the parameters for this node in the graph_state.

Type:

str

setters[source]

A dictionary mapping the parameters’ names to information about the setters (_ParameterSource). The model parameters are stored in the order in which they need to be set.

Type:

dict

node_pos[source]

A unique ID number for each node in the graph indicating its position. Assigned during resampling or set_graph_positions(). This is required to resolve naming collisions so we do not overwrite parameters from other nodes.

Type:

int or None

Parameters:
  • node_label (str, optional) – An identifier (or name) for the current node.

  • **kwargs (dict, optional) – Any additional keyword arguments.

setters[source]
node_label = None[source]
node_pos = None[source]
node_string = None[source]
__str__()[source]

Return the string representation of the node.

set_graph_positions(seen_nodes=None)[source]

Force an update of the graph structure (numbering of each node).

Parameters:

seen_nodes (set, optional) – A set of nodes that have already been processed to prevent infinite loops. Caller should not set.

list_params()[source]

Return a list of this node’s parameterized values.

Returns:

names – The name of all of the parameterized values for this node.

Return type:

list[str]

describe_params(names=None)[source]

Print out a description of the node’s parameters.

Parameters:

names (list[str], optional) – A list of parameter names to describe. If None, describes all parameters.

has_valid_param(name)[source]

Check whether the node has a given parameterized value and that it is not always set to None.

Parameters:

name (str) – The name of the parameter.

Returns:

contains – Whether the node contains a given parameter and it is not always None.

Return type:

bool

get_param(graph_state, name, default=None)[source]

Get the value of a parameter stored in this node or a default value.

Note

This is an optional helper function that accesses the internals of graph_state using the node’s information (e.g. its hash value).

Parameters:
  • graph_state (dict) – The dictionary of graph state information.

  • name (str) – The parameter name to query.

  • default (any) – The default value to return if the parameter is not in GraphState.

Returns:

The parameter value or the default.

Return type:

any

Raises:

ValueError – If graph_state is None.

get_local_params(graph_state)[source]

Get a dictionary of all parameters local to this node.

Note

This is an optional helper function that accesses the internals of graph_state using the node’s information (e.g. its hash value).

Parameters:

graph_state (GraphState) – An object mapping graph parameters to their values.

Returns:

result – A dictionary mapping the parameter name to its value.

Return type:

dict

Raises:
  • KeyError – If no parameters have been set for this node.

  • ValueError – If graph_state is None.

set_parameter(name, value=None, **kwargs)[source]

Set the source of a single existing parameter in the ParameterizedNode.

Parameters within a node are actually a mapping of the parameter name to a _ParameterSource object that indicates how they are set during sampling. Some parameters may be set as a constant value while others may be set by evaluating a function that depends on other parameters.

Note

  • Does NOT set an initial value for the model parameter. The user must sample the parameters for this to be set.

  • The model parameters are stored in the order in which they are added.

Parameters:
  • name (str) – The parameter name to add.

  • value (any, optional) – The information to use to set the parameter. Can be a constant, function, ParameterizedNode, or self.

  • **kwargs (dict, optional) – All other keyword arguments, possibly including the parameter setters.

Raises:
  • KeyError – If there is a parameter collision or the parameter cannot be found.

  • ValueError – If the setter type is not supported.

set_allow_gradient(name, allow_gradient)[source]

Turn on or off the ability to compute a gradient for this variable.

Parameters:
  • name (str) – The parameter name to modify.

  • allow_gradient (bool) – The new setting for allow_gradient.

add_parameter(name, value=None, allow_gradient=None, description=None, **kwargs)[source]

Add a single new parameter to the ParameterizedNode.

Note

  • Checks multiple sources in the following order: Manually specified value, an entry in kwargs, or None.

  • Does NOT set an initial value for the model parameter. The user must sample the parameters for this to be set.

  • The model parameters are stored in the order in which they are added.

Parameters:
  • name (str) – The parameter name to add.

  • value (any, optional) – The information to use to set the parameter. Can be a constant, function, ParameterizedNode, or self.

  • allow_gradient (bool or None) – Allow gradients to be computed for this variable. If set to None uses the default for the setter type (True for constant and False for everything else). Default: None

  • description (str, optional) – A brief description of the parameter.

  • **kwargs (dict, optional) – All other keyword arguments, possibly including the parameter setters.

Raises:

KeyError – If there is a parameter collision or the parameter cannot be found.

get_parameter_indicator(param_name)[source]

Return the _AttributeIndicatorNode for a given parameter. This replicates the behavior of accessing the parameter as an attribute of the object (e.g., obj.param_name), but in functional form.

Parameters:

param_name (str) – The name of the parameter indicator to get.

Returns:

getter – The indicator node for the given parameter.

Return type:

_AttributeIndicatorNode

compute(graph_state, rng_info=None, **kwargs)[source]

Placeholder for a general compute function, which is called at the end of the sampling process and can produce derived parameters. This function is the main processing step in a FunctionNode.

Parameters:
  • graph_state (GraphState) – An object mapping graph parameters to their values. This object is modified in place as it is sampled.

  • rng_info (numpy.random._generator.Generator, optional) – A given numpy random number generator to use for this computation. If not provided, the function uses the node’s random number generator.

  • **kwargs (dict, optional) – Additional function arguments.

sample_parameters(given_args=None, num_samples=1, rng_info=None, sample_offset=0)[source]

Sample the model’s underlying parameters if they are provided by a function or ParameterizedNode.

Parameters:
  • given_args (dict, optional) – A dictionary representing the given arguments for this sample run. This can be used as the JAX PyTree for differentiation.

  • num_samples (int) – A count of the number of samples to compute. Default: 1

  • rng_info (numpy.random._generator.Generator, optional) – A given numpy random number generator to use for this computation. If not provided, the function uses the node’s random number generator.

  • sample_offset (int) – An optional offset to add to the graph state for any stateful nodes. This allows the system to better support testing and parallelized sampling. Default: 0 (no offset)

Returns:

graph_state – A dictionary of dictionaries mapping node->hash, variable_name to either a value or array of values. This data structure is modified in place to represent the model’s state(s).

Return type:

GraphState

Raises:

ValueError – If the sampling encounters a problem with the order of dependencies.

build_dependency_graph()[source]

Build a directed acyclic graph (DAG) representing the parameters in the model and their dependencies.

Returns:

dependency_graph – An object tracking the dependencies between parameters.

Return type:

DependencyGraph

build_pytree(graph_state, partial=None)[source]

Build a JAX PyTree representation of the variables in this graph.

Parameters:
  • graph_state (dict) – A dictionary of dictionaries mapping node->hash, variable_name to value. This data structure is modified in place to represent the current state.

  • partial (dict) – The partial results so far. This is modified in place by the function. A dictionary mapping node name to a dictionary mapping each variable’s name to its value. Default: None

Returns:

values – A dictionary mapping node name to a dictionary mapping each variable’s name to its value.

Return type:

dict

class FunctionNode(func, node_label=None, outputs=None, fixed_params=None, **kwargs)[source]

Bases: ParameterizedNode

A class to wrap functions and their argument settings.

The node can compute the result using a given function (the func parameter) or through the compute() method. If func=None then the user must override compute().

func

The function to call during an evaluation. If this is None you must override the compute() method directly.

Type:

function or method or partial

args_names

A list of argument names to pass to the function.

Type:

list

outputs[source]

The output model parameters of this function.

Type:

list of str

Parameters:
  • func (function or method) – The function to call during an evaluation.

  • node_label (str, optional) – An identifier (or name) for the current node.

  • outputs (list of str, optional) – The output model parameters of this function. If None, uses a single model parameter result.

  • fixed_params (dict, optional) – A dictionary mapping a parameter name in the function to its fixed value.

  • **kwargs – Any additional keyword arguments.

Examples

>>> my_func = TDFunc(random.randint, a=1, b=10)
>>> value1 = my_func()      # Sample from default range
>>> value2 = my_func(b=20)  # Sample from extended range

Note

All the function’s parameters that will be used need to be specified in either the default_args dict, object_args list, or as a kwarg in the constructor. Arguments cannot be first given during function call. For example, the following will fail (because b is not defined in the constructor):

my_func = TDFunc(random.randint, a=1)
value1 = my_func(b=10.0)
arg_names = [][source]
outputs = None[source]
compute(graph_state, rng_info=None, **kwargs)[source]

Execute the wrapped function.

The input arguments are taken from the current graph_state and the outputs are written to graph_state.

Parameters:
  • graph_state (GraphState) – An object mapping graph parameters to their values. This object is modified in place as it is sampled.

  • rng_info (numpy.random._generator.Generator, optional) – A given numpy random number generator to use for this computation. If not provided, the function uses the node’s random number generator.

  • **kwargs (dict, optional) – Additional function arguments.

Returns:

results – The result of the computation. This return value is provided so that testing functions can easily access the results.

Return type:

any

Raises:

ValueError – If func attribute is None.

generate(given_args=None, num_samples=1, rng_info=None, **kwargs)[source]

A helper function that regenerates the parameters for this node and the nodes above it, then returns the output for this individual node.

This is used both for testing and for computing JAX gradients.

Parameters:
  • given_args (dict, optional) – A dictionary representing the given arguments for this sample run. This can be used as the JAX PyTree for differentiation.

  • num_samples (int) – A count of the number of samples to compute. Default: 1

  • rng_info (numpy.random._generator.Generator, optional) – A given numpy random number generator to use for this computation. If not provided, the function uses the node’s random number generator.

  • **kwargs (dict, optional) – Additional function arguments.