Skip to content

Commit

Permalink
Merge pull request #695 from RocketPy-Team/enh/sim-decoding
Browse files Browse the repository at this point in the history
ENH: Simulation Save and Load in JSON Files
  • Loading branch information
phmbressan authored Nov 12, 2024
2 parents e1a1aa2 + 5d1c585 commit 328a25d
Show file tree
Hide file tree
Showing 30 changed files with 1,388 additions and 138 deletions.
158 changes: 150 additions & 8 deletions rocketpy/_encoders.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
"""Defines a custom JSON encoder for RocketPy objects."""

import base64
import json
from datetime import datetime
from importlib import import_module

import dill
import numpy as np


class RocketPyEncoder(json.JSONEncoder):
"""Custom JSON encoder for RocketPy objects. It defines how to encode
different types of objects to a JSON supported format."""

def __init__(self, *args, **kwargs):
self.include_outputs = kwargs.pop("include_outputs", True)
super().__init__(*args, **kwargs)

def default(self, o):
if isinstance(
o,
Expand All @@ -33,17 +40,152 @@ def default(self, o):
elif isinstance(o, np.ndarray):
return o.tolist()
elif isinstance(o, datetime):
return o.isoformat()
return [o.year, o.month, o.day, o.hour]
elif hasattr(o, "__iter__") and not isinstance(o, str):
return list(o)
elif hasattr(o, "to_dict"):
return o.to_dict()
encoding = o.to_dict(self.include_outputs)
encoding = remove_circular_references(encoding)

encoding["signature"] = get_class_signature(o)

return encoding

elif hasattr(o, "__dict__"):
exception_set = {"prints", "plots"}
return {
key: value
for key, value in o.__dict__.items()
if key not in exception_set
}
encoding = remove_circular_references(o.__dict__)

if "rocketpy" in o.__class__.__module__:
encoding["signature"] = get_class_signature(o)

return encoding
else:
return super().default(o)


class RocketPyDecoder(json.JSONDecoder):
"""Custom JSON decoder for RocketPy objects. It defines how to decode
different types of objects from a JSON supported format."""

def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.object_hook, *args, **kwargs)

def object_hook(self, obj):
if "signature" in obj:
signature = obj.pop("signature")

try:
class_ = get_class_from_signature(signature)

if hasattr(class_, "from_dict"):
return class_.from_dict(obj)
else:
# Filter keyword arguments
kwargs = {
key: value
for key, value in obj.items()
if key in class_.__init__.__code__.co_varnames
}

return class_(**kwargs)
except (ImportError, AttributeError):
return obj
else:
return obj


def get_class_signature(obj):
"""Returns the signature of a class in the form of a string.
The signature is an importable string that can be used to import
the class by its module.
Parameters
----------
obj : object
Object to get the signature from.
Returns
-------
str
Signature of the class.
"""
class_ = obj.__class__
name = getattr(class_, '__qualname__', class_.__name__)

return {"module": class_.__module__, "name": name}


def get_class_from_signature(signature):
"""Returns the class by importing its signature.
Parameters
----------
signature : str
Signature of the class.
Returns
-------
type
Class defined by the signature.
"""
module = import_module(signature["module"])
inner_class = None

for class_ in signature["name"].split("."):
inner_class = getattr(module, class_)

return inner_class


def remove_circular_references(obj_dict):
"""Removes circular references from a dictionary.
Parameters
----------
obj_dict : dict
Dictionary to remove circular references from.
Returns
-------
dict
Dictionary without circular references.
"""
obj_dict.pop("prints", None)
obj_dict.pop("plots", None)

return obj_dict


def to_hex_encode(obj, encoder=base64.b85encode):
"""Converts an object to hex representation using dill.
Parameters
----------
obj : object
Object to be converted to hex.
encoder : callable, optional
Function to encode the bytes. Default is base64.b85encode.
Returns
-------
bytes
Object converted to bytes.
"""
return encoder(dill.dumps(obj)).hex()


def from_hex_decode(obj_bytes, decoder=base64.b85decode):
"""Converts an object from hex representation using dill.
Parameters
----------
obj_bytes : str
Hex string to be converted to object.
decoder : callable, optional
Function to decode the bytes. Default is base64.b85decode.
Returns
-------
object
Object converted from bytes.
"""
return dill.loads(decoder(bytes.fromhex(obj_bytes)))
90 changes: 90 additions & 0 deletions rocketpy/environment/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2743,6 +2743,96 @@ def decimal_degrees_to_arc_seconds(angle):
arc_seconds = (remainder * 60 - arc_minutes) * 60
return degrees, arc_minutes, arc_seconds

def to_dict(self, include_outputs=True):
env_dict = {
"gravity": self.gravity,
"date": self.date,
"latitude": self.latitude,
"longitude": self.longitude,
"elevation": self.elevation,
"datum": self.datum,
"timezone": self.timezone,
"_max_expected_height": self.max_expected_height,
"atmospheric_model_type": self.atmospheric_model_type,
"pressure": self.pressure,
"barometric_height": self.barometric_height,
"temperature": self.temperature,
"wind_velocity_x": self.wind_velocity_x,
"wind_velocity_y": self.wind_velocity_y,
"wind_heading": self.wind_heading,
"wind_direction": self.wind_direction,
"wind_speed": self.wind_speed,
}

if include_outputs:
env_dict["density"] = self.density
env_dict["speed_of_sound"] = self.speed_of_sound
env_dict["dynamic_viscosity"] = self.dynamic_viscosity

return env_dict

