Skip to content

Commit

Permalink
add2
Browse files Browse the repository at this point in the history
  • Loading branch information
DimasfromLavoisier committed Nov 12, 2024
1 parent 05a0f8b commit b2a5d1d
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions src/python/SimulationCycleCallbackWrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,36 @@

namespace py = pybind11;

class [[gnu::visibility("default")]] SimulationCycleCallbackWrap : public SimulationCycleCallback {
public:
using SimulationCycleCallback::SimulationCycleCallback;
SimulationCycleCallbackWrap(py::object obj) : SimulationCycleCallback(), py_obj(std::move(obj)) {}

void operator()(
std::vector<Measurement> &measurements,
std::vector<Trajectory> &trajectories,
const std::string &outpath) override {

py::gil_scoped_acquire acquire; // Acquire GIL before calling Python code

// Create a tuple with the required components
py::tuple output = py::make_tuple(
measurements,
trajectories,
outpath,
std::vector<std::string>{outpath},
false
);


py_obj(output);
}

private:
py::object py_obj;

};

class PySimulationCycleCallback : public SimulationCycleCallback {
public:
using SimulationCycleCallback::SimulationCycleCallback;
Expand All @@ -16,28 +46,22 @@ class PySimulationCycleCallback : public SimulationCycleCallback {
const std::string& outpath) override {

if (is_callback_in_progress) {
//std::cout << "Callback already in progress, skipping." << std::endl;
return;
}

is_callback_in_progress = true; // Set the flag to prevent recursion

py::gil_scoped_acquire acquire; // Acquire GIL before calling Python code
// std::cout << "EKFKOEKFEOF IT IS IN THE TRAMPOLINE" << std::endl;

py::object py_self = py::cast(this, py::return_value_policy::reference); // Reference to the Python object

if (py::hasattr(py_self, "__call__")) {
// std::cout << "EKFKOEKFEOF IT IS IN THE TRAMPOLINE 2" << std::endl;

if (py::hasattr(py_self, "__call__")) {
// Convert C++ vectors to Python lists
py::list measurements_list = py::cast(measurements);
py::list trajectories_list = py::cast(trajectories);

// Call the Python __call__ method with converted arguments
// std::cout << "EKFKOEKFEOF IT IS IN THE TRAMPOLINE 3" << std::endl;
py_self.attr("__call__")(measurements_list, trajectories_list, outpath);
// std::cout << "EKFKOEKFEOF IT IS IN THE TRAMPOLINE 4" << std::endl;
} else {
throw std::runtime_error("Python __call__ method is missing on SimulationCycleCallback.");
}
Expand Down

0 comments on commit b2a5d1d

Please sign in to comment.