24 #include "pybind11/pybind11.h"
25 #include "pybind11/stl.h"
27 #include "numpy/arrayobject.h"
28 #include "ndarray/pybind11.h"
34 namespace py = pybind11;
35 using namespace pybind11::literals;
42 using PyModel = py::class_<Model, std::shared_ptr<Model>>;
44 PYBIND11_PLUGIN(model) {
45 py::module::import(
"lsst.shapelet");
46 py::module::import(
"lsst.meas.modelfit.priors");
47 py::module::import(
"lsst.meas.modelfit.unitSystem");
51 if (_import_array() < 0) {
52 PyErr_SetString(PyExc_ImportError,
"numpy.core.multiarray failed to import");
56 PyModel cls(mod,
"Model");
58 py::enum_<Model::CenterEnum>(cls,
"CenterEnum")
59 .value(
"FIXED_CENTER", Model::FIXED_CENTER)
60 .value(
"SINGLE_CENTER", Model::SINGLE_CENTER)
61 .value(
"MULTI_CENTER", Model::MULTI_CENTER)
64 cls.def_static(
"make", (std::shared_ptr<Model>(*)(Model::BasisVector, Model::NameVector
const &,
67 "basisVector"_a,
"prefixes"_a,
"center"_a);
68 cls.def_static(
"make", (std::shared_ptr<Model>(*)(std::shared_ptr<shapelet::MultiShapeletBasis>,
71 "basis"_a,
"center"_a);
72 cls.def_static(
"makeGaussian", &Model::makeGaussian,
"center"_a,
"radius"_a = 1.0);
73 cls.def(
"getNonlinearDim", &Model::getNonlinearDim);
74 cls.def(
"getAmplitudeDim", &Model::getAmplitudeDim);
75 cls.def(
"getFixedDim", &Model::getFixedDim);
76 cls.def(
"getBasisCount", &Model::getBasisCount);
77 cls.def(
"getNonlinearNames", &Model::getNonlinearNames, py::return_value_policy::copy);
78 cls.def(
"getAmplitudeNames", &Model::getAmplitudeNames, py::return_value_policy::copy);
79 cls.def(
"getFixedNames", &Model::getFixedNames, py::return_value_policy::copy);
80 cls.def(
"getBasisVector", &Model::getBasisVector, py::return_value_policy::copy);
81 cls.def(
"makeShapeletFunction", &Model::makeShapeletFunction);
82 cls.def(
"adaptPrior", &Model::adaptPrior);
83 cls.def(
"makeEllipseVector", &Model::makeEllipseVector);
84 cls.def(
"writeEllipses",
85 (Model::EllipseVector (Model::*)(ndarray::Array<Scalar const, 1, 1>
const &,
86 ndarray::Array<Scalar const, 1, 1>
const &)
const) &
88 "nonlinear"_a,
"fixed"_a);
89 cls.def(
"readEllipses",
90 (
void (Model::*)(Model::EllipseVector
const &, ndarray::Array<Scalar, 1, 1>
const &,
91 ndarray::Array<Scalar, 1, 1>
const &)
const) &
93 "ellipses"_a,
"nonlinear"_a,
"fixed"_a);
94 cls.def(
"transformParameters", &Model::transformParameters,
"transform"_a,
"nonlinear"_a,
"amplitudes"_a,