@classmethod
def from_dict(cls, data): # pylint: disable=too-many-statements
env = cls(
gravity=data["gravity"],
date=data["date"],
latitude=data["latitude"],
longitude=data["longitude"],
elevation=data["elevation"],
datum=data["datum"],
timezone=data["timezone"],
max_expected_height=data["_max_expected_height"],
)
atmospheric_model = data["atmospheric_model_type"]

if atmospheric_model == "standard_atmosphere":
env.set_atmospheric_model("standard_atmosphere")
elif atmospheric_model == "custom_atmosphere":
env.set_atmospheric_model(
type="custom_atmosphere",
pressure=data["pressure"],
temperature=data["temperature"],
wind_u=data["wind_velocity_x"],
wind_v=data["wind_velocity_y"],
)
else:
env.__set_pressure_function(data["pressure"])
env.__set_barometric_height_function(data["barometric_height"])
env.__set_temperature_function(data["temperature"])
env.__set_wind_velocity_x_function(data["wind_velocity_x"])
env.__set_wind_velocity_y_function(data["wind_velocity_y"])
env.__set_wind_heading_function(data["wind_heading"])
env.__set_wind_direction_function(data["wind_direction"])
env.__set_wind_speed_function(data["wind_speed"])
env.elevation = data["elevation"]
env.max_expected_height = data["_max_expected_height"]

if atmospheric_model in ["windy", "forecast", "reanalysis", "ensemble"]:
env.atmospheric_model_init_date = data["atmospheric_model_init_date"]
env.atmospheric_model_end_date = data["atmospheric_model_end_date"]
env.atmospheric_model_interval = data["atmospheric_model_interval"]
env.atmospheric_model_init_lat = data["atmospheric_model_init_lat"]
env.atmospheric_model_end_lat = data["atmospheric_model_end_lat"]
env.atmospheric_model_init_lon = data["atmospheric_model_init_lon"]
env.atmospheric_model_end_lon = data["atmospheric_model_end_lon"]

if atmospheric_model == "ensemble":
env.level_ensemble = data["level_ensemble"]
env.height_ensemble = data["height_ensemble"]
env.temperature_ensemble = data["temperature_ensemble"]
env.wind_u_ensemble = data["wind_u_ensemble"]
env.wind_v_ensemble = data["wind_v_ensemble"]
env.wind_heading_ensemble = data["wind_heading_ensemble"]
env.wind_direction_ensemble = data["wind_direction_ensemble"]
env.wind_speed_ensemble = data["wind_speed_ensemble"]
env.num_ensemble_members = data["num_ensemble_members"]

env.calculate_density_profile()
env.calculate_speed_of_sound_profile()
env.calculate_dynamic_viscosity()

return env


if __name__ == "__main__":
import doctest
Expand Down
5 changes: 4 additions & 1 deletion rocketpy/environment/environment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,10 @@ def __check_coordinates_inside_grid(
or lat_index > len(lat_array) - 1
):
raise ValueError(
f"Latitude and longitude pair {(self.latitude, self.longitude)} is outside the grid available in the given file, which is defined by {(lat_array[0], lon_array[0])} and {(lat_array[-1], lon_array[-1])}."
f"Latitude and longitude pair {(self.latitude, self.longitude)} "
"is outside the grid available in the given file, which "
f"is defined by {(lat_array[0], lon_array[0])} and "
f"{(lat_array[-1], lon_array[-1])}."
)

def __localize_input_dates(self):
Expand Down
19 changes: 8 additions & 11 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@
carefully as it may impact all the rest of the project.
"""

import base64
import warnings
import zlib
from bisect import bisect_left
from collections.abc import Iterable
from copy import deepcopy
from functools import cached_property
from inspect import signature
from pathlib import Path

import dill
import matplotlib.pyplot as plt
import numpy as np
from scipy import integrate, linalg, optimize
Expand All @@ -25,6 +22,8 @@
RBFInterpolator,
)

from rocketpy._encoders import from_hex_decode, to_hex_encode

# Numpy 1.x compatibility,
# TODO: remove these lines when all dependencies support numpy>=2.0.0
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
Expand Down Expand Up @@ -712,9 +711,9 @@ def set_discrete(
if func.__dom_dim__ == 1:
xs = np.linspace(lower, upper, samples)
ys = func.get_value(xs.tolist()) if one_by_one else func.get_value(xs)
func.set_source(np.concatenate(([xs], [ys])).transpose())
func.set_interpolation(interpolation)
func.set_extrapolation(extrapolation)
func.__interpolation__ = interpolation
func.__extrapolation__ = extrapolation
func.set_source(np.column_stack((xs, ys)))
elif func.__dom_dim__ == 2:
lower = 2 * [lower] if isinstance(lower, NUMERICAL_TYPES) else lower
upper = 2 * [upper] if isinstance(upper, NUMERICAL_TYPES) else upper
Expand Down Expand Up @@ -3390,7 +3389,7 @@ def __validate_extrapolation(self, extrapolation):
extrapolation = "natural"
return extrapolation

def to_dict(self):
def to_dict(self, _):
"""Serializes the Function instance to a dictionary.
Returns
Expand All @@ -3401,7 +3400,7 @@ def to_dict(self):
source = self.source

if callable(source):
source = zlib.compress(base64.b85encode(dill.dumps(source))).hex()
source = to_hex_encode(source)

return {
"source": source,
Expand All @@ -3423,9 +3422,7 @@ def from_dict(cls, func_dict):
"""
source = func_dict["source"]
if func_dict["interpolation"] is None and func_dict["extrapolation"] is None:
source = dill.loads(
base64.b85decode(zlib.decompress(bytes.fromhex(source)))
)
source = from_hex_decode(source)

return cls(
source=source,
Expand Down
Loading

0 comments on commit 328a25d

Please sign in to comment.