from manim import *


class GradientDescent(Scene):
    def construct(self):
        # ── Title ──────────────────────────────────────────────
        title = Text("Gradient Descent", font_size=52, color=WHITE, weight=BOLD)
        subtitle = Text(
            "Finding the minimum of a function",
            font_size=28,
            color=GREY_B,
        )
        subtitle.next_to(title, DOWN, buff=0.3)
        self.play(Write(title), FadeIn(subtitle, shift=UP), run_time=1.5)
        self.wait(0.8)
        self.play(FadeOut(title), FadeOut(subtitle))

        # ── Axes + parabola ────────────────────────────────────
        axes = Axes(
            x_range=[-4, 4, 1],
            y_range=[-1, 16, 2],
            x_length=8,
            y_length=5,
            axis_config={"include_numbers": False, "font_size": 24},
        ).shift(DOWN * 0.5)

        # Manually add axis tick labels using Text (avoids LaTeX)
        x_tick_labels = VGroup()
        for val in range(-4, 5, 1):
            if val == 0:
                continue
            label = Text(str(val), font_size=18, color=GREY_B)
            label.next_to(axes.c2p(val, 0), DOWN, buff=0.15)
            x_tick_labels.add(label)

        y_tick_labels = VGroup()
        for val in range(0, 17, 4):
            if val == 0:
                continue
            label = Text(str(val), font_size=18, color=GREY_B)
            label.next_to(axes.c2p(0, val), LEFT, buff=0.15)
            y_tick_labels.add(label)

        x_label = Text("x", font_size=24).next_to(axes.x_axis, RIGHT, buff=0.2)
        y_label = Text("f(x)", font_size=24).next_to(axes.y_axis, UP, buff=0.2)

        graph = axes.plot(lambda x: x ** 2, color=BLUE, x_range=[-3.8, 3.8])

        equation = Text("f(x) = x\u00b2", font_size=36, color=BLUE)
        equation.to_corner(UR, buff=0.6)

        self.play(
            Create(axes),
            FadeIn(x_tick_labels),
            FadeIn(y_tick_labels),
            Write(x_label),
            Write(y_label),
            run_time=1.5,
        )
        self.play(Create(graph), Write(equation), run_time=1.5)
        self.wait(0.5)

        # ── Update rule ───────────────────────────────────────
        rule = Text(
            "x_new = x_old \u2212 \u03b1 \u00b7 f\u2032(x)",
            font_size=28,
            color=WHITE,
        )
        rule.to_corner(UL, buff=0.6)

        rule_label = Text("Update rule", font_size=20, color=GREY_B)
        rule_label.next_to(rule, DOWN, buff=0.15, aligned_edge=LEFT)

        self.play(Write(rule), FadeIn(rule_label, shift=UP), run_time=1.2)
        self.wait(0.6)

        # ── Derivative + learning rate ────────────────────────
        deriv = Text("f\u2032(x) = 2x", font_size=30, color=RED)
        deriv.next_to(equation, DOWN, buff=0.35, aligned_edge=RIGHT)

        alpha_label = Text("\u03b1 = 0.3  (learning rate)", font_size=22, color=YELLOW)
        alpha_label.next_to(rule_label, DOWN, buff=0.15, aligned_edge=LEFT)

        self.play(Write(deriv), FadeIn(alpha_label), run_time=1)

        # ── Gradient descent animation ────────────────────────
        alpha = 0.3
        x_val = 3.5

        dot = Dot(axes.i2gp(x_val, graph), color=YELLOW, radius=0.1)
        dot_label = Text(
            f"x = {x_val:.2f}", font_size=24, color=YELLOW
        ).next_to(dot, UR, buff=0.15)

        self.play(GrowFromCenter(dot), FadeIn(dot_label), run_time=0.8)
        self.wait(0.4)

        step_text = Text("", font_size=22).to_edge(DOWN, buff=0.4)

        steps = 7
        for i in range(steps):
            grad = 2 * x_val
            x_new = x_val - alpha * grad

            # Tangent line at current point
            slope = 2 * x_val
            y_cur = x_val ** 2
            tangent_fn = lambda x, s=slope, xv=x_val, yv=y_cur: s * (x - xv) + yv
            x_lo = max(x_val - 1.2, -3.8)
            x_hi = min(x_val + 1.2, 3.8)
            tangent = axes.plot(tangent_fn, x_range=[x_lo, x_hi], color=RED, stroke_width=2)

            new_step = Text(
                f"Step {i + 1}:  slope = {grad:.2f},   x: {x_val:.2f} \u2192 {x_new:.2f}",
                font_size=22,
                color=WHITE,
            ).to_edge(DOWN, buff=0.4)

            if i == 0:
                self.play(Create(tangent), FadeIn(new_step), run_time=0.8)
            else:
                self.play(Create(tangent), ReplacementTransform(step_text, new_step), run_time=0.7)
            step_text = new_step
            self.wait(0.3)

            # Move dot
            x_val = x_new
            new_dot = Dot(axes.i2gp(x_val, graph), color=YELLOW, radius=0.1)
            new_label = Text(
                f"x = {x_val:.2f}", font_size=24, color=YELLOW
            ).next_to(new_dot, UR, buff=0.15)

            self.play(
                ReplacementTransform(dot, new_dot),
                FadeOut(dot_label),
                FadeIn(new_label),
                FadeOut(tangent),
                run_time=0.7,
            )
            dot = new_dot
            dot_label = new_label
            self.wait(0.2)

        # ── Highlight minimum ─────────────────────────────────
        self.play(FadeOut(step_text), run_time=0.4)

        min_dot = Dot(axes.i2gp(0, graph), color=GREEN, radius=0.14)
        min_label = Text("minimum at x = 0", font_size=28, color=GREEN)
        min_label.next_to(min_dot, DOWN, buff=0.3)

        self.play(
            Flash(axes.i2gp(0, graph), color=GREEN, flash_radius=0.4),
            GrowFromCenter(min_dot),
            run_time=1,
        )
        self.play(Write(min_label), run_time=0.8)
        self.wait(0.5)

        # ── Closing ───────────────────────────────────────────
        closing = Text(
            "Gradient descent iteratively follows the slope\n"
            "downhill to find the function\u2019s minimum.",
            font_size=26,
            color=GREY_A,
            line_spacing=1.4,
        ).to_edge(DOWN, buff=0.5)

        self.play(FadeIn(closing, shift=UP), run_time=1)
        self.wait(2)
