src.superphot_plus.samplers.numpyro_sampler

MCMC sampling using numpyro.

Module Contents

Classes

NumpyroSampler

MCMC sampling using numpyro.

Functions

prior_helper(priors[, aux_b])

Helper function to sample prior values. If aux_b is not None,

lax_helper_function(svi, svi_state, num_iters, *args, ...)

Helper function using LAX to speed up SVI state updates.

trunc_norm(fields[, low, high])

Provides keyword parameters to numpyro's TruncatedNormal, using the fields in PriorFields.

create_jax_model(priors[, t, obsflux, uncertainties, ...])

Create a JAX model for MCMC.

create_jax_guide(priors[, t, obsflux, uncertainties, ...])

JAX guide function for MCMC.

_svi_helper_no_recompile(lc_single, max_flux, priors, ...)

Helper function to run SVI on a single light curve with an already

run_mcmc(lc, rng_seed[, sampler, priors, ref_params])

Runs MCMC using numpyro on the lightcurve to get set

class NumpyroSampler(sampler='svi')[source]

Bases: superphot_plus.samplers.sampler.Sampler

MCMC sampling using numpyro.

run_single_curve(lightcurve: superphot_plus.lightcurve.Lightcurve, priors: superphot_plus.surveys.fitting_priors.MultibandPriors, rng_seed, ref_params=None, **kwargs) superphot_plus.posterior_samples.PosteriorSamples[source]

Run the sampler on a single light curve.

Parameters:
  • lightcurve (Lightcurve) – The lightcurve to sample.

  • priors (MultibandPriors) – The curve priors to use.

  • rng_seed (int or None) – The random seed to use (for testing purposes). The user should pass None in cases where they want a fully random run.

  • sampler (str) – The numpyro sampler to use. Either “NUTS” or “svi”

Returns:

eq_wt_samples – The resulting samples.

Return type:

PosteriorSamples

run_multi_curve(lightcurves, priors: superphot_plus.surveys.fitting_priors.MultibandPriors, rng_seed, sampler='svi', ref_params=None, **kwargs) List[superphot_plus.posterior_samples.PosteriorSamples][source]

Not yet implemented.

prior_helper(priors, aux_b=None)[source]

Helper function to sample prior values. If aux_b is not None, appends aux_b to value names.

Parameters:
  • priors (CurvePriors) – The priors for one band

  • max_flux (float) – Max flux of the light curve.

  • aux_b (str, optional) – The name of the auxiliary band, if it is auxiliary. Defaults to None, which assumes it’s the base band.

lax_helper_function(svi, svi_state, num_iters, *args, **kwargs)[source]

Helper function using LAX to speed up SVI state updates.

trunc_norm(fields: superphot_plus.surveys.fitting_priors.PriorFields, low=None, high=None)[source]

Provides keyword parameters to numpyro’s TruncatedNormal, using the fields in PriorFields.

Parameters:

fields (PriorFields) – The (low, high, mean, standard deviation) fields of the truncated normal distribution.

Returns:

A truncated normal distribution.

Return type:

numpyro.distributions.TruncatedDistribution

create_jax_model(priors, t=None, obsflux=None, uncertainties=None, max_flux=None, ref_params=None)[source]

Create a JAX model for MCMC.

Parameters:
  • t (array-like, optional) – Time values. Defaults to None.

  • obsflux (array-like, optional) – Observed flux values. Defaults to None.

  • uncertainties (array-like, optional) – Flux uncertainties. Defaults to None.

  • max_flux (float, optional) – Maximum flux value. Defaults to None.

  • priors (MultibandPriors) – priors for all bands in lightcurves

create_jax_guide(priors, t=None, obsflux=None, uncertainties=None, max_flux=None, ref_params=None)[source]

JAX guide function for MCMC.

Parameters:

priors (MultibandPriors) – priors for all bands in lightcurves

_svi_helper_no_recompile(lc_single, max_flux, priors, svi, svi_state, lax_jit, num_iter, seed, ref_params=None)[source]

Helper function to run SVI on a single light curve with an already compiled SVI sampler object.

run_mcmc(lc, rng_seed, sampler='NUTS', priors=Survey.ZTF().priors, ref_params=None)[source]

Runs MCMC using numpyro on the lightcurve to get set of equally weighted posteriors (sets of fit parameters).

Parameters:
  • lc (Lightcurve object) – The Lightcurve object on which to run MCMC

  • rng_seed (int or None) – The random seed to use (for testing purposes). The user should pass None in cases where they want a fully random run.

  • sampler (str, optional) – The MCMC sampler to use. Defaults to “NUTS”.

  • priors (MultibandPriors, optional) – The prior set to use for fitting. Defaults to ZTF’s priors.

Returns:

A set of equally weighted posteriors (sets of fit parameters) as a numpy array. If the lightcurve does not contain any valid points, None is returned.

Return type:

np.ndarray or None