csu-hmc / opty

A library for using direct collocation in the optimization of dynamic systems.
http://opty.readthedocs.io
Other
86 stars 20 forks source link

Problem with: prob.plot_constraint_violations(solution) #123

Closed Peter230655 closed 3 months ago

Peter230655 commented 4 months ago

I am playing around with a opty simulation.

prob = Problem(obj, obj_grad, EOM, state_symbols, num_nodes, interval_value,
               known_parameter_map=par_map,
               instance_constraints=instance_constraints,
               bounds=bounds)

initial_guess = np.random.randn(prob.num_free)

solution, info = prob.solve(initial_guess)
print('message from optimizer:', info['status_msg'])

prob.plot_constraint_violations(solution)_

The last line (prob.plot_constraint_violations(solution)) gives this error:

ValueError                                Traceback (most recent call last)
Cell In[4], [line 14](vscode-notebook-cell:?execution_count=4&line=14)
     [11](vscode-notebook-cell:?execution_count=4&line=11) solution, info = prob.solve(initial_guess)
     [12](vscode-notebook-cell:?execution_count=4&line=12) print('message from optimizer:', info['status_msg'])
---> [14](vscode-notebook-cell:?execution_count=4&line=14) prob.plot_constraint_violations(solution)
     [15](vscode-notebook-cell:?execution_count=4&line=15) prob.plot_objective_value()

File [c:\Users\Peter\anaconda3\envs\sympy-dev\Lib\site-packages\opty\utils.py:172](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/opty/utils.py:172), in _optional_plt_dep.<locals>.wrapper(*args, **kwargs)
    [170](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/opty/utils.py:170)     raise ImportError('Install matplotlib for plotting features.')
    [171](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/opty/utils.py:171) else:
--> [172](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/opty/utils.py:172)     func(*args, **kwargs)

File [c:\Users\Peter\anaconda3\envs\sympy-dev\Lib\site-packages\opty\direct_collocation.py:386](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/opty/direct_collocation.py:386), in Problem.plot_constraint_violations(self, vector)
    [383](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/opty/direct_collocation.py:383) axes[-2].set_xlabel('Node Number')
    [385](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/opty/direct_collocation.py:385) left = range(len(con_violations[self.collocator.num_states * N:]))
--> [386](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/opty/direct_collocation.py:386) axes[-1].bar(left, con_violations[self.collocator.num_states * N:],
    [387](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/opty/direct_collocation.py:387)              tick_label=[sm.latex(s, mode='inline')
    [388](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/opty/direct_collocation.py:388)                          for s in self.collocator.instance_constraints])
    [389](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/opty/direct_collocation.py:389) axes[-1].set_ylabel('Instance')
    [390](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/opty/direct_collocation.py:390) axes[-1].set_xticklabels(axes[-1].get_xticklabels(), rotation=-10)

File [c:\Users\Peter\anaconda3\envs\sympy-dev\Lib\site-packages\matplotlib\__init__.py:1465](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/__init__.py:1465), in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   [1462](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/__init__.py:1462) @functools.wraps(func)
   [1463](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/__init__.py:1463) def inner(ax, *args, data=None, **kwargs):
   [1464](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/__init__.py:1464)     if data is None:
-> [1465](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/__init__.py:1465)         return func(ax, *map(sanitize_sequence, args), **kwargs)
   [1467](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/__init__.py:1467)     bound = new_sig.bind(ax, *args, **kwargs)
   [1468](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/__init__.py:1468)     auto_label = (bound.arguments.get(label_namer)
   [1469](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/__init__.py:1469)                   or bound.kwargs.get(label_namer))

File [c:\Users\Peter\anaconda3\envs\sympy-dev\Lib\site-packages\matplotlib\axes\_axes.py:2569](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/axes/_axes.py:2569), in Axes.bar(self, x, height, width, bottom, align, **kwargs)
   [2566](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/axes/_axes.py:2566) self.add_container(bar_container)
   [2568](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/axes/_axes.py:2568) if tick_labels is not None:
-> [2569](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/axes/_axes.py:2569)     tick_labels = np.broadcast_to(tick_labels, len(patches))
   [2570](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/axes/_axes.py:2570)     tick_label_axis.set_ticks(tick_label_position)
   [2571](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/matplotlib/axes/_axes.py:2571)     tick_label_axis.set_ticklabels(tick_labels)

File [c:\Users\Peter\anaconda3\envs\sympy-dev\Lib\site-packages\numpy\lib\stride_tricks.py:413](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:413), in broadcast_to(array, shape, subok)
    [367](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:367) @array_function_dispatch(_broadcast_to_dispatcher, module='numpy')
    [368](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:368) def broadcast_to(array, shape, subok=False):
    [369](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:369)     """Broadcast an array to a new shape.
    [370](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:370) 
    [371](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:371)     Parameters
   (...)
    [411](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:411)            [1, 2, 3]])
    [412](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:412)     """
