Skip to content

Commit

Permalink
Fix Peter's plotting code to handle no instance constraints.
Browse files Browse the repository at this point in the history
  • Loading branch information
moorepants committed Sep 14, 2024
1 parent 92fc5e1 commit df995bb
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 49 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ jobs:
- name: Build Docs
run: |
cd docs
make html
# fail on any warnings
make html SPHINXOPTS="-W"
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
#html_static_path = ['_static']


# -- Options for HTMLHelp output ------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions examples-gallery/plot_vyasarayani.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def obj_grad(free):
print(identified_msg)
print(divider)

# %%
# Plot constraint violations.
prob.plot_constraint_violations(solution)

# %%
# Simulate with the identified parameter.
y_sim = odeint(eval_f, y0, time, args=(p_sol,))
Expand Down
95 changes: 48 additions & 47 deletions opty/direct_collocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,6 @@ def plot_trajectories(self, vector, axes=None):

return axes



@_optional_plt_dep
def plot_constraint_violations(self, vector, axes=None):
"""Returns an axis with the state constraint violations plotted versus
Expand Down Expand Up @@ -457,10 +455,10 @@ def plot_constraint_violations(self, vector, axes=None):
bars_per_plot = None
rotation = -45

# find the number of bars per plot, so the bars per plot are arroximately
# the same on each plot.
# find the number of bars per plot, so the bars per plot are
# aproximately the same on each plot
hilfs = []
len_constr = len(self.collocator.instance_constraints)
len_constr = self.collocator.num_instance_constraints
for i in range(6, 11):
hilfs.append((i, i - len_constr % i))
if len_constr % i == 0:
Expand All @@ -481,77 +479,81 @@ def plot_constraint_violations(self, vector, axes=None):
else:
num_plots = len_constr // bars_per_plot + 1


# ensure that len(axes) is correct, raise ValuError otherwise
if axes is not None:
len_axes = len(axes.ravel())
len_constr = len(self.collocator.instance_constraints)
len_constr = self.collocator.num_instance_constraints
if (len_constr <= bars_per_plot) and (len_axes < 2):
raise ValueError('len(axes) must be equal to 2')

elif (len_constr % bars_per_plot == 0) and (len_axes < len_constr // bars_per_plot + 1):
raise ValueError(f'len(axes) must be equal to {len_constr//bars_per_plot+1}')
elif ((len_constr % bars_per_plot == 0) and
(len_axes < len_constr // bars_per_plot + 1)):
msg = (f'len(axes) must be equal to '
f'{len_constr//bars_per_plot+1}')
raise ValueError(msg)

elif ((len_constr % bars_per_plot != 0) and
(len_axes < len_constr // bars_per_plot + 2)):
raise ValueError(f'len(axes) must be equal to {len_constr//bars_per_plot+2}')
msg = (f'len(axes) must be equal to '
f'{len_constr//bars_per_plot+2}')
raise ValueError(msg)

else:
pass

N = self.collocator.num_collocation_nodes
con_violations = self.con(vector)
state_violations = con_violations[
:(N - 1) * len(self.collocator.state_symbols)]
:(N - 1) * self.collocator.num_states]
instance_violations = con_violations[len(state_violations):]
state_violations = state_violations.reshape(
(len(self.collocator.state_symbols), N - 1))
(self.collocator.num_states, N - 1))
con_nodes = range(1, self.collocator.num_collocation_nodes)

plot_inst_viols = self.collocator.instance_constraints is not None
num_inst_viols = self.collocator.num_instance_constraints

if axes is None:
fig, axes = plt.subplots(1 + num_plots, 1,
figsize=(6.4, 1.50*(1 + num_plots)),
layout='compressed')
figsize=(6.4, 1.50*(1 + num_plots)),
layout='compressed')

axes = axes.ravel()
axes = np.asarray(axes).ravel()

axes[0].plot(con_nodes, state_violations.T)
axes[0].set_title('Constraint violations')
axes[0].set_xlabel('Node Number')
axes[0].set_ylabel('EoM violation')

# reduce the instance constrtaints to 2 digits after the decimal point.
# give the time in tha variables with 2 digits after the decimal point.
# if variable h is used, use the result for h in the time.
instance_constr_plot = []
a_before = ''
a_before_before = ''
for exp1 in self.collocator.instance_constraints:
for a in sm.preorder_traversal(exp1):
if ((isinstance(a_before, sm.Integer) or
isinstance(a_before, sm.Float)) and
(a == self.collocator.node_time_interval)):
a_before = float(a_before)
hilfs = a_before * vector[-1]
exp1 = exp1.subs(a_before_before, sm.Float(round(hilfs, 2)))
elif (isinstance(a_before, sm.Float) and
(a != self.collocator.node_time_interval)):
exp1 = exp1.subs(a_before, round(a_before, 2))
a_before_before = a_before
a_before = a
instance_constr_plot.append(exp1)

if plot_inst_viols:
if self.collocator.instance_constraints is not None:
# reduce the instance constrtaints to 2 digits after the decimal
# point. give the time in tha variables with 2 digits after the
# decimal point. if variable h is used, use the result for h in
# the time.
num_inst_viols = self.collocator.num_instance_constraints
instance_constr_plot = []
a_before = ''
a_before_before = ''
for exp1 in self.collocator.instance_constraints:
for a in sm.preorder_traversal(exp1):
if ((isinstance(a_before, sm.Integer) or
isinstance(a_before, sm.Float)) and
(a == self.collocator.node_time_interval)):
a_before = float(a_before)
hilfs = a_before * vector[-1]
exp1 = exp1.subs(a_before_before,
sm.Float(round(hilfs, 2)))
elif ((isinstance(a_before, sm.Float) and
(a != self.collocator.node_time_interval))):
exp1 = exp1.subs(a_before, round(a_before, 2))
a_before_before = a_before
a_before = a
instance_constr_plot.append(exp1)

for i in range(num_plots):
num_ticks = bars_per_plot
if i == num_plots - 1:
beginn = i * bars_per_plot
endd = num_inst_viols
num_ticks = num_inst_viols % bars_per_plot
if(num_inst_viols % bars_per_plot == 0):
if (num_inst_viols % bars_per_plot == 0):
num_ticks = bars_per_plot
else:
endd = (i + 1) * bars_per_plot
Expand All @@ -561,13 +563,12 @@ def plot_constraint_violations(self, vector, axes=None):
inst_constr = instance_constr_plot[beginn: endd]

width = [0.06*num_ticks for _ in range(num_ticks)]
axes[i+1].bar(
range(num_ticks), inst_viol,
tick_label=[sm.latex(s, mode='inline')
for s in inst_constr], width=width)
axes[i+1].bar(range(num_ticks), inst_viol,
tick_label=[sm.latex(s, mode='inline') for s in
inst_constr], width=width)
axes[i+1].set_ylabel('Instance')
axes[i+1].set_xticklabels(axes[i+1].get_xticklabels(),
rotation=rotation)
rotation=rotation)

return axes

Expand Down Expand Up @@ -625,7 +626,7 @@ def __init__(self, equations_of_motion, state_symbols,
explicit. They can be ordinary differential equations or
differential algebraic equations.
state_symbols : iterable
An iterable containing all ``n` of the SymPy functions of time
An iterable containing all ``n`` of the SymPy functions of time
which represent the states in the equations of motion.
num_collocation_nodes : integer
The number of collocation nodes, ``N``. All known trajectory arrays
Expand Down

0 comments on commit df995bb

Please sign in to comment.