"""Daily product-usage pipeline — ingest app events, validate, transform,
and load warehouse tables for the exec metrics review.

Source: s3://company-data-lake/raw/app_events/ds={{ ds }}/
Targets:
  - analytics.product_usage_events_clean  (deduped event-level rows)
  - analytics.product_usage_daily         (feature × platform aggregates)

Schedule: 06:00 UTC daily, after the warehouse_health_check DAG.
"""
from datetime import datetime, timedelta

from airflow.decorators import dag, task
from airflow.operators.empty import EmptyOperator
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
from airflow.providers.slack.notifications.slack import send_slack_notification
from airflow.utils.trigger_rule import TriggerRule

S3_BUCKET = "company-data-lake"
S3_EVENTS_PREFIX = "raw/app_events/ds={ds}/"
S3_STAGING_PREFIX = "staging/product_usage/ds={ds}/"
REDSHIFT_CONN = "redshift_default"
REDSHIFT_SCHEMA = "analytics"

MIN_ROW_COUNT = 100_000
NULL_RATE_MAX = 0.02
DUPLICATE_RATE_MAX = 0.001

KEY_COLUMNS = [
    "event_id",
    "user_id",
    "event_type",
    "event_timestamp",
    "session_id",
    "platform",
    "feature_name",
]

default_args = {
    "owner": "data-team",
    "depends_on_past": False,
    "email": ["alex@company.com", "data-team@company.com"],
    "email_on_failure": True,
    "email_on_retry": False,
    "retries": 3,
    "retry_delay": timedelta(minutes=5),
    "retry_exponential_backoff": True,
    "max_retry_delay": timedelta(minutes=60),
}


def _on_failure(context):
    """Task-level failure callback — posts to #data-alerts."""
    ti = context["task_instance"]
    send_slack_notification(
        slack_conn_id="slack_default",
        text=(
            f":red_circle: *Task failed* — `{ti.dag_id}.{ti.task_id}`\n"
            f"Execution date: `{context['ds']}`\n"
            f"Error: {context.get('exception', 'unknown')}\n"
            f"<{ti.log_url}|View logs>"
        ),
        channel="#data-alerts",
    )


@dag(
    dag_id="daily_product_usage",
    default_args=default_args,
    description="Ingest app events, validate, transform, and load product usage metrics",
    schedule="0 6 * * *",
    start_date=datetime(2024, 6, 1),
    catchup=False,
    tags=["product-usage", "exec-metrics", "etl"],
    max_active_runs=1,
    on_failure_callback=_on_failure,
)
def daily_product_usage():

    start = EmptyOperator(task_id="start")

    wait_for_partition = S3KeySensor(
        task_id="wait_for_partition",
        bucket_name=S3_BUCKET,
        bucket_key=S3_EVENTS_PREFIX.format(ds="{{ ds }}"),
        aws_conn_id="aws_default",
        poke_interval=300,
        timeout=7200,
        mode="reschedule",
    )

    # ── Validate ────────────────────────────────────────────────────────

    @task()
    def validate_partition(**context) -> dict:
        """Check row count, duplicate rate, and null rate on the raw partition."""
        import pyarrow.parquet as pq
        from airflow.providers.amazon.aws.hooks.s3 import S3Hook

        ds = context["ds"]
        s3 = S3Hook(aws_conn_id="aws_default")
        prefix = S3_EVENTS_PREFIX.format(ds=ds)
        keys = s3.list_keys(bucket_name=S3_BUCKET, prefix=prefix)
        if not keys:
            raise FileNotFoundError(f"No files in s3://{S3_BUCKET}/{prefix}")

        # Read partition into Arrow table for fast validation
        frames = []
        for key in keys:
            local = s3.download_file(key=key, bucket_name=S3_BUCKET)
            frames.append(pq.read_table(local))

        import pyarrow as pa
        table = pa.concat_tables(frames)
        total_rows = table.num_rows

        # Row-count check
        if total_rows < MIN_ROW_COUNT:
            raise ValueError(
                f"Partition {ds} has {total_rows} rows, below minimum {MIN_ROW_COUNT}"
            )

        # Duplicate event_id check
        event_ids = table.column("event_id").to_pylist()
        n_dupes = len(event_ids) - len(set(event_ids))
        dupe_rate = n_dupes / total_rows
        if dupe_rate > DUPLICATE_RATE_MAX:
            raise ValueError(
                f"Duplicate rate {dupe_rate:.4%} exceeds threshold {DUPLICATE_RATE_MAX:.4%}"
            )

        # Null-rate check on key columns
        for col in KEY_COLUMNS:
            null_count = table.column(col).null_count
            null_rate = null_count / total_rows
            if null_rate > NULL_RATE_MAX:
                raise ValueError(
                    f"Null rate for {col} is {null_rate:.4%}, exceeds {NULL_RATE_MAX:.4%}"
                )

        return {
            "ds": ds,
            "total_rows": total_rows,
            "duplicate_count": n_dupes,
            "status": "passed",
        }

    # ── Transform ───────────────────────────────────────────────────────

    @task()
    def transform_clean_events(validation: dict, **context) -> str:
        """Deduplicate and clean raw events, write to staging S3 path.

        Returns the S3 key of the cleaned parquet file.
        """
        import pandas as pd
        from airflow.providers.amazon.aws.hooks.s3 import S3Hook

        ds = context["ds"]
        s3 = S3Hook(aws_conn_id="aws_default")
        prefix = S3_EVENTS_PREFIX.format(ds=ds)
        keys = s3.list_keys(bucket_name=S3_BUCKET, prefix=prefix)

        frames = []
        for key in keys:
            local = s3.download_file(key=key, bucket_name=S3_BUCKET)
            frames.append(pd.read_parquet(local))

        df = pd.concat(frames, ignore_index=True)

        # Deduplicate on event_id, keeping first occurrence
        df = df.drop_duplicates(subset=["event_id"], keep="first")

        # Drop rows with nulls in key columns
        df = df.dropna(subset=KEY_COLUMNS)

        # Add ingestion timestamp
        df["ingested_at"] = datetime.utcnow().isoformat()

        # Write cleaned parquet to staging
        staging_key = S3_STAGING_PREFIX.format(ds=ds) + "events_clean.parquet"
        local_path = f"/tmp/events_clean_{ds}.parquet"
        df.to_parquet(local_path, index=False)
        s3.load_file(
            filename=local_path,
            key=staging_key,
            bucket_name=S3_BUCKET,
            replace=True,
        )
        return staging_key

    @task()
    def transform_daily_aggregates(clean_events_key: str, **context) -> str:
        """Aggregate clean events into daily feature × platform metrics.

        Returns the S3 key of the aggregated parquet file.
        """
        import pandas as pd
        from airflow.providers.amazon.aws.hooks.s3 import S3Hook

        ds = context["ds"]
        s3 = S3Hook(aws_conn_id="aws_default")
        local = s3.download_file(key=clean_events_key, bucket_name=S3_BUCKET)
        df = pd.read_parquet(local)

        df["event_timestamp"] = pd.to_datetime(df["event_timestamp"])

        # Session durations: time between first and last event per session
        session_stats = (
            df.groupby(["session_id", "feature_name", "platform"])["event_timestamp"]
            .agg(["min", "max"])
            .reset_index()
        )
        session_stats["duration_sec"] = (
            session_stats["max"] - session_stats["min"]
        ).dt.total_seconds()

        # Aggregate per feature × platform
        agg = (
            df.groupby(["feature_name", "platform"])
            .agg(
                unique_users=("user_id", "nunique"),
                total_events=("event_id", "count"),
            )
            .reset_index()
        )

        # Session duration stats per feature × platform
        dur_stats = (
            session_stats.groupby(["feature_name", "platform"])["duration_sec"]
            .agg(
                avg_session_duration_sec="mean",
                p95_session_duration_sec=lambda x: x.quantile(0.95),
            )
            .reset_index()
        )

        agg = agg.merge(dur_stats, on=["feature_name", "platform"], how="left")

        # New vs returning users (first event_timestamp per user globally approximation)
        first_seen = df.groupby("user_id")["event_timestamp"].min().reset_index()
        first_seen.columns = ["user_id", "first_event"]
        df = df.merge(first_seen, on="user_id", how="left")
        df["is_new"] = df["first_event"].dt.date == pd.to_datetime(ds).date()

        new_returning = (
            df.groupby(["feature_name", "platform"])
            .agg(
                new_users=("is_new", "sum"),
                returning_users=("is_new", lambda x: (~x).sum()),
            )
            .reset_index()
        )

        agg = agg.merge(new_returning, on=["feature_name", "platform"], how="left")
        agg.insert(0, "ds", ds)

        staging_key = S3_STAGING_PREFIX.format(ds=ds) + "daily_aggregates.parquet"
        local_path = f"/tmp/daily_aggregates_{ds}.parquet"
        agg.to_parquet(local_path, index=False)

        s3_hook = S3Hook(aws_conn_id="aws_default")
        s3_hook.load_file(
            filename=local_path,
            key=staging_key,
            bucket_name=S3_BUCKET,
            replace=True,
        )
        return staging_key

    # ── Load ────────────────────────────────────────────────────────────

    @task()
    def load_clean_events(clean_events_key: str, **context):
        """COPY cleaned events from S3 staging into Redshift (incremental)."""
        from airflow.providers.postgres.hooks.postgres import PostgresHook

        ds = context["ds"]
        pg = PostgresHook(postgres_conn_id=REDSHIFT_CONN)

        # Delete existing partition then COPY — idempotent reload
        pg.run(
            f"DELETE FROM {REDSHIFT_SCHEMA}.product_usage_events_clean WHERE ds = %s",
            parameters=(ds,),
        )
        pg.run(f"""
            COPY {REDSHIFT_SCHEMA}.product_usage_events_clean
            FROM 's3://{S3_BUCKET}/{clean_events_key}'
            IAM_ROLE default
            FORMAT AS PARQUET;
        """)

    @task()
    def load_daily_aggregates(daily_agg_key: str, **context):
        """COPY daily aggregates from S3 staging into Redshift (incremental)."""
        from airflow.providers.postgres.hooks.postgres import PostgresHook

        ds = context["ds"]
        pg = PostgresHook(postgres_conn_id=REDSHIFT_CONN)

        pg.run(
            f"DELETE FROM {REDSHIFT_SCHEMA}.product_usage_daily WHERE ds = %s",
            parameters=(ds,),
        )
        pg.run(f"""
            COPY {REDSHIFT_SCHEMA}.product_usage_daily
            FROM 's3://{S3_BUCKET}/{daily_agg_key}'
            IAM_ROLE default
            FORMAT AS PARQUET;
        """)

    # ── Notify ──────────────────────────────────────────────────────────

    @task()
    def notify_success(validation: dict, **context):
        """Post run summary to #data-pipeline-runs."""
        ds = context["ds"]
        send_slack_notification(
            slack_conn_id="slack_default",
            text=(
                f":large_green_circle: *daily_product_usage* completed for `{ds}`\n"
                f"Rows ingested: {validation['total_rows']:,}\n"
                f"Duplicates removed: {validation['duplicate_count']:,}\n"
                f"Tables refreshed: `product_usage_events_clean`, `product_usage_daily`"
            ),
            channel="#data-pipeline-runs",
        )

    end = EmptyOperator(task_id="end")

    # ── Wire dependencies ───────────────────────────────────────────────

    validated = validate_partition()
    clean_key = transform_clean_events(validated)
    agg_key = transform_daily_aggregates(clean_key)
    loaded_clean = load_clean_events(clean_key)
    loaded_agg = load_daily_aggregates(agg_key)
    notified = notify_success(validated)

    start >> wait_for_partition >> validated
    clean_key >> [loaded_clean, agg_key]
    loaded_agg >> notified >> end


daily_product_usage()
