Skip to content
Snippets Groups Projects
Commit 67abe0cc authored by Yvan's avatar Yvan
Browse files

Updating propagator

parent 9f7d2357
No related branches found
No related tags found
No related merge requests found
File deleted
......@@ -122,24 +122,28 @@ class Propagator:
# events to be tracked, use solve_ivp
if len(t_span) > 2: # solution on user-provided time grid
sol = sci.solve_ivp(self.odes, [t_span[0], t_span[-1]], state0, t_eval=t_span,
method=self.method, rtol=self.rtol, atol=self.atol, events=events)
sol = sci.solve_ivp(self.odes, [t_span[0], t_span[-1]], state0, method=self.method,
t_eval=t_span, events=events, rtol=self.rtol, atol=self.atol)
elif self.time_steps is None: # solution only at time interval boundaries
sol = sci.solve_ivp(self.odes, t_span, state0, t_eval=t_span, method=self.method,
rtol=self.rtol, atol=self.atol, events=events)
sol = sci.solve_ivp(self.odes, t_span, state0, method=self.method,
t_eval=t_span, events=events, rtol=self.rtol, atol=self.atol)
else: # solution on uniformly spaced time grid with predefined number of points
sol = sci.solve_ivp(self.odes, t_span, state0,
sol = sci.solve_ivp(self.odes, t_span, state0, method=self.method,
t_eval=np.linspace(t_span[0], t_span[-1], self.time_steps),
method=self.method, rtol=self.rtol, atol=self.atol,
events=events)
# in case of terminal events, append event state and time to solution state and time vectors
if callable(events): # For a single event
events = (events,)
for i, event in enumerate(events):
if hasattr(event, 'terminal') and event.terminal and sol.t_events[i].size > 0:
sol.y = np.append(sol.y, sol.y_events[i].T, axis=1)
sol.t = np.append(sol.t, sol.t_events[i])
return sol.t, sol.y.T, sol.t_events, sol.y_events
events=events, rtol=self.rtol, atol=self.atol)
t_vec, state_vec, t_event, state_event = sol.t, sol.y.T, sol.t_events, sol.y_events
# look for terminal events outside the time vector
t_evt_all = np.concatenate(t_event) # time vector for all event occurrences
if t_span[0] < t_span[-1] and t_evt_all.max(initial=t_span[0]) > t_vec[-1]:
idx_cum = t_evt_all.argmax() # index of the stopping event for forward propagation
elif t_span[0] > t_span[-1] and t_evt_all.min(initial=t_span[0]) < t_vec[-1]:
idx_cum = t_evt_all.argmin() # index of the stopping event for backward propagation
else: # no terminal events outside the time vector
return t_vec, state_vec, t_event, state_event
# append the stopping event to both time and state vectors
idx_evt = np.nonzero(np.asarray([i.size for i in t_event]).cumsum() == (idx_cum + 1))[0][0]
t_vec = np.concatenate((t_vec, t_event[idx_evt][-1:]))
state_vec = np.concatenate((state_vec, state_event[idx_evt][-1:]))
return t_vec, state_vec, t_event, state_event
No preview for this file type
No preview for this file type
No preview for this file type
File added
"""
Test the event detection feature of the Propagator class.
@author: Alberto FOSSA'
"""
import unittest
import numpy as np
from sempy.core.init.primary import Primary
from sempy.core.init.cr3bp import Cr3bp
from sempy.core.orbits.halo import Halo
from sempy.core.propagation.cr3bp_propagator import Cr3bpSynodicPropagator
from sempy.core.diffcorr.ode_event import generate_plane_event
class TestEventDetection(unittest.TestCase):
def test_terminal_events_fwd(self):
"""Test terminal event detection for forward propagation. """
cr3bp, halo = self.get_env_orbit()
prop = Cr3bpSynodicPropagator(cr3bp.mu)
xy_crossing = generate_plane_event('xy', 1, True)
xz_crossing = generate_plane_event('xz', 0, False)
yz_crossing = generate_plane_event('yz', 0, False)
events = [xy_crossing, xz_crossing, yz_crossing]
t_vec, state_vec, t_event, state_event = \
prop.propagate([0.0, 2 * halo.T], halo.state0, events=events)
self.assertEqual(t_event[0].size, 1)
self.assertEqual(t_event[1].size, 2)
self.assertEqual(t_event[2].size, 0)
np.testing.assert_equal(t_vec[-1], t_event[0][-1])
np.testing.assert_equal(state_vec[-1], state_event[0][-1])
def test_terminal_events_bwd(self):
"""Test terminal event detection for backward propagation. """
cr3bp, halo = self.get_env_orbit()
prop = Cr3bpSynodicPropagator(cr3bp.mu)
xy_crossing = generate_plane_event('xy', 0, False)
xz_crossing = generate_plane_event('xz', -1, True)
yz_crossing = generate_plane_event('yz', 0, False)
events = [yz_crossing, xy_crossing, xz_crossing]
t_vec, state_vec, t_event, state_event = \
prop.propagate([0.0, -halo.T], halo.state0, events=events)
self.assertEqual(t_event[0].size, 0)
self.assertEqual(t_event[1].size, 1)
self.assertEqual(t_event[2].size, 1)
np.testing.assert_equal(t_vec[-1], t_event[2][-1])
np.testing.assert_equal(state_vec[-1], state_event[2][-1])
@staticmethod
def get_env_orbit():
"""Get the CR3BP system and the orbit used to perform all tests. """
cr3bp = Cr3bp(Primary.EARTH, Primary.MOON)
halo = Halo(cr3bp, cr3bp.l1, Halo.Family.southern, Azdim=30e3)
halo.interpolation()
return cr3bp, halo
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment