Skip to content

Commit 5db9036

Browse files
partially fix tests
1 parent 25947fa commit 5db9036

File tree

11 files changed

+354
-313
lines changed

11 files changed

+354
-313
lines changed

rocketpy/_encoders.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
from datetime import datetime
55
from importlib import import_module
6+
from inspect import Parameter, signature
67

78
import numpy as np
89

@@ -71,11 +72,12 @@ def default(self, o):
7172
encoding["signature"] = get_class_signature(o)
7273
return encoding
7374
elif hasattr(o, "to_dict"):
74-
encoding = o.to_dict(
75-
include_outputs=self.include_outputs,
76-
discretize=self.discretize,
77-
allow_pickle=self.allow_pickle,
78-
)
75+
call_kwargs = {
76+
"include_outputs": self.include_outputs,
77+
"discretize": self.discretize,
78+
"allow_pickle": self.allow_pickle,
79+
}
80+
encoding = _call_to_dict_with_supported_kwargs(o, call_kwargs)
7981
encoding = remove_circular_references(encoding)
8082

8183
encoding["signature"] = get_class_signature(o)
@@ -195,6 +197,23 @@ def set_minimal_flight_attributes(flight, obj):
195197
flight.t_initial = flight.initial_solution[0]
196198

197199

200+
def _call_to_dict_with_supported_kwargs(obj, candidate_kwargs):
201+
"""Call obj.to_dict passing only supported keyword arguments."""
202+
method = obj.to_dict
203+
sig = signature(method)
204+
params = list(sig.parameters.values())
205+
206+
if any(param.kind == Parameter.VAR_KEYWORD for param in params):
207+
return method(**candidate_kwargs)
208+
209+
supported_kwargs = {
210+
name: candidate_kwargs[name]
211+
for name in sig.parameters
212+
if name != "self" and name in candidate_kwargs
213+
}
214+
return method(**supported_kwargs)
215+
216+
198217
def get_class_signature(obj):
199218
"""Returns the signature of a class so it can be identified on
200219
decoding. The signature is a dictionary with the module and

rocketpy/mathutils/inertia.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""Utilities related to inertia tensor transformations.
2+
3+
This module centralizes dynamic helpers for applying the parallel axis
4+
theorem (PAT). It lives inside ``rocketpy.mathutils`` so that functionality
5+
depending on :class:`rocketpy.mathutils.function.Function` does not leak into
6+
generic utility modules such as ``rocketpy.tools``.
7+
"""
8+
9+
from rocketpy.mathutils.function import Function
10+
from rocketpy.mathutils.vector_matrix import Vector
11+
12+
13+
def _pat_dynamic_helper(com_inertia_moment, mass, distance_vec_3d, axes_term_lambda):
14+
"""Apply the PAT to inertia moments, supporting static and dynamic inputs."""
15+
16+
is_dynamic = (
17+
isinstance(com_inertia_moment, Function)
18+
or isinstance(mass, Function)
19+
or isinstance(distance_vec_3d, Function)
20+
)
21+
22+
def get_val(arg, t):
23+
return arg(t) if isinstance(arg, Function) else arg
24+
25+
if not is_dynamic:
26+
d_vec = Vector(distance_vec_3d)
27+
mass_term = mass * axes_term_lambda(d_vec)
28+
return com_inertia_moment + mass_term
29+
30+
def new_source(t):
31+
d_vec_t = get_val(distance_vec_3d, t)
32+
mass_t = get_val(mass, t)
33+
inertia_t = get_val(com_inertia_moment, t)
34+
mass_term = mass_t * axes_term_lambda(d_vec_t)
35+
return inertia_t + mass_term
36+
37+
return Function(new_source, inputs="t", outputs="Inertia (kg*m^2)")
38+
39+
40+
def _pat_dynamic_product_helper(
41+
com_inertia_product, mass, distance_vec_3d, product_term_lambda
42+
):
43+
"""Apply the PAT to inertia products, supporting static and dynamic inputs."""
44+
45+
is_dynamic = (
46+
isinstance(com_inertia_product, Function)
47+
or isinstance(mass, Function)
48+
or isinstance(distance_vec_3d, Function)
49+
)
50+
51+
def get_val(arg, t):
52+
return arg(t) if isinstance(arg, Function) else arg
53+
54+
if not is_dynamic:
55+
d_vec = Vector(distance_vec_3d)
56+
mass_term = mass * product_term_lambda(d_vec)
57+
return com_inertia_product + mass_term
58+
59+
def new_source(t):
60+
d_vec_t = get_val(distance_vec_3d, t)
61+
mass_t = get_val(mass, t)
62+
inertia_t = get_val(com_inertia_product, t)
63+
mass_term = mass_t * product_term_lambda(d_vec_t)
64+
return inertia_t + mass_term
65+
66+
return Function(new_source, inputs="t", outputs="Inertia (kg*m^2)")
67+
68+
69+
# --- Public functions for the Parallel Axis Theorem ---
70+
71+
72+
def parallel_axis_theorem_I11(com_inertia_moment, mass, distance_vec_3d):
73+
"""Apply PAT to the I11 inertia term.
74+
75+
Parameters
76+
----------
77+
com_inertia_moment : float or Function
78+
Inertia moment relative to the component center of mass.
79+
mass : float or Function
80+
Mass of the component. If a Function, it must map time to mass.
81+
distance_vec_3d : array-like or Function
82+
Displacement vector from the component COM to the reference COM.
83+
84+
Returns
85+
-------
86+
float or Function
87+
Updated I11 value referenced to the new axis.
88+
"""
89+
90+
return _pat_dynamic_helper(
91+
com_inertia_moment, mass, distance_vec_3d, lambda d_vec: d_vec.y**2 + d_vec.z**2
92+
)
93+
94+
95+
def parallel_axis_theorem_I22(com_inertia_moment, mass, distance_vec_3d):
96+
"""Apply PAT to the I22 inertia term.
97+
98+
Parameters
99+
----------
100+
com_inertia_moment : float or Function
101+
Inertia moment relative to the component center of mass.
102+
mass : float or Function
103+
Mass of the component. If a Function, it must map time to mass.
104+
distance_vec_3d : array-like or Function
105+
Displacement vector from the component COM to the reference COM.
106+
107+
Returns
108+
-------
109+
float or Function
110+
Updated I22 value referenced to the new axis.
111+
"""
112+
113+
return _pat_dynamic_helper(
114+
com_inertia_moment, mass, distance_vec_3d, lambda d_vec: d_vec.x**2 + d_vec.z**2
115+
)
116+
117+
118+
def parallel_axis_theorem_I33(com_inertia_moment, mass, distance_vec_3d):
119+
"""Apply PAT to the I33 inertia term.
120+
121+
Parameters
122+
----------
123+
com_inertia_moment : float or Function
124+
Inertia moment relative to the component center of mass.
125+
mass : float or Function
126+
Mass of the component. If a Function, it must map time to mass.
127+
distance_vec_3d : array-like or Function
128+
Displacement vector from the component COM to the reference COM.
129+
130+
Returns
131+
-------
132+
float or Function
133+
Updated I33 value referenced to the new axis.
134+
"""
135+
136+
return _pat_dynamic_helper(
137+
com_inertia_moment, mass, distance_vec_3d, lambda d_vec: d_vec.x**2 + d_vec.y**2
138+
)
139+
140+
141+
def parallel_axis_theorem_I12(com_inertia_product, mass, distance_vec_3d):
142+
"""Apply PAT to the I12 inertia product.
143+
144+
Parameters
145+
----------
146+
com_inertia_product : float or Function
147+
Product of inertia relative to the component center of mass.
148+
mass : float or Function
149+
Mass of the component. If a Function, it must map time to mass.
150+
distance_vec_3d : array-like or Function
151+
Displacement vector from the component COM to the reference COM.
152+
153+
Returns
154+
-------
155+
float or Function
156+
Updated I12 value referenced to the new axis.
157+
"""
158+
159+
return _pat_dynamic_product_helper(
160+
com_inertia_product, mass, distance_vec_3d, lambda d_vec: d_vec.x * d_vec.y
161+
)
162+
163+
164+
def parallel_axis_theorem_I13(com_inertia_product, mass, distance_vec_3d):
165+
"""Apply PAT to the I13 inertia product.
166+
167+
Parameters
168+
----------
169+
com_inertia_product : float or Function
170+
Product of inertia relative to the component center of mass.
171+
mass : float or Function
172+
Mass of the component. If a Function, it must map time to mass.
173+
distance_vec_3d : array-like or Function
174+
Displacement vector from the component COM to the reference COM.
175+
176+
Returns
177+
-------
178+
float or Function
179+
Updated I13 value referenced to the new axis.
180+
"""
181+
182+
return _pat_dynamic_product_helper(
183+
com_inertia_product, mass, distance_vec_3d, lambda d_vec: d_vec.x * d_vec.z
184+
)
185+
186+
187+
def parallel_axis_theorem_I23(com_inertia_product, mass, distance_vec_3d):
188+
"""Apply PAT to the I23 inertia product.
189+
190+
Parameters
191+
----------
192+
com_inertia_product : float or Function
193+
Product of inertia relative to the component center of mass.
194+
mass : float or Function
195+
Mass of the component. If a Function, it must map time to mass.
196+
distance_vec_3d : array-like or Function
197+
Displacement vector from the component COM to the reference COM.
198+
199+
Returns
200+
-------
201+
float or Function
202+
Updated I23 value referenced to the new axis.
203+
"""
204+
205+
return _pat_dynamic_product_helper(
206+
com_inertia_product, mass, distance_vec_3d, lambda d_vec: d_vec.y * d_vec.z
207+
)
208+
209+
210+
__all__ = [
211+
"parallel_axis_theorem_I11",
212+
"parallel_axis_theorem_I22",
213+
"parallel_axis_theorem_I33",
214+
"parallel_axis_theorem_I12",
215+
"parallel_axis_theorem_I13",
216+
"parallel_axis_theorem_I23",
217+
]

rocketpy/motors/cluster_motor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import numpy as np
44

55
from rocketpy.mathutils.function import Function
6-
from rocketpy.mathutils.vector_matrix import Vector
7-
from rocketpy.tools import (
6+
from rocketpy.mathutils.inertia import (
87
parallel_axis_theorem_I11,
98
parallel_axis_theorem_I12,
109
parallel_axis_theorem_I13,
1110
parallel_axis_theorem_I22,
1211
parallel_axis_theorem_I23,
1312
parallel_axis_theorem_I33,
1413
)
14+
from rocketpy.mathutils.vector_matrix import Vector
1515

1616

1717
class ClusterMotor:

rocketpy/motors/hybrid_motor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from functools import cached_property
22

33
from ..mathutils.function import Function, funcify_method, reset_funcified_methods
4-
from ..mathutils.vector_matrix import Vector
5-
from ..plots.hybrid_motor_plots import _HybridMotorPlots
6-
from ..tools import (
4+
from ..mathutils.inertia import (
75
parallel_axis_theorem_I11,
86
parallel_axis_theorem_I12,
97
parallel_axis_theorem_I13,
108
parallel_axis_theorem_I22,
119
parallel_axis_theorem_I23,
1210
parallel_axis_theorem_I33,
1311
)
12+
from ..mathutils.vector_matrix import Vector
13+
from ..plots.hybrid_motor_plots import _HybridMotorPlots
1414
from .motor import Motor
1515
from .solid_motor import SolidMotor
1616

rocketpy/motors/motor.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@
1212
import requests
1313

1414
from ..mathutils.function import Function, funcify_method
15-
from ..mathutils.vector_matrix import Matrix, Vector
16-
from ..plots.motor_plots import _MotorPlots
17-
from ..prints.motor_prints import _MotorPrints
18-
from ..tools import (
15+
from ..mathutils.inertia import (
1916
parallel_axis_theorem_I11,
2017
parallel_axis_theorem_I12,
2118
parallel_axis_theorem_I13,
2219
parallel_axis_theorem_I22,
2320
parallel_axis_theorem_I23,
2421
parallel_axis_theorem_I33,
25-
tuple_handler,
2622
)
23+
from ..mathutils.vector_matrix import Matrix, Vector
24+
from ..plots.motor_plots import _MotorPlots
25+
from ..prints.motor_prints import _MotorPrints
26+
from ..tools import tuple_handler
2727

2828
# pylint: disable=too-many-public-methods
2929
# ThrustCurve API cache
@@ -243,18 +243,52 @@ def __init__(
243243
self.burn_out_time = self.burn_time[1]
244244

245245
# Reshape thrust curve if needed
246-
self.reshape_thrust_curve = reshape_thrust_curve
247-
if reshape_thrust_curve:
248-
self._reshape_thrust_curve(*reshape_thrust_curve)
246+
reshape_thrust_curve_input = reshape_thrust_curve
247+
if reshape_thrust_curve_input:
248+
try:
249+
new_burn_time, desired_impulse = reshape_thrust_curve_input
250+
except (TypeError, ValueError) as exc:
251+
raise TypeError(
252+
"reshape_thrust_curve must be an iterable with two elements:"
253+
" (new_burn_time, total_impulse)."
254+
) from exc
255+
256+
self.thrust = self.reshape_thrust_curve(
257+
self.thrust, new_burn_time, desired_impulse
258+
)
249259

250260
# Basic calculations and attributes
251261
self.burn_duration = self.burn_out_time - self.burn_start_time
252262
self.total_impulse = self.thrust.integral(
253263
self.burn_start_time, self.burn_out_time
254264
)
255-
self.max_thrust = self.thrust.max
265+
266+
# Calculate max_thrust and max_thrust_time
267+
try:
268+
self.max_thrust = self.thrust.max
269+
# Find time of max thrust by evaluating thrust at multiple points
270+
if hasattr(self.thrust, "x_array"):
271+
max_thrust_index = np.argmax(self.thrust.y_array)
272+
self.max_thrust_time = self.thrust.x_array[max_thrust_index]
273+
else:
274+
# For lambda functions, sample over burn time
275+
time_samples = np.linspace(
276+
self.burn_start_time, self.burn_out_time, 1000
277+
)
278+
thrust_samples = [self.thrust(t) for t in time_samples]
279+
max_thrust_index = np.argmax(thrust_samples)
280+
self.max_thrust = thrust_samples[max_thrust_index]
281+
self.max_thrust_time = time_samples[max_thrust_index]
282+
except AttributeError:
283+
# If thrust is lambda-based, sample to find max
284+
time_samples = np.linspace(self.burn_start_time, self.burn_out_time, 1000)
285+
thrust_samples = [self.thrust(t) for t in time_samples]
286+
max_thrust_index = np.argmax(thrust_samples)
287+
self.max_thrust = thrust_samples[max_thrust_index]
288+
self.max_thrust_time = time_samples[max_thrust_index]
256289

257290
self.average_thrust = self.total_impulse / self.burn_duration
291+
self.reshape_thrust_curve_request = reshape_thrust_curve_input
258292

259293
# Abstract methods - must be implemented by subclasses
260294
self._propellant_initial_mass = 0

0 commit comments

Comments
 (0)