Skip to content

Commit 328a25d

Browse files
authored
Merge pull request #695 from RocketPy-Team/enh/sim-decoding
ENH: Simulation Save and Load in JSON Files
2 parents e1a1aa2 + 5d1c585 commit 328a25d

30 files changed

+1388
-138
lines changed

rocketpy/_encoders.py

Lines changed: 150 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
"""Defines a custom JSON encoder for RocketPy objects."""
22

3+
import base64
34
import json
45
from datetime import datetime
6+
from importlib import import_module
57

8+
import dill
69
import numpy as np
710

811

912
class RocketPyEncoder(json.JSONEncoder):
1013
"""Custom JSON encoder for RocketPy objects. It defines how to encode
1114
different types of objects to a JSON supported format."""
1215

16+
def __init__(self, *args, **kwargs):
17+
self.include_outputs = kwargs.pop("include_outputs", True)
18+
super().__init__(*args, **kwargs)
19+
1320
def default(self, o):
1421
if isinstance(
1522
o,
@@ -33,17 +40,152 @@ def default(self, o):
3340
elif isinstance(o, np.ndarray):
3441
return o.tolist()
3542
elif isinstance(o, datetime):
36-
return o.isoformat()
43+
return [o.year, o.month, o.day, o.hour]
3744
elif hasattr(o, "__iter__") and not isinstance(o, str):
3845
return list(o)
3946
elif hasattr(o, "to_dict"):
40-
return o.to_dict()
47+
encoding = o.to_dict(self.include_outputs)
48+
encoding = remove_circular_references(encoding)
49+
50+
encoding["signature"] = get_class_signature(o)
51+
52+
return encoding
53+
4154
elif hasattr(o, "__dict__"):
42-
exception_set = {"prints", "plots"}
43-
return {
44-
key: value
45-
for key, value in o.__dict__.items()
46-
if key not in exception_set
47-
}
55+
encoding = remove_circular_references(o.__dict__)
56+
57+
if "rocketpy" in o.__class__.__module__:
58+
encoding["signature"] = get_class_signature(o)
59+
60+
return encoding
4861
else:
4962
return super().default(o)
63+
64+
65+
class RocketPyDecoder(json.JSONDecoder):
66+
"""Custom JSON decoder for RocketPy objects. It defines how to decode
67+
different types of objects from a JSON supported format."""
68+
69+
def __init__(self, *args, **kwargs):
70+
super().__init__(object_hook=self.object_hook, *args, **kwargs)
71+
72+
def object_hook(self, obj):
73+
if "signature" in obj:
74+
signature = obj.pop("signature")
75+
76+
try:
77+
class_ = get_class_from_signature(signature)
78+
79+
if hasattr(class_, "from_dict"):
80+
return class_.from_dict(obj)
81+
else:
82+
# Filter keyword arguments
83+
kwargs = {
84+
key: value
85+
for key, value in obj.items()
86+
if key in class_.__init__.__code__.co_varnames
87+
}
88+
89+
return class_(**kwargs)
90+
except (ImportError, AttributeError):
91+
return obj
92+
else:
93+
return obj
94+
95+
96+
def get_class_signature(obj):
97+
"""Returns the signature of a class in the form of a string.
98+
The signature is an importable string that can be used to import
99+
the class by its module.
100+
101+
Parameters
102+
----------
103+
obj : object
104+
Object to get the signature from.
105+
106+
Returns
107+
-------
108+
str
109+
Signature of the class.
110+
"""
111+
class_ = obj.__class__
112+
name = getattr(class_, '__qualname__', class_.__name__)
113+
114+
return {"module": class_.__module__, "name": name}
115+
116+
117+
def get_class_from_signature(signature):
118+
"""Returns the class by importing its signature.
119+
120+
Parameters
121+
----------
122+
signature : str
123+
Signature of the class.
124+
125+
Returns
126+
-------
127+
type
128+
Class defined by the signature.
129+
"""
130+
module = import_module(signature["module"])
131+
inner_class = None
132+
133+
for class_ in signature["name"].split("."):
134+
inner_class = getattr(module, class_)
135+
136+
return inner_class
137+
138+
139+
def remove_circular_references(obj_dict):
140+
"""Removes circular references from a dictionary.
141+
142+
Parameters
143+
----------
144+
obj_dict : dict
145+
Dictionary to remove circular references from.
146+
147+
Returns
148+
-------
149+
dict
150+
Dictionary without circular references.
151+
"""
152+
obj_dict.pop("prints", None)
153+
obj_dict.pop("plots", None)
154+
155+
return obj_dict
156+
157+
158+
def to_hex_encode(obj, encoder=base64.b85encode):
159+
"""Converts an object to hex representation using dill.
160+
161+
Parameters
162+
----------
163+
obj : object
164+
Object to be converted to hex.
165+
encoder : callable, optional
166+
Function to encode the bytes. Default is base64.b85encode.
167+
168+
Returns
169+
-------
170+
bytes
171+
Object converted to bytes.
172+
"""
173+
return encoder(dill.dumps(obj)).hex()
174+
175+
176+
def from_hex_decode(obj_bytes, decoder=base64.b85decode):
177+
"""Converts an object from hex representation using dill.
178+
179+
Parameters
180+
----------
181+
obj_bytes : str
182+
Hex string to be converted to object.
183+
decoder : callable, optional
184+
Function to decode the bytes. Default is base64.b85decode.
185+
186+
Returns
187+
-------
188+
object
189+
Object converted from bytes.
190+
"""
191+
return dill.loads(decoder(bytes.fromhex(obj_bytes)))

rocketpy/environment/environment.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2743,6 +2743,96 @@ def decimal_degrees_to_arc_seconds(angle):
27432743
arc_seconds = (remainder * 60 - arc_minutes) * 60
27442744
return degrees, arc_minutes, arc_seconds
27452745

2746+
def to_dict(self, include_outputs=True):
2747+
env_dict = {
2748+
"gravity": self.gravity,
2749+
"date": self.date,
2750+
"latitude": self.latitude,
2751+
"longitude": self.longitude,
2752+
"elevation": self.elevation,
2753+
"datum": self.datum,
2754+
"timezone": self.timezone,
2755+
"_max_expected_height": self.max_expected_height,
2756+
"atmospheric_model_type": self.atmospheric_model_type,
2757+
"pressure": self.pressure,
2758+
"barometric_height": self.barometric_height,
2759+
"temperature": self.temperature,
2760+
"wind_velocity_x": self.wind_velocity_x,
2761+
"wind_velocity_y": self.wind_velocity_y,
2762+
"wind_heading": self.wind_heading,
2763+
"wind_direction": self.wind_direction,
2764+
"wind_speed": self.wind_speed,
2765+
}
2766+
2767+
if include_outputs:
2768+
env_dict["density"] = self.density
2769+
env_dict["speed_of_sound"] = self.speed_of_sound
2770+
env_dict["dynamic_viscosity"] = self.dynamic_viscosity
2771+
2772+
return env_dict
2773+
2774+
@classmethod
2775+
def from_dict(cls, data): # pylint: disable=too-many-statements
2776+
env = cls(
2777+
gravity=data["gravity"],
2778+
date=data["date"],
2779+
latitude=data["latitude"],
2780+
longitude=data["longitude"],
2781+
elevation=data["elevation"],
2782+
datum=data["datum"],
2783+
timezone=data["timezone"],
2784+
max_expected_height=data["_max_expected_height"],
2785+
)
2786+
atmospheric_model = data["atmospheric_model_type"]
2787+
2788+
if atmospheric_model == "standard_atmosphere":
2789+
env.set_atmospheric_model("standard_atmosphere")
2790+
elif atmospheric_model == "custom_atmosphere":
2791+
env.set_atmospheric_model(
2792+
type="custom_atmosphere",
2793+
pressure=data["pressure"],
2794+
temperature=data["temperature"],
2795+
wind_u=data["wind_velocity_x"],
2796+
wind_v=data["wind_velocity_y"],
2797+
)
2798+
else:
2799+
env.__set_pressure_function(data["pressure"])
2800+
env.__set_barometric_height_function(data["barometric_height"])
2801+
env.__set_temperature_function(data["temperature"])
2802+
env.__set_wind_velocity_x_function(data["wind_velocity_x"])
2803+
env.__set_wind_velocity_y_function(data["wind_velocity_y"])
2804+
env.__set_wind_heading_function(data["wind_heading"])
2805+
env.__set_wind_direction_function(data["wind_direction"])
2806+
env.__set_wind_speed_function(data["wind_speed"])
2807+
env.elevation = data["elevation"]
2808+
env.max_expected_height = data["_max_expected_height"]
2809+
2810+
if atmospheric_model in ["windy", "forecast", "reanalysis", "ensemble"]:
2811+
env.atmospheric_model_init_date = data["atmospheric_model_init_date"]
2812+
env.atmospheric_model_end_date = data["atmospheric_model_end_date"]
2813+
env.atmospheric_model_interval = data["atmospheric_model_interval"]
2814+
env.atmospheric_model_init_lat = data["atmospheric_model_init_lat"]
2815+
env.atmospheric_model_end_lat = data["atmospheric_model_end_lat"]
2816+
env.atmospheric_model_init_lon = data["atmospheric_model_init_lon"]
2817+
env.atmospheric_model_end_lon = data["atmospheric_model_end_lon"]
2818+
2819+
if atmospheric_model == "ensemble":
2820+
env.level_ensemble = data["level_ensemble"]
2821+
env.height_ensemble = data["height_ensemble"]
2822+
env.temperature_ensemble = data["temperature_ensemble"]
2823+
env.wind_u_ensemble = data["wind_u_ensemble"]
2824+
env.wind_v_ensemble = data["wind_v_ensemble"]
2825+
env.wind_heading_ensemble = data["wind_heading_ensemble"]
2826+
env.wind_direction_ensemble = data["wind_direction_ensemble"]
2827+
env.wind_speed_ensemble = data["wind_speed_ensemble"]
2828+
env.num_ensemble_members = data["num_ensemble_members"]
2829+
2830+
env.calculate_density_profile()
2831+
env.calculate_speed_of_sound_profile()
2832+
env.calculate_dynamic_viscosity()
2833+
2834+
return env
2835+
27462836

27472837
if __name__ == "__main__":
27482838
import doctest

rocketpy/environment/environment_analysis.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,10 @@ def __check_coordinates_inside_grid(
423423
or lat_index > len(lat_array) - 1
424424
):
425425
raise ValueError(
426-
f"Latitude and longitude pair {(self.latitude, self.longitude)} is outside the grid available in the given file, which is defined by {(lat_array[0], lon_array[0])} and {(lat_array[-1], lon_array[-1])}."
426+
f"Latitude and longitude pair {(self.latitude, self.longitude)} "
427+
"is outside the grid available in the given file, which "
428+
f"is defined by {(lat_array[0], lon_array[0])} and "
429+
f"{(lat_array[-1], lon_array[-1])}."
427430
)
428431

429432
def __localize_input_dates(self):

rocketpy/mathutils/function.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,14 @@
55
carefully as it may impact all the rest of the project.
66
"""
77

8-
import base64
98
import warnings
10-
import zlib
119
from bisect import bisect_left
1210
from collections.abc import Iterable
1311
from copy import deepcopy
1412
from functools import cached_property
1513
from inspect import signature
1614
from pathlib import Path
1715

18-
import dill
1916
import matplotlib.pyplot as plt
2017
import numpy as np
2118
from scipy import integrate, linalg, optimize
@@ -25,6 +22,8 @@
2522
RBFInterpolator,
2623
)
2724

25+
from rocketpy._encoders import from_hex_decode, to_hex_encode
26+
2827
# Numpy 1.x compatibility,
2928
# TODO: remove these lines when all dependencies support numpy>=2.0.0
3029
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
@@ -712,9 +711,9 @@ def set_discrete(
712711
if func.__dom_dim__ == 1:
713712
xs = np.linspace(lower, upper, samples)
714713
ys = func.get_value(xs.tolist()) if one_by_one else func.get_value(xs)
715-
func.set_source(np.concatenate(([xs], [ys])).transpose())
716-
func.set_interpolation(interpolation)
717-
func.set_extrapolation(extrapolation)
714+
func.__interpolation__ = interpolation
715+
func.__extrapolation__ = extrapolation
716+
func.set_source(np.column_stack((xs, ys)))
718717
elif func.__dom_dim__ == 2:
719718
lower = 2 * [lower] if isinstance(lower, NUMERICAL_TYPES) else lower
720719
upper = 2 * [upper] if isinstance(upper, NUMERICAL_TYPES) else upper
@@ -3390,7 +3389,7 @@ def __validate_extrapolation(self, extrapolation):
33903389
extrapolation = "natural"
33913390
return extrapolation
33923391

3393-
def to_dict(self):
3392+
def to_dict(self, _):
33943393
"""Serializes the Function instance to a dictionary.
33953394
33963395
Returns
@@ -3401,7 +3400,7 @@ def to_dict(self):
34013400
source = self.source
34023401

34033402
if callable(source):
3404-
source = zlib.compress(base64.b85encode(dill.dumps(source))).hex()
3403+
source = to_hex_encode(source)
34053404

34063405
return {
34073406
"source": source,
@@ -3423,9 +3422,7 @@ def from_dict(cls, func_dict):
34233422
"""
34243423
source = func_dict["source"]
34253424
if func_dict["interpolation"] is None and func_dict["extrapolation"] is None:
3426-
source = dill.loads(
3427-
base64.b85decode(zlib.decompress(bytes.fromhex(source)))
3428-
)
3425+
source = from_hex_decode(source)
34293426

34303427
return cls(
34313428
source=source,

0 commit comments

Comments
 (0)