diff --git a/manim/animation/composition.py b/manim/animation/composition.py index 20b23d828c..ee26eb504f 100644 --- a/manim/animation/composition.py +++ b/manim/animation/composition.py @@ -3,11 +3,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Sequence +import types +from typing import TYPE_CHECKING, Callable, Iterable, Sequence import numpy as np from manim.mobject.opengl.opengl_mobject import OpenGLGroup +from manim.utils.parameter_parsing import flatten_iterable_parameters from .._config import config from ..animation.animation import Animation, prepare_animation @@ -54,14 +56,15 @@ class AnimationGroup(Animation): def __init__( self, - *animations: Animation, + *animations: Animation | Iterable[Animation] | types.GeneratorType[Animation], group: Group | VGroup | OpenGLGroup | OpenGLVGroup = None, run_time: float | None = None, rate_func: Callable[[float], float] = linear, lag_ratio: float = 0, **kwargs, ) -> None: - self.animations = [prepare_animation(anim) for anim in animations] + arg_anim = flatten_iterable_parameters(animations) + self.animations = [prepare_animation(anim) for anim in arg_anim] self.rate_func = rate_func self.group = group if self.group is None: diff --git a/manim/renderer/cairo_renderer.py b/manim/renderer/cairo_renderer.py index f11b7ea7ca..e9d79b109f 100644 --- a/manim/renderer/cairo_renderer.py +++ b/manim/renderer/cairo_renderer.py @@ -1,7 +1,6 @@ from __future__ import annotations import typing -from typing import Any import numpy as np @@ -15,6 +14,10 @@ from ..utils.iterables import list_update if typing.TYPE_CHECKING: + import types + from typing import Any, Iterable + + from manim.animation.animation import Animation from manim.scene.scene import Scene @@ -51,7 +54,12 @@ def init_scene(self, scene): scene.__class__.__name__, ) - def play(self, scene, *args, **kwargs): + def play( + self, + scene: Scene, + *args: Animation | Iterable[Animation] | types.GeneratorType[Animation], + **kwargs, + ): # Reset skip_animations to the original state. # Needed when rendering only some animations, and skipping others. self.skip_animations = self._original_skipping_status diff --git a/manim/scene/scene.py b/manim/scene/scene.py index 5f3fb99247..2a3b37d76f 100644 --- a/manim/scene/scene.py +++ b/manim/scene/scene.py @@ -2,6 +2,8 @@ from __future__ import annotations +from manim.utils.parameter_parsing import flatten_iterable_parameters + __all__ = ["Scene"] import copy @@ -13,7 +15,6 @@ import time import types from queue import Queue -from typing import Callable import srt @@ -25,6 +26,8 @@ dearpygui_imported = True except ImportError: dearpygui_imported = False +from typing import TYPE_CHECKING + import numpy as np from tqdm import tqdm from watchdog.events import FileSystemEventHandler @@ -48,6 +51,9 @@ from ..utils.file_ops import open_media_file from ..utils.iterables import list_difference_update, list_update +if TYPE_CHECKING: + from typing import Callable, Iterable + class RerunSceneHandler(FileSystemEventHandler): """A class to handle rerunning a Scene after the input file is modified.""" @@ -865,7 +871,11 @@ def get_moving_and_static_mobjects(self, animations): ) return all_moving_mobject_families, static_mobjects - def compile_animations(self, *args: Animation, **kwargs): + def compile_animations( + self, + *args: Animation | Iterable[Animation] | types.GeneratorType[Animation], + **kwargs, + ): """ Creates _MethodAnimations from any _AnimationBuilders and updates animation kwargs with kwargs passed to play(). @@ -883,7 +893,9 @@ def compile_animations(self, *args: Animation, **kwargs): Animations to be played. """ animations = [] - for arg in args: + arg_anims = flatten_iterable_parameters(args) + # Allow passing a generator to self.play instead of comma separated arguments + for arg in arg_anims: try: animations.append(prepare_animation(arg)) except TypeError: @@ -1027,7 +1039,7 @@ def get_run_time(self, animations: list[Animation]): def play( self, - *args, + *args: Animation | Iterable[Animation] | types.GeneratorType[Animation], subcaption=None, subcaption_duration=None, subcaption_offset=0, @@ -1157,7 +1169,11 @@ def wait_until(self, stop_condition: Callable[[], bool], max_time: float = 60): """ self.wait(max_time, stop_condition=stop_condition) - def compile_animation_data(self, *animations: Animation, **play_kwargs): + def compile_animation_data( + self, + *animations: Animation | Iterable[Animation] | types.GeneratorType[Animation], + **play_kwargs, + ): """Given a list of animations, compile the corresponding static and moving mobjects, and gather the animation durations. diff --git a/manim/utils/parameter_parsing.py b/manim/utils/parameter_parsing.py new file mode 100644 index 0000000000..458885a5b3 --- /dev/null +++ b/manim/utils/parameter_parsing.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from types import GeneratorType +from typing import Iterable, TypeVar + +T = TypeVar("T") + + +def flatten_iterable_parameters( + args: Iterable[T | Iterable[T] | GeneratorType], +) -> list[T]: + """Flattens an iterable of parameters into a list of parameters. + + Parameters + ---------- + args + The iterable of parameters to flatten. + [(generator), [], (), ...] + + Returns + ------- + :class:`list` + The flattened list of parameters. + """ + flattened_parameters = [] + for arg in args: + if isinstance(arg, (Iterable, GeneratorType)): + flattened_parameters.extend(arg) + else: + flattened_parameters.append(arg) + return flattened_parameters