lsst.meas.modelfit  13.0-10-g4e34388+1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Pages
mixture.cc
Go to the documentation of this file.
1 // -*- lsst-c++ -*-
2 /*
3  * LSST Data Management System
4  * Copyright 2008-2013 LSST Corporation.
5  *
6  * This product includes software developed by the
7  * LSST Project (http://www.lsst.org/).
8  *
9  * This program is free software: you can redistribute it and/or modify
10  * it under the terms of the GNU General Public License as published by
11  * the Free Software Foundation, either version 3 of the License, or
12  * (at your option) any later version.
13  *
14  * This program is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17  * GNU General Public License for more details.
18  *
19  * You should have received a copy of the LSST License Statement and
20  * the GNU General Public License along with this program. If not,
21  * see <http://www.lsstcorp.org/LegalNotices/>.
22  */
23 
24 #include "pybind11/pybind11.h"
25 #include "pybind11/stl.h"
26 
27 #include <sstream> // Python.h must come before even system headers
28 
29 #include "numpy/arrayobject.h"
30 #include "ndarray/pybind11.h"
31 #include "ndarray/eigen.h"
32 
33 #include "lsst/utils/python.h"
35 
36 namespace py = pybind11;
37 using namespace pybind11::literals;
38 
39 namespace lsst {
40 namespace meas {
41 namespace modelfit {
42 namespace {
43 
44 using PyMixtureComponent = py::class_<MixtureComponent>;
45 using PyMixtureUpdateRestriction =
46  py::class_<MixtureUpdateRestriction, std::shared_ptr<MixtureUpdateRestriction>>;
47 using PyMixture = py::class_<Mixture, std::shared_ptr<Mixture>, afw::table::io::PersistableFacade<Mixture>,
48  afw::table::io::Persistable>;
49 
50 static PyMixtureComponent declareMixtureComponent(py::module &mod) {
51  PyMixtureComponent cls(mod, "MixtureComponent");
52  cls.def("getDimension", &MixtureComponent::getDimension);
53  cls.def_readwrite("weight", &MixtureComponent::weight);
54  cls.def("getMu", &MixtureComponent::getMu);
55  cls.def("setMu", &MixtureComponent::setMu);
56  cls.def("getSigma", &MixtureComponent::getSigma);
57  cls.def("setSigma", &MixtureComponent::setSigma);
58  cls.def("project", (MixtureComponent (MixtureComponent::*)(int) const) & MixtureComponent::project,
59  "dim"_a);
60  cls.def("project", (MixtureComponent (MixtureComponent::*)(int, int) const) & MixtureComponent::project,
61  "dim1"_a, "dim2"_a);
62  cls.def(py::init<int>(), "dim"_a);
63  cls.def(py::init<Scalar, Vector const &, Matrix const &>(), "weight"_a, "mu"_a, "sigma"_a);
64  auto streamStr = [](MixtureComponent const &self) {
65  std::ostringstream os;
66  os << self;
67  return os.str();
68  };
69  cls.def("__str__", streamStr);
70  cls.def("__repr__", streamStr);
71  return cls;
72 }
73 
74 static PyMixtureUpdateRestriction declareMixtureUpdateRestriction(py::module &mod) {
75  PyMixtureUpdateRestriction cls(mod, "MixtureUpdateRestriction");
76  cls.def("getDimension", &MixtureUpdateRestriction::getDimension);
77  cls.def(py::init<int>(), "dim"_a);
78  // The rest of this interface isn't usable in Python, and doesn't need to be.
79  return cls;
80 }
81 
82 static PyMixture declareMixture(py::module &mod) {
83  afw::table::io::python::declarePersistableFacade<Mixture>(mod, "Mixture");
84  PyMixture cls(mod, "Mixture");
85  cls.def("__iter__", [](Mixture &self) { return py::make_iterator(self.begin(), self.end()); },
86  py::keep_alive<0, 1>());
87  cls.def("__getitem__",
88  [](Mixture &self, std::ptrdiff_t i) { return self[utils::python::cppIndex(self.size(), i)]; },
89  py::return_value_policy::reference_internal);
90  cls.def("__len__", &Mixture::size);
91  cls.def("getComponentCount", &Mixture::getComponentCount);
92  cls.def("project", (std::shared_ptr<Mixture> (Mixture::*)(int) const) & Mixture::project, "dim"_a);
93  cls.def("project", (std::shared_ptr<Mixture> (Mixture::*)(int, int) const) & Mixture::project, "dim1"_a,
94  "dim2"_a);
95  cls.def("getDimension", &Mixture::getDimension);
96  cls.def("normalize", &Mixture::normalize);
97  cls.def("shift", &Mixture::shift, "dim"_a, "offset"_a);
98  cls.def("clip", &Mixture::clip, "threshold"_a = 0.0);
99  cls.def("getDegreesOfFreedom", &Mixture::getDegreesOfFreedom);
100  cls.def("setDegreesOfFreedom", &Mixture::setDegreesOfFreedom,
101  "df"_a = std::numeric_limits<Scalar>::infinity());
102  cls.def("evaluate",
103  [](Mixture const &self, MixtureComponent const &component,
104  ndarray::Array<Scalar, 1, 0> const &array) -> Scalar {
105  return self.evaluate(component, array.asEigen());
106  },
107  "component"_a, "x"_a);
108  cls.def("evaluate",
109  [](Mixture const &self, ndarray::Array<Scalar, 1, 0> const &array) -> Scalar {
110  return self.evaluate(array.asEigen());
111  },
112  "x"_a);
113  cls.def("evaluate", (void (Mixture::*)(ndarray::Array<Scalar const, 2, 1> const &,
114  ndarray::Array<Scalar, 1, 0> const &) const) &
115  Mixture::evaluate,
116  "x"_a, "p"_a);
117  cls.def("evaluateComponents", &Mixture::evaluateComponents, "x"_a, "p"_a);
118  cls.def("evaluateDerivatives", &Mixture::evaluateDerivatives, "x"_a, "gradient"_a, "hessian"_a);
119  cls.def("draw", &Mixture::draw, "rng"_a, "x"_a);
120  cls.def("updateEM", (void (Mixture::*)(ndarray::Array<Scalar const, 2, 1> const &,
121  ndarray::Array<Scalar const, 1, 0> const &, Scalar, Scalar)) &
122  Mixture::updateEM,
123  "x"_a, "w"_a, "tau1"_a = 0.0, "tau2"_a = 0.5);
124  cls.def("updateEM", (void (Mixture::*)(ndarray::Array<Scalar const, 2, 1> const &,
125  ndarray::Array<Scalar const, 1, 0> const &,
126  MixtureUpdateRestriction const &restriction, Scalar, Scalar)) &
127  Mixture::updateEM,
128  "x"_a, "w"_a, "restriction"_a, "tau1"_a = 0.0, "tau2"_a = 0.5);
129  cls.def("updateEM", (void (Mixture::*)(ndarray::Array<Scalar const, 2, 1> const &,
130  MixtureUpdateRestriction const &restriction, Scalar, Scalar)) &
131  Mixture::updateEM,
132  "x"_a, "restriction"_a, "tau1"_a = 0.0, "tau2"_a = 0.5);
133  cls.def("clone", &Mixture::clone);
134  cls.def(py::init<int, Mixture::ComponentList &, Scalar>(), "dim"_a, "components"_a,
135  "df"_a = std::numeric_limits<Scalar>::infinity());
136  auto streamStr = [](Mixture const &self) {
137  std::ostringstream os;
138  os << self;
139  return os.str();
140  };
141  return cls;
142  cls.def("__str__", streamStr);
143  cls.def("__repr__", streamStr);
144  return cls;
145 }
146 
147 PYBIND11_PLUGIN(mixture) {
148  py::module::import("lsst.afw.math");
149 
150  py::module mod("mixture");
151 
152  if (_import_array() < 0) {
153  PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import");
154  return nullptr;
155  }
156 
157  auto clsMixtureComponent = declareMixtureComponent(mod);
158  auto clsMixtureUpdateRestriction = declareMixtureUpdateRestriction(mod);
159  auto clsMixture = declareMixture(mod);
160  clsMixture.attr("Component") = clsMixtureComponent;
161  clsMixture.attr("UpdateRestriction") = clsMixtureUpdateRestriction;
162 
163  return mod.ptr();
164 }
165 }
166 }
167 }
168 } // namespace lsst::meas::modelfit::anonymous
double Scalar
Typedefs to be used for probability and parameter values.
Definition: common.h:44