24 #include "pybind11/pybind11.h"
25 #include "pybind11/stl.h"
27 #include "numpy/arrayobject.h"
28 #include "ndarray/pybind11.h"
32 namespace py = pybind11;
33 using namespace pybind11::literals;
40 using Sampler = TruncatedGaussianSampler;
41 using Evaluator = TruncatedGaussianEvaluator;
42 using LogEvaluator = TruncatedGaussianLogEvaluator;
44 using PyTruncatedGaussian = py::class_<TruncatedGaussian, std::shared_ptr<TruncatedGaussian>>;
45 using PySampler = py::class_<Sampler, std::shared_ptr<Sampler>>;
46 using PyEvaluator = py::class_<Evaluator, std::shared_ptr<Evaluator>>;
47 using PyLogEvaluator = py::class_<LogEvaluator, std::shared_ptr<LogEvaluator>>;
52 template <
typename Class,
typename PyClass>
53 static PyClass declareEvaluator(
py::module &mod, std::string
const &
name) {
54 PyClass cls(mod, (
"TruncatedGaussian" + name).c_str());
55 cls.def(py::init<TruncatedGaussian const &>(),
"parent"_a);
57 (
Scalar (Class::*)(ndarray::Array<Scalar const, 1, 1>
const &)
const) & Class::operator(),
59 cls.def(
"__call__", (
void (Class::*)(ndarray::Array<Scalar const, 2, 1>
const &,
60 ndarray::Array<Scalar, 1, 1>
const &)
const) &
62 "alpha"_a,
"output"_a);
68 PYBIND11_PLUGIN(truncatedGaussian) {
69 py::module::import(
"lsst.afw.math");
73 if (_import_array() < 0) {
74 PyErr_SetString(PyExc_ImportError,
"numpy.core.multiarray failed to import");
78 PyTruncatedGaussian cls(mod,
"TruncatedGaussian");
79 py::enum_<TruncatedGaussian::SampleStrategy>(cls,
"SampleStrategy")
80 .value(
"DIRECT_WITH_REJECTION", TruncatedGaussian::DIRECT_WITH_REJECTION)
81 .value(
"ALIGN_AND_WEIGHT", TruncatedGaussian::ALIGN_AND_WEIGHT)
83 cls.def_static(
"fromSeriesParameters", &TruncatedGaussian::fromSeriesParameters,
"q0"_a,
"gradient"_a,
85 cls.def_static(
"fromStandardParameters", &TruncatedGaussian::fromStandardParameters,
"mean"_a,
87 cls.def(
"sample", (Sampler (TruncatedGaussian::*)(TruncatedGaussian::SampleStrategy)
const) &
88 TruncatedGaussian::sample,
90 cls.def(
"sample", (Sampler (TruncatedGaussian::*)(
Scalar)
const) & TruncatedGaussian::sample,
91 "minRejectionEfficiency"_a = 0.1);
92 cls.def(
"evaluateLog", &TruncatedGaussian::evaluateLog);
93 cls.def(
"evaluate", &TruncatedGaussian::evaluate);
94 cls.def(
"getDim", &TruncatedGaussian::getDim);
95 cls.def(
"maximize", &TruncatedGaussian::maximize);
96 cls.def(
"getUntruncatedFraction", &TruncatedGaussian::getUntruncatedFraction);
97 cls.def(
"getLogPeakAmplitude", &TruncatedGaussian::getLogPeakAmplitude);
98 cls.def(
"getLogIntegral", &TruncatedGaussian::getLogIntegral);
100 cls.attr(
"LogEvaluator") = declareEvaluator<LogEvaluator, PyLogEvaluator>(mod,
"LogEvaluator");
101 cls.attr(
"Evaluator") = declareEvaluator<Evaluator, PyEvaluator>(mod,
"Evaluator");
103 PySampler clsSampler(mod,
"TruncatedGaussianSampler");
104 clsSampler.def(py::init<TruncatedGaussian const &, TruncatedGaussian::SampleStrategy>(),
"parent"_a,
106 clsSampler.def(
"__call__",
107 (
Scalar (Sampler::*)(afw::math::Random &, ndarray::Array<Scalar, 1, 1>
const &)
const) &
110 clsSampler.def(
"__call__", (
void (Sampler::*)(afw::math::Random &, ndarray::Array<Scalar, 2, 1>
const &,
111 ndarray::Array<Scalar, 1, 1>
const &,
bool)
const) &
113 "rng"_a,
"alpha"_a,
"weights"_a,
"multiplyWeights"_a =
false);
115 cls.attr(
"Sampler") = clsSampler;
double Scalar
Typedefs to be used for probability and parameter values.