diff --git a/examples/thrown_example.py b/examples/thrown_example.py new file mode 100644 index 00000000..162a561d --- /dev/null +++ b/examples/thrown_example.py @@ -0,0 +1,33 @@ +# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved. + +from prog_models.models.thrown_object import ThrownObject +from prog_algs import * +import matplotlib.pyplot as plt # For plotting +from prog_algs.visualize import plot_line + +def run_example(): + ## Setup + def future_loading(t, x = None): + return {} + m = ThrownObject(process_noise = 0) + + + ## Prediction - Predict EOD given current state + # Setup prediction + mc = predictors.MonteCarlo(m) + + # Predict with a step size of 0.1 + mean = m.initialize({}, {}) + dist = uncertain_data.MultivariateNormalDist(['x', 'v'], list(mean.values()), [[0.01, 0], [0, 1e-4]]) + samples = dist.sample(100) + print(samples) + print([m.event_state(x) for x in samples]) + (times, inputs, states, outputs, event_states, eol) = mc.predict(samples, future_loading, dt=0.1) + + # Plot result + plot_line(times[0], event_states) + plt.show() + +# This allows the module to be executed directly +if __name__ == '__main__': + run_example() diff --git a/src/prog_algs/visualize/__init__.py b/src/prog_algs/visualize/__init__.py index e00ed8c8..ea657192 100644 --- a/src/prog_algs/visualize/__init__.py +++ b/src/prog_algs/visualize/__init__.py @@ -1,4 +1,5 @@ # Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved. from .plot_scatter import plot_scatter -__all__ = ['plot_scatter'] +from .plot_line import plot_line +__all__ = ['plot_scatter', 'plot_line'] diff --git a/src/prog_algs/visualize/plot_line.py b/src/prog_algs/visualize/plot_line.py new file mode 100644 index 00000000..898d363e --- /dev/null +++ b/src/prog_algs/visualize/plot_line.py @@ -0,0 +1,58 @@ +# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved. +import matplotlib.pyplot as plt +from statistics import mean + +def plot_line(times, data, keys = None, fig = None): + """Plot Line Chart with Uncertainty Bounds + + Args: + times ([double]): Times that the data corresponds to. + data ([type]): [description] + keys ([type], optional): [description]. Defaults to None. + fig ([type], optional): [description]. Defaults to None. + + Raises: + TypeError: [description] + TypeError: [description] + """ + parameters = { # Default parameters + 'legend': True + } + + if fig is None: + fig = plt.figure() + else: + plt.figure(fig.number) + + if keys is not None: + try: + iter(keys) + except TypeError: + raise TypeError("Keys should be a list of strings (e.g., ['state1', 'state2'], was {}".format(type(keys))) + + for key in keys: + if key not in data[0][0].keys(): + raise TypeError("Key {} was not present in samples (keys: {})".format(key, list(data[0][0].keys()))) + else: + keys = data[0][0].keys() + + transposed_data = [data.snapshot(i) for i in range(len(times))] + + for key in keys: + specific_data = [[sample[key] for sample in snapshot if sample is not None] for snapshot in transposed_data] + means = [mean(d) for d in specific_data] + mins = [min(d) for d in specific_data] + maxs = [max(d) for d in specific_data] + line = plt.plot(times, means, label=key)[0] + color = line.get_color() + plt.fill_between(times, mins, maxs, color=color+"55") + + plt.xlabel('Time (s)') + plt.ylim(0, 1) + plt.xlim(times[0], times[-1]) + + # Set legend + if parameters['legend']: + plt.legend().remove() # Remove any existing legend - prevents "ghost effect" + plt.legend(loc='upper right') + \ No newline at end of file