OpenMDAO / dymos

Open Source Optimization of Dynamic Multidisciplinary Systems
Apache License 2.0
208 stars 66 forks source link

Connections to trajectory with src_indices from upstream analysis cannot simulate. #887

Closed caksland closed 1 year ago

caksland commented 1 year ago

Description

An analysis upstream of a dymos trajectory computes a vector of parameters that are passed into the dymos trajectory as scalars using connect(src,to, src_indicies=X). When simulating the trajectory, openmdao cannot assign the parameter values correctly. See example below for a mass spring damper system. The upstream analysis computes a vector with damping coefficient and natural frequency and the trajectory contains the mass spring damper ODE.

Example

import openmdao.api as om
import dymos as dm
import numpy as np
import matplotlib.pyplot as plt

class MSD(om.ExplicitComponent):
    def initialize(self):
        self.options.declare('num_nodes', types=int, default = 2)

    def setup(self):
        nn = self.options["num_nodes"]

        self.add_input('x1',shape=(nn,))
        self.add_input('x2',shape=(nn,))
        self.add_input('zeta',val=1)
        self.add_input('wn',val=1)

        self.add_output('x1_dot',shape=(nn,))   
        self.add_output('x2_dot',shape=(nn,))   

    def setup_partials(self):
        self.declare_partials('*', '*',method='fd')

    def compute(self,inputs,outputs):
        x1 = inputs['x1']
        x2 = inputs['x2']
        zeta= inputs['zeta']
        wn = inputs['wn']

        outputs['x1_dot'] = x2
        outputs['x2_dot'] = -2*zeta*wn*x2 - wn**2*x1

class DampFreq(om.ExplicitComponent):
    def initialize(self):
        pass

    def setup(self):        
        self.add_input('m',shape=(1,)) # mass
        self.add_input('k',shape=(1,)) #spring
        self.add_input('c',shape=(1,)) #damping

        self.add_output('p',shape=(2,))   #[zeta,wn]

    def setup_partials(self):
        self.declare_partials('*', '*',method='fd')

    def compute(self,inputs,outputs):
        m = inputs['m']
        k = inputs['k']    
        c = inputs['c']    

        outputs['p'] = np.array([c/2*(m*k)**.5,(k/m)**.5])

# create problem
p = om.Problem()
p.driver = om.ScipyOptimizeDriver()

# add model to compute dampening ratio and natural frequency
p.model.add_subsystem('params',subsys=DampFreq())

# create phase with mass spring damper
phase = dm.Phase(ode_class=MSD, transcription=dm.Radau(num_segments=10))
phase.set_time_options(fix_initial=True, fix_duration=True)

# add states
phase.add_state('x1', fix_initial=True, rate_source='x1_dot', targets=['x1'])
phase.add_state('x2', fix_initial=True, rate_source='x2_dot', targets=['x2'])

# needs objective
phase.add_objective('time', loc='final')

# add trajectory and parameters to model
traj = p.model.add_subsystem('traj', dm.Trajectory())
traj.add_phase('phase0', phase)
traj.add_parameter('zeta', targets={'phase0':['zeta']},units=None, opt=False,static_target=True)
traj.add_parameter('wn', targets={'phase0':['wn']}, units=None, opt=False,static_target=True)

# connect upstream analysis to ODE
p.model.connect('params.p','traj.parameters:zeta',src_indices=[0])
p.model.connect('params.p','traj.parameters:wn',src_indices=[1])

# setup problem
p.setup()
om.n2(p)

# iniitalize values
p.set_val('traj.phase0.t_initial', 0.0)
p.set_val('traj.phase0.t_duration', 15.0)
p.set_val('traj.phase0.states:x1', 10.0)
p.set_val('traj.phase0.states:x2', 0.0)
p.set_val('params.m', 1)
p.set_val('params.k', 1)
p.set_val('params.c', 0.5)

# run the driver
p.run_driver()

# simulate the trajectory (THIS FAILS!)
sim_out = traj.simulate(times_per_seg=50)

# plot results
t_sol = p.get_val('traj.phase0.timeseries.time')
t_sim = sim_out.get_val('traj.phase0.timeseries.time')

states = ['x1', 'x2']
fig, axes = plt.subplots(len(states), 1)
for i, state in enumerate(states):
    sol = axes[i].plot(t_sol, p.get_val(f'traj.phase0.timeseries.states:{state}'), 'o')
    sim = axes[i].plot(t_sim, sim_out.get_val(f'traj.phase0.timeseries.states:{state}'), '-')
    axes[i].set_ylabel(state)
axes[-1].set_xlabel('time (s)')
fig.legend((sol[0], sim[0]), ('solution', 'simulation'), 'lower right', ncol=2)
plt.tight_layout()
plt.show()

Dymos Version

1.2.0

Relevant environment information

No response

robfalck commented 1 year ago

Could you please try this with the latest version of dymos and see if you still encounter the problem? I think this bug has been fixed.

caksland commented 1 year ago

I'll try to get to it this week. Thanks for looking into this.

caksland commented 1 year ago

Yes it has been resolved. I'll make sure to test issues on the latest version before posting. Thanks!