24 #include "pybind11/pybind11.h"
25 #include "pybind11/stl.h"
29 #include "numpy/arrayobject.h"
30 #include "ndarray/pybind11.h"
31 #include "ndarray/eigen.h"
33 #include "lsst/utils/python.h"
36 namespace py = pybind11;
37 using namespace pybind11::literals;
44 using PyMixtureComponent = py::class_<MixtureComponent>;
45 using PyMixtureUpdateRestriction =
46 py::class_<MixtureUpdateRestriction, std::shared_ptr<MixtureUpdateRestriction>>;
47 using PyMixture = py::class_<Mixture, std::shared_ptr<Mixture>, afw::table::io::PersistableFacade<Mixture>,
48 afw::table::io::Persistable>;
50 static PyMixtureComponent declareMixtureComponent(
py::module &mod) {
51 PyMixtureComponent cls(mod,
"MixtureComponent");
52 cls.def(
"getDimension", &MixtureComponent::getDimension);
53 cls.def_readwrite(
"weight", &MixtureComponent::weight);
54 cls.def(
"getMu", &MixtureComponent::getMu);
55 cls.def(
"setMu", &MixtureComponent::setMu);
56 cls.def(
"getSigma", &MixtureComponent::getSigma);
57 cls.def(
"setSigma", &MixtureComponent::setSigma);
58 cls.def(
"project", (MixtureComponent (MixtureComponent::*)(
int)
const) & MixtureComponent::project,
60 cls.def(
"project", (MixtureComponent (MixtureComponent::*)(
int,
int)
const) & MixtureComponent::project,
62 cls.def(py::init<int>(),
"dim"_a);
63 cls.def(py::init<Scalar, Vector const &, Matrix const &>(),
"weight"_a,
"mu"_a,
"sigma"_a);
64 auto streamStr = [](MixtureComponent
const &
self) {
65 std::ostringstream os;
69 cls.def(
"__str__", streamStr);
70 cls.def(
"__repr__", streamStr);
74 static PyMixtureUpdateRestriction declareMixtureUpdateRestriction(
py::module &mod) {
75 PyMixtureUpdateRestriction cls(mod,
"MixtureUpdateRestriction");
76 cls.def(
"getDimension", &MixtureUpdateRestriction::getDimension);
77 cls.def(py::init<int>(),
"dim"_a);
82 static PyMixture declareMixture(
py::module &mod) {
83 afw::table::io::python::declarePersistableFacade<Mixture>(mod,
"Mixture");
84 PyMixture cls(mod,
"Mixture");
85 cls.def(
"__iter__", [](Mixture &
self) {
return py::make_iterator(
self.begin(),
self.end()); },
86 py::keep_alive<0, 1>());
87 cls.def(
"__getitem__",
88 [](Mixture &
self, std::ptrdiff_t i) {
return self[utils::python::cppIndex(
self.size(), i)]; },
89 py::return_value_policy::reference_internal);
90 cls.def(
"__len__", &Mixture::size);
91 cls.def(
"getComponentCount", &Mixture::getComponentCount);
92 cls.def(
"project", (std::shared_ptr<Mixture> (Mixture::*)(
int)
const) & Mixture::project,
"dim"_a);
93 cls.def(
"project", (std::shared_ptr<Mixture> (Mixture::*)(
int,
int)
const) & Mixture::project,
"dim1"_a,
95 cls.def(
"getDimension", &Mixture::getDimension);
96 cls.def(
"normalize", &Mixture::normalize);
97 cls.def(
"shift", &Mixture::shift,
"dim"_a,
"offset"_a);
98 cls.def(
"clip", &Mixture::clip,
"threshold"_a = 0.0);
99 cls.def(
"getDegreesOfFreedom", &Mixture::getDegreesOfFreedom);
100 cls.def(
"setDegreesOfFreedom", &Mixture::setDegreesOfFreedom,
101 "df"_a = std::numeric_limits<Scalar>::infinity());
103 [](Mixture
const &
self, MixtureComponent
const &component,
104 ndarray::Array<Scalar, 1, 0>
const &array) ->
Scalar {
105 return self.evaluate(component, array.asEigen());
107 "component"_a,
"x"_a);
109 [](Mixture
const &
self, ndarray::Array<Scalar, 1, 0>
const &array) ->
Scalar {
110 return self.evaluate(array.asEigen());
113 cls.def(
"evaluate", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
114 ndarray::Array<Scalar, 1, 0>
const &)
const) &
117 cls.def(
"evaluateComponents", &Mixture::evaluateComponents,
"x"_a,
"p"_a);
118 cls.def(
"evaluateDerivatives", &Mixture::evaluateDerivatives,
"x"_a,
"gradient"_a,
"hessian"_a);
119 cls.def(
"draw", &Mixture::draw,
"rng"_a,
"x"_a);
120 cls.def(
"updateEM", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
121 ndarray::Array<Scalar const, 1, 0>
const &,
Scalar,
Scalar)) &
123 "x"_a,
"w"_a,
"tau1"_a = 0.0,
"tau2"_a = 0.5);
124 cls.def(
"updateEM", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
125 ndarray::Array<Scalar const, 1, 0>
const &,
126 MixtureUpdateRestriction
const &restriction,
Scalar,
Scalar)) &
128 "x"_a,
"w"_a,
"restriction"_a,
"tau1"_a = 0.0,
"tau2"_a = 0.5);
129 cls.def(
"updateEM", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
130 MixtureUpdateRestriction
const &restriction,
Scalar,
Scalar)) &
132 "x"_a,
"restriction"_a,
"tau1"_a = 0.0,
"tau2"_a = 0.5);
133 cls.def(
"clone", &Mixture::clone);
134 cls.def(py::init<int, Mixture::ComponentList &, Scalar>(),
"dim"_a,
"components"_a,
135 "df"_a = std::numeric_limits<Scalar>::infinity());
136 auto streamStr = [](Mixture
const &
self) {
137 std::ostringstream os;
142 cls.def(
"__str__", streamStr);
143 cls.def(
"__repr__", streamStr);
147 PYBIND11_PLUGIN(mixture) {
148 py::module::import(
"lsst.afw.math");
152 if (_import_array() < 0) {
153 PyErr_SetString(PyExc_ImportError,
"numpy.core.multiarray failed to import");
157 auto clsMixtureComponent = declareMixtureComponent(mod);
158 auto clsMixtureUpdateRestriction = declareMixtureUpdateRestriction(mod);
159 auto clsMixture = declareMixture(mod);
160 clsMixture.attr(
"Component") = clsMixtureComponent;
161 clsMixture.attr(
"UpdateRestriction") = clsMixtureUpdateRestriction;
double Scalar
Typedefs to be used for probability and parameter values.