--> [413](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:413)     return _broadcast_to(array, shape, subok=subok, readonly=True)

File [c:\Users\Peter\anaconda3\envs\sympy-dev\Lib\site-packages\numpy\lib\stride_tricks.py:349](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:349), in _broadcast_to(array, shape, subok, readonly)
    [346](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:346)     raise ValueError('all elements of broadcast shape must be non-'
    [347](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:347)                      'negative')
    [348](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:348) extras = []
--> [349](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:349) it = np.nditer(
    [350](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:350)     (array,), flags=['multi_index', 'refs_ok', 'zerosize_ok'] + extras,
    [351](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:351)     op_flags=['readonly'], itershape=shape, order='C')
    [352](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:352) with it:
    [353](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:353)     # never really has writebackifcopy semantics
    [354](file:///C:/Users/Peter/anaconda3/envs/sympy-dev/Lib/site-packages/numpy/lib/stride_tricks.py:354)     broadcast = it.itviews[0]

ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (8,)  and requested shape (16,)

However it does print some result: Screenshot 2024-02-13 210319

What am I doing wrong? Thanks for any help!

tjstienstra commented 4 months ago

Yeah, I actually guess that this is a bug. The plot constraint function should never crash after successfully solving a problem. I also remember having the same problem and writing my own version of this function.

tjstienstra commented 4 months ago

Just gave the source code a quick read. The problem is in the slicing of the selecting the values of the instance constraints of the free vector. It uses con_violations[self.collocator.num_states * N:] and forgets the unknown input trajectories.

Peter230655 commented 4 months ago

Thanks!

  1. It does seem to work sometimes, with the inverted pendulum I did not get this error. Only difference I see to my (copied) simulation is the larger state space (4 instead of 2) ?
  2. Despite the error it created the plot in my simulation.
  3. Are these constraint violations in the dictionary returned by problem.solve(initial_guess) ?
tjstienstra commented 4 months ago

Just gave the source code a quick read. The problem is in the slicing of the selecting the values of the instance constraints of the free vector. It uses con_violations[self.collocator.num_states * N:] and forgets the unknown input trajectories.

Oh wait, this statement is incorrect, we are dealing here with the constraints vector, not the free vector. I'll have to check how this array is actually formatted.

moorepants commented 4 months ago

There is this probably related issue #53 .

tjstienstra commented 4 months ago

There is this probably related issue #53 .

I think it is not exactly related (unless it really has to do with the initial slicing), but it can be easily fixed at the same time. I would propose something like the following, where I immediately also introduce the option to just plot all state violations in a single axis, as that is way more readable when having more states.

def plot_constraint_violations(self, vector, separate_state_axes=True):
    con_violations = self.con(vector)
    con_nodes = range(self.collocator.num_states,  # Have to say that I do not exactly know where this part comes from
                      self.collocator.num_collocation_nodes + 1)
    N = len(con_nodes)
    n_state_axes = len(self.collocator.state_symbols) if separate_state_axes else 1
    plot_instance_viols = self.coll_instance_constraints is not None
    fig, axes = plt.subplots(n_state_axes + plot_instance_viols)

    for i, symbol in enumerate(self.collocator.state_symbols):
        state_violations = con_violations[i * N:i * N + N]
        state_label = sm.latex(symbol, mode='inline')
        if separate_state_axes:
            axes[i].plot(con_nodes, state_violations)
            axes[i].set_ylabel(state_label)
        else:
            axes[0].plot(con_nodes, state_violations, label=state_label)
    if not separate_state_axes:
        axes[0].legend()

    axes[0].set_title('Constraint Violations')
    axes[-plot_instance_viols - 1].set_xlabel('Node Number')

    if plot_instance_viols:
        left = range(len(self.collocator.instance_constraints))
        axes[-1].bar(left, con_violations[-len(self.collocator.instance_constraints):],
                    tick_label=[sm.latex(s, mode='inline')
                                for s in self.collocator.instance_constraints])
        axes[-1].set_ylabel('Instance')
        axes[-1].set_xticklabels(axes[-1].get_xticklabels(), rotation=-10)

    return axes

PS a more future proof argument could be state_style="separate axes"

moorepants commented 4 months ago

Any improvement to the plots are welcome. I didn't spend too much time making them nice years ago when I drafted this.

Peter230655 commented 3 months ago

Dumb question: Does this line: _prob.plot_constraint_violations(solution)__ plot the content of info[g], where info is the dictionary returned by solve? Thanks!