diff --git a/manim/mobject/three_d/three_dimensions.py b/manim/mobject/three_d/three_dimensions.py index accb077746..bcdde3e188 100644 --- a/manim/mobject/three_d/three_dimensions.py +++ b/manim/mobject/three_d/three_dimensions.py @@ -32,7 +32,7 @@ from manim.mobject.mobject import * from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL from manim.mobject.opengl.opengl_mobject import OpenGLMobject -from manim.mobject.types.vectorized_mobject import VGroup, VMobject +from manim.mobject.types.vectorized_mobject import VectorizedPoint, VGroup, VMobject from manim.utils.color import ( ManimColor, ParsableManimColor, @@ -616,17 +616,18 @@ def __init__( **kwargs, ) # used for rotations + self.new_height = height self._current_theta = 0 self._current_phi = 0 - + self.base_circle = Circle( + radius=base_radius, + color=self.fill_color, + fill_opacity=self.fill_opacity, + stroke_width=0, + ) + self.base_circle.shift(height * IN) + self._set_start_and_end_attributes(direction) if show_base: - self.base_circle = Circle( - radius=base_radius, - color=self.fill_color, - fill_opacity=self.fill_opacity, - stroke_width=0, - ) - self.base_circle.shift(height * IN) self.add(self.base_circle) self._rotate_to_direction() @@ -656,6 +657,12 @@ def func(self, u: float, v: float) -> np.ndarray: ], ) + def get_start(self) -> np.ndarray: + return self.start_point.get_center() + + def get_end(self) -> np.ndarray: + return self.end_point.get_center() + def _rotate_to_direction(self) -> None: x, y, z = self.direction @@ -710,6 +717,15 @@ def get_direction(self) -> np.ndarray: """ return self.direction + def _set_start_and_end_attributes(self, direction): + normalized_direction = direction * np.linalg.norm(direction) + + start = self.base_circle.get_center() + end = start + normalized_direction * self.new_height + self.start_point = VectorizedPoint(start) + self.end_point = VectorizedPoint(end) + self.add(self.start_point, self.end_point) + class Cylinder(Surface): """A cylinder, defined by its height, radius and direction, @@ -1150,14 +1166,20 @@ def __init__( self.end - height * self.direction, **kwargs, ) - self.cone = Cone( - direction=self.direction, base_radius=base_radius, height=height, **kwargs + direction=self.direction, + base_radius=base_radius, + height=height, + **kwargs, ) self.cone.shift(end) - self.add(self.cone) + self.end_point = VectorizedPoint(end) + self.add(self.end_point, self.cone) self.set_color(color) + def get_end(self) -> np.ndarray: + return self.end_point.get_center() + class Torus(Surface): """A torus. diff --git a/tests/test_graphical_units/test_threed.py b/tests/test_graphical_units/test_threed.py index 022201f4c8..b6079e5e4c 100644 --- a/tests/test_graphical_units/test_threed.py +++ b/tests/test_graphical_units/test_threed.py @@ -33,6 +33,18 @@ def test_Cone(scene): scene.add(Cone(resolution=16)) +def test_Cone_get_start_and_get_end(): + cone = Cone().shift(RIGHT).rotate(PI / 4, about_point=ORIGIN, about_edge=OUT) + start = [0.70710678, 0.70710678, -1.0] + end = [0.70710678, 0.70710678, 0.0] + assert np.allclose( + cone.get_start(), start, atol=0.01 + ), "start points of Cone do not match" + assert np.allclose( + cone.get_end(), end, atol=0.01 + ), "end points of Cone do not match" + + @frames_comparison(base_scene=ThreeDScene) def test_Cylinder(scene): scene.add(Cylinder()) @@ -149,3 +161,14 @@ def param_surface(u, v): axes=axes, colorscale=[(RED, -0.4), (YELLOW, 0), (GREEN, 0.4)], axis=1 ) scene.add(axes, surface_plane) + + +def test_get_start_and_end_Arrow3d(): + start, end = ORIGIN, np.array([2, 1, 0]) + arrow = Arrow3D(start, end) + assert np.allclose( + arrow.get_start(), start, atol=0.01 + ), "start points of Arrow3D do not match" + assert np.allclose( + arrow.get_end(), end, atol=0.01 + ), "end points of Arrow3D do not match"