diff --git a/manim/mobject/graphing/functions.py b/manim/mobject/graphing/functions.py index 970ddb0ce6..aae20c2e59 100644 --- a/manim/mobject/graphing/functions.py +++ b/manim/mobject/graphing/functions.py @@ -5,7 +5,7 @@ __all__ = ["ParametricFunction", "FunctionGraph", "ImplicitFunction"] -from typing import Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Callable, Iterable, Sequence import numpy as np from isosurfaces import plot_isoline @@ -14,6 +14,10 @@ from manim.mobject.graphing.scale import LinearBase, _ScaleBase from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL from manim.mobject.types.vectorized_mobject import VMobject + +if TYPE_CHECKING: + from manim.typing import Point2D, Point3D + from manim.utils.color import YELLOW @@ -23,9 +27,9 @@ class ParametricFunction(VMobject, metaclass=ConvertToOpenGL): Parameters ---------- function - The function to be plotted in the form of ``(lambda x: x**2)`` + The function to be plotted in the form of ``(lambda t: (x(t), y(t), z(t)))`` t_range - Determines the length that the function spans. By default ``[0, 1]`` + Determines the length that the function spans in the form of (t_min, t_max, step=0.01). By default ``[0, 1]`` scaling Scaling class applied to the points of the function. Default of :class:`~.LinearBase`. use_smoothing @@ -49,10 +53,10 @@ class ParametricFunction(VMobject, metaclass=ConvertToOpenGL): class PlotParametricFunction(Scene): def func(self, t): - return np.array((np.sin(2 * t), np.sin(3 * t), 0)) + return (np.sin(2 * t), np.sin(3 * t), 0) def construct(self): - func = ParametricFunction(self.func, t_range = np.array([0, TAU]), fill_opacity=0).set_color(RED) + func = ParametricFunction(self.func, t_range = (0, TAU), fill_opacity=0).set_color(RED) self.add(func.scale(3)) .. manim:: ThreeDParametricSpring @@ -61,11 +65,11 @@ def construct(self): class ThreeDParametricSpring(ThreeDScene): def construct(self): curve1 = ParametricFunction( - lambda u: np.array([ + lambda u: ( 1.2 * np.cos(u), 1.2 * np.sin(u), u * 0.05 - ]), color=RED, t_range = np.array([-3*TAU, 5*TAU, 0.01]) + ), color=RED, t_range = (-3*TAU, 5*TAU, 0.01) ).set_shade_in_3d(True) axes = ThreeDAxes() self.add(axes, curve1) @@ -97,8 +101,8 @@ def construct(self): def __init__( self, - function: Callable[[float, float], float], - t_range: Sequence[float] | None = None, + function: Callable[[float], Point3D], + t_range: Point2D | Point3D = (0, 1), scaling: _ScaleBase = LinearBase(), dt: float = 1e-8, discontinuities: Iterable[float] | None = None, @@ -107,7 +111,7 @@ def __init__( **kwargs, ): self.function = function - t_range = [0, 1, 0.01] if t_range is None else t_range + t_range = (0, 1, 0.01) if t_range is None else t_range if len(t_range) == 2: t_range = np.array([*t_range, 0.01])