"""
Equity Momentum Backtest — bias-safe with walk-forward analysis.

Bias controls applied:
  1. Look-ahead: signals use only prices strictly before rebalance date
  2. Survivorship: delisted stocks (FTR, ETFC) included until delist date
  3. Transaction costs: $0.005/share commission (min $1) + 10 bps slippage
  4. Walk-forward: rolling 12m IS / 3m OOS windows, 3m step
  5. Position limits: 15% per name, 40% per sector

Usage:
    python run_backtest.py
"""

import os
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from dateutil.relativedelta import relativedelta
from backtest_config import *


# ── Data ─────────────────────────────────────────────────────

def load_data():
    prices = pd.read_csv("data/prices.csv", parse_dates=["date"])
    universe = pd.read_csv("data/universe.csv")
    universe["delist_date"] = pd.to_datetime(universe["delist_date"], errors="coerce")
    sector_map = dict(zip(universe["ticker"], universe["sector"]))
    delist_map = dict(zip(universe["ticker"], universe["delist_date"]))
    return prices, sector_map, delist_map


def build_close_panel(prices):
    return prices.pivot(index="date", columns="ticker", values="close").sort_index()


# ── Momentum Signal ─────────────────────────────────────────

def compute_momentum_signal(close, as_of_date):
    """
    12-1 momentum: return from t-lookback to t-skip (skip recent month).

    Look-ahead safe: uses only prices *before* the rebalance date.
    At month-start rebalance, this means the latest available month-end
    price is the prior month's close.
    """
    # Strictly before rebalance date → no current-day price leakage
    available = close.loc[close.index < as_of_date]
    monthly = available.resample("ME").last()

    needed = LOOKBACK_MONTHS + 1
    if len(monthly) < needed:
        return pd.Series(dtype=float)

    # price_recent: end of month (skip_months) months back from latest
    # price_start : end of month (lookback_months) months back from latest
    price_recent = monthly.iloc[-(1 + SKIP_MONTHS)]
    price_start = monthly.iloc[-(1 + LOOKBACK_MONTHS)]

    signal = (price_recent / price_start - 1).dropna()
    return signal


# ── Portfolio Construction ───────────────────────────────────

def construct_portfolio(signal, as_of_date, sector_map, delist_map):
    """
    Long top tercile, short bottom tercile, equal-weight within legs.
    Survivorship-bias safe: excludes stocks already delisted.
    """
    # Keep only active stocks (not yet delisted as of rebalance date)
    active = [t for t in signal.index
              if pd.isna(delist_map.get(t)) or as_of_date < delist_map[t]]
    signal = signal.reindex(active).dropna()

    if len(signal) < 3:
        return pd.Series(dtype=float)

    ranked = signal.sort_values(ascending=False)
    n_tercile = max(1, len(ranked) // 3)

    long_tkrs = ranked.index[:n_tercile]
    short_tkrs = ranked.index[-n_tercile:]

    weights = pd.Series(0.0, index=ranked.index)
    weights[long_tkrs] = LONG_WEIGHT / len(long_tkrs)
    weights[short_tkrs] = -SHORT_WEIGHT / len(short_tkrs)
    weights = weights[weights != 0]

    # ── Position cap ──
    weights = weights.clip(-MAX_POSITION_PCT, MAX_POSITION_PCT)

    # ── Sector cap ──
    sectors = pd.Series({t: sector_map.get(t, "Other") for t in weights.index})
    for sec in sectors.unique():
        mask = sectors == sec
        gross = weights[mask].abs().sum()
        if gross > MAX_SECTOR_PCT:
            weights[mask] *= MAX_SECTOR_PCT / gross

    # Re-normalize legs to target gross exposure
    long_mask = weights > 0
    short_mask = weights < 0
    if weights[long_mask].sum() > 0:
        weights[long_mask] *= LONG_WEIGHT / weights[long_mask].sum()
    if weights[short_mask].sum() < 0:
        weights[short_mask] *= SHORT_WEIGHT / abs(weights[short_mask].sum())

    weights = weights.clip(-MAX_POSITION_PCT, MAX_POSITION_PCT)
    return weights


# ── Transaction Costs ────────────────────────────────────────

def transaction_costs(old_w, new_w, port_value, prices_row):
    all_tkrs = set(old_w.index) | set(new_w.index)
    cost = 0.0
    for t in all_tkrs:
        delta = abs(new_w.get(t, 0.0) - old_w.get(t, 0.0))
        if delta < 1e-10:
            continue
        trade_val = delta * port_value
        cost += trade_val * SLIPPAGE_BPS / 1e4          # slippage
        px = prices_row.get(t, 100.0)
        shares = trade_val / px if px > 0 else 0
        cost += max(shares * COMMISSION_PER_SHARE, MIN_COMMISSION)  # commission
    return cost


# ── Backtest Engine ──────────────────────────────────────────

def run_backtest(close, sector_map, delist_map, start=None, end=None):
    """Run momentum backtest over [start, end]. Returns (daily_df, trade_list)."""
    idx = close.index
    if start is not None:
        idx = idx[idx >= pd.Timestamp(start)]
    if end is not None:
        idx = idx[idx <= pd.Timestamp(end)]
    if len(idx) == 0:
        return pd.DataFrame(columns=["portfolio_value"]), []

    # First trading day of each month → rebalance dates
    rebal_set = set()
    cur_ym = None
    for d in idx:
        ym = (d.year, d.month)
        if ym != cur_ym:
            rebal_set.add(d)
            cur_ym = ym

    pv = INITIAL_CAPITAL
    weights = pd.Series(dtype=float)
    prev_d = None
    daily_rows = []
    trades = []

    for d in idx:
        # ── Daily P&L from held positions ──
        if prev_d is not None and len(weights) > 0:
            day_ret = (close.loc[d] / close.loc[prev_d] - 1)
            day_ret = day_ret.reindex(weights.index).fillna(0.0)
            pv *= 1.0 + (weights * day_ret).sum()

        # ── Rebalance ──
        if d in rebal_set:
            sig = compute_momentum_signal(close, d)
            if len(sig) > 0:
                new_w = construct_portfolio(sig, d, sector_map, delist_map)
                if len(new_w) > 0:
                    px = close.loc[d].dropna()
                    cost = transaction_costs(weights, new_w, pv, px)
                    pv -= cost
                    all_idx = new_w.index.union(weights.index)
                    turnover = (
                        new_w.reindex(all_idx, fill_value=0)
                        - weights.reindex(all_idx, fill_value=0)
                    ).abs().sum() / 2
                    trades.append({
                        "date": d,
                        "n_long": int((new_w > 0).sum()),
                        "n_short": int((new_w < 0).sum()),
                        "turnover": turnover,
                        "costs": cost,
                        "portfolio_value": pv,
                    })
                    weights = new_w

        daily_rows.append({"date": d, "portfolio_value": pv})
        prev_d = d

    return pd.DataFrame(daily_rows).set_index("date"), trades


# ── Metrics ──────────────────────────────────────────────────

def compute_metrics(dv):
    if len(dv) < 2:
        return {}
    v = dv["portfolio_value"]
    r = v.pct_change().dropna()
    days = (v.index[-1] - v.index[0]).days
    if days <= 0:
        return {}
    total_ret = v.iloc[-1] / v.iloc[0]
    cagr = total_ret ** (365.25 / days) - 1
    vol = r.std() * np.sqrt(252)
    sharpe = (r.mean() / r.std() * np.sqrt(252)) if r.std() > 0 else 0.0
    peak = v.cummax()
    dd = ((v - peak) / peak).min()
    mr = v.resample("ME").last().pct_change().dropna()
    return {
        "CAGR": cagr,
        "Sharpe": sharpe,
        "Max Drawdown": dd,
        "Win Rate (Monthly)": (mr > 0).mean() if len(mr) else 0.0,
        "Avg Monthly Return": mr.mean() if len(mr) else 0.0,
        "Annualized Vol": vol,
        "Calmar": cagr / abs(dd) if dd != 0 else 0.0,
        "Total Return": total_ret - 1,
        "Days": len(r),
    }


# ── Walk-Forward ─────────────────────────────────────────────

def walk_forward(close, sector_map, delist_map):
    """Rolling walk-forward: 12m train / 3m OOS test / 3m step."""
    # First date where momentum signal is computable
    first_possible = close.index[0] + relativedelta(months=LOOKBACK_MONTHS + SKIP_MONTHS)
    candidates = close.index[close.index >= first_possible]
    if len(candidates) == 0:
        return []
    wf_start = candidates[0]
    data_end = close.index[-1]

    results = []
    t_start = wf_start
    win = 1

    while True:
        t_end = t_start + relativedelta(months=TRAIN_WINDOW_MONTHS)
        oos_start = t_end
        oos_end = oos_start + relativedelta(months=TEST_WINDOW_MONTHS)

        t_end_snap = close.index[close.index < pd.Timestamp(t_end)]
        oos_start_snap = close.index[close.index >= pd.Timestamp(oos_start)]
        oos_end_snap = close.index[close.index <= min(pd.Timestamp(oos_end), data_end)]

        if len(t_end_snap) == 0 or len(oos_start_snap) == 0:
            break
        if oos_start_snap[0] > data_end:
            break

        is_daily, _ = run_backtest(close, sector_map, delist_map,
                                   start=t_start, end=t_end_snap[-1])
        oos_daily, _ = run_backtest(close, sector_map, delist_map,
                                    start=oos_start_snap[0],
                                    end=oos_end_snap[-1] if len(oos_end_snap) else data_end)

        results.append({
            "window": win,
            "train_start": t_start.strftime("%Y-%m-%d"),
            "train_end": t_end_snap[-1].strftime("%Y-%m-%d"),
            "test_start": oos_start_snap[0].strftime("%Y-%m-%d"),
            "test_end": (oos_end_snap[-1] if len(oos_end_snap) else data_end).strftime("%Y-%m-%d"),
            "train_metrics": compute_metrics(is_daily),
            "test_metrics": compute_metrics(oos_daily),
        })

        t_start += relativedelta(months=STEP_MONTHS)
        win += 1
        if pd.Timestamp(oos_end) >= data_end:
            break

    return results


# ── Report ───────────────────────────────────────────────────

def _fmt(v, kind):
    if kind == "pct":
        return f"{v:.2%}"
    if kind == "f4":
        return f"{v:.4f}"
    if kind == "d":
        return f"{v:.0f}"
    return str(v)


_PCT_KEYS = {"CAGR", "Max Drawdown", "Win Rate (Monthly)",
             "Avg Monthly Return", "Annualized Vol", "Total Return"}


def generate_report(full_daily, trades, wf, train_m, val_m, test_m):
    L = []
    L.append("=" * 80)
    L.append("EQUITY MOMENTUM BACKTEST — SUMMARY REPORT")
    L.append(f"Generated: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M')}")
    L.append("=" * 80)

    L.append("\nBIAS CONTROLS APPLIED")
    L.append("-" * 40)
    L.append("1. Look-ahead bias:   Signals computed only from prices available before")
    L.append("                      each rebalance date (strict < as_of_date filter).")
    L.append("2. Survivorship bias: Universe includes delisted stocks FTR (delist 2024-04-30)")
    L.append("                      and ETFC (delist 2024-10-15). Both participate in signal")
    L.append("                      ranking and portfolios until their respective delist dates.")
    L.append("3. Transaction costs: $0.005/share commission (min $1.00/order) + 10 bps")
    L.append("                      slippage applied at every monthly rebalance.")
    L.append("4. Walk-forward:      Rolling 12-month IS / 3-month OOS windows, 3-month step.")
    L.append("                      No parameter optimization — strategy is fixed 12-1 momentum.")
    L.append("5. Position limits:   15% max per name, 40% max per sector.")

    # ── Overall ──
    fm = compute_metrics(full_daily)
    L.append("\nOVERALL PERFORMANCE (Full Period)")
    L.append("-" * 40)
    for k in ["CAGR", "Sharpe", "Max Drawdown", "Win Rate (Monthly)",
              "Avg Monthly Return", "Annualized Vol", "Calmar", "Total Return", "Days"]:
        v = fm.get(k, 0)
        kind = "pct" if k in _PCT_KEYS else ("d" if k == "Days" else "f4")
        L.append(f"  {k:30s}: {_fmt(v, kind):>10s}")

    # ── Train / Validation / Test ──
    L.append("\nTRAIN vs VALIDATION vs TEST")
    L.append("-" * 70)
    L.append(f"  {'Metric':28s} {'Train':>12s} {'Validation':>12s} {'Test':>12s}")
    L.append("  " + "-" * 64)
    for k in ["CAGR", "Sharpe", "Max Drawdown", "Win Rate (Monthly)",
              "Annualized Vol", "Calmar"]:
        kind = "pct" if k in _PCT_KEYS else "f4"
        L.append(f"  {k:28s} {_fmt(train_m.get(k,0), kind):>12s}"
                 f" {_fmt(val_m.get(k,0), kind):>12s}"
                 f" {_fmt(test_m.get(k,0), kind):>12s}")

    # ── Walk-Forward ──
    L.append("\nWALK-FORWARD ANALYSIS")
    L.append("-" * 95)
    L.append(f"  {'Win':>3s} | {'Train Period':>25s} | {'Test Period':>25s}"
             f" | {'IS Sharpe':>10s} | {'OOS Sharpe':>10s} | {'OOS CAGR':>10s}")
    L.append("  " + "-" * 92)
    for r in wf:
        tp = f"{r['train_start']} to {r['train_end']}"
        op = f"{r['test_start']} to {r['test_end']}"
        L.append(f"  {r['window']:>3d} | {tp:>25s} | {op:>25s}"
                 f" | {r['train_metrics'].get('Sharpe',0):>10.4f}"
                 f" | {r['test_metrics'].get('Sharpe',0):>10.4f}"
                 f" | {_fmt(r['test_metrics'].get('CAGR',0), 'pct'):>10s}")

    if wf:
        is_s = [r["train_metrics"].get("Sharpe", 0) for r in wf]
        oos_s = [r["test_metrics"].get("Sharpe", 0) for r in wf]
        L.append(f"\n  Walk-Forward Summary:")
        L.append(f"    Avg IS Sharpe:     {np.mean(is_s):.4f}")
        L.append(f"    Avg OOS Sharpe:    {np.mean(oos_s):.4f}")
        if np.mean(is_s) != 0:
            decay = 1 - np.mean(oos_s) / np.mean(is_s)
            L.append(f"    Sharpe Decay:      {decay:.1%}")
        L.append(f"    OOS Sharpe > 0:    {sum(1 for s in oos_s if s > 0)}/{len(oos_s)} windows")

    # ── Trading Summary ──
    if trades:
        tc = sum(t["costs"] for t in trades)
        n_yrs = len(trades) / 12
        L.append(f"\nTRADING SUMMARY")
        L.append("-" * 40)
        L.append(f"  Total rebalances:    {len(trades)}")
        L.append(f"  Total costs:         ${tc:,.2f}")
        L.append(f"  Avg cost/rebalance:  ${tc/len(trades):,.2f}")
        L.append(f"  Avg one-way turnover:{np.mean([t['turnover'] for t in trades]):>8.2%}")
        if n_yrs > 0:
            L.append(f"  Annualized cost drag:{tc / INITIAL_CAPITAL / n_yrs:>8.2%}")

    # ── Conclusion ──
    ts = test_m.get("Sharpe", 0)
    L.append(f"\nCONCLUSION")
    L.append("-" * 40)
    if ts > 1.0:
        verdict = "Strategy shows credible OOS performance."
    elif ts > 0.3:
        verdict = "Strategy shows moderate OOS performance — warrants further study."
    elif ts > 0:
        verdict = "Strategy shows weak positive OOS performance."
    else:
        verdict = "Strategy does not demonstrate positive OOS performance."
    L.append(f"  {verdict}")
    L.append(f"  Test Sharpe: {ts:.4f} | Test CAGR: {_fmt(test_m.get('CAGR',0),'pct')}"
             f" | Test MaxDD: {_fmt(test_m.get('Max Drawdown',0),'pct')}")
    if wf:
        oos_s = [r["test_metrics"].get("Sharpe", 0) for r in wf]
        L.append(f"  Walk-forward avg OOS Sharpe: {np.mean(oos_s):.4f}"
                 f" across {len(oos_s)} windows")
    L.append(f"\n  NOTE: Results based on synthetic GBM data. Validate on live")
    L.append(f"  market data before any deployment or capital allocation.")
    L.append("=" * 80)

    return "\n".join(L)


# ── Charts ───────────────────────────────────────────────────

def generate_charts(full_daily, wf):
    fig, axes = plt.subplots(3, 1, figsize=(14, 12))
    v = full_daily["portfolio_value"]

    # 1. Equity curve with regime lines
    ax = axes[0]
    ax.plot(v.index, v.values, "k-", lw=1.2)
    ax.axhline(INITIAL_CAPITAL, color="gray", ls="--", alpha=0.5, label="Initial $1M")
    for lbl, dt, c in [("Train end", TRAIN_END, "blue"),
                       ("Val end", VALIDATION_END, "orange")]:
        ax.axvline(pd.Timestamp(dt), color=c, ls="--", alpha=0.6, label=lbl)
    ax.set_title("Equity Curve — 12-1 Momentum Long/Short", fontsize=13)
    ax.set_ylabel("Portfolio Value ($)")
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)
    ax.ticklabel_format(style="plain", axis="y")

    # 2. Drawdown
    ax = axes[1]
    dd = (v - v.cummax()) / v.cummax()
    ax.fill_between(dd.index, dd.values, 0, color="red", alpha=0.4)
    ax.set_title("Drawdown", fontsize=13)
    ax.set_ylabel("Drawdown")
    ax.grid(True, alpha=0.3)
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}"))

    # 3. Walk-forward IS vs OOS Sharpe
    ax = axes[2]
    if wf:
        x = np.arange(len(wf))
        w = 0.35
        ax.bar(x - w / 2,
               [r["train_metrics"].get("Sharpe", 0) for r in wf],
               w, label="In-Sample", color="steelblue", alpha=0.7)
        ax.bar(x + w / 2,
               [r["test_metrics"].get("Sharpe", 0) for r in wf],
               w, label="Out-of-Sample", color="coral", alpha=0.7)
        ax.axhline(0, color="k", lw=0.8)
        ax.set_xticks(x)
        ax.set_xticklabels([f"W{r['window']}" for r in wf])
    ax.set_title("Walk-Forward: In-Sample vs Out-of-Sample Sharpe", fontsize=13)
    ax.set_ylabel("Sharpe Ratio")
    ax.legend()
    ax.grid(True, alpha=0.3, axis="y")

    plt.tight_layout()
    plt.savefig("results/performance.png", dpi=150, bbox_inches="tight")
    plt.close()
    print("  Saved results/performance.png")


# ── Main ─────────────────────────────────────────────────────

def main():
    os.makedirs("results", exist_ok=True)

    print("Loading data...")
    prices, sector_map, delist_map = load_data()
    close = build_close_panel(prices)
    print(f"  {len(close)} days, {len(close.columns)} tickers,"
          f" {close.index[0].date()} to {close.index[-1].date()}")

    # ── Full backtest ──
    print("\nRunning full-period backtest...")
    full_daily, trades = run_backtest(close, sector_map, delist_map)
    print(f"  Final PV: ${full_daily['portfolio_value'].iloc[-1]:,.2f}"
          f"  ({len(trades)} rebalances)")

    # ── Split metrics ──
    print("\nComputing train/validation/test splits...")
    train_m = compute_metrics(full_daily.loc[:TRAIN_END])
    val_m = compute_metrics(full_daily.loc[TRAIN_END:VALIDATION_END])
    test_m = compute_metrics(full_daily.loc[VALIDATION_END:])
    print(f"  Train Sharpe:  {train_m.get('Sharpe', 0):.4f}")
    print(f"  Val Sharpe:    {val_m.get('Sharpe', 0):.4f}")
    print(f"  Test Sharpe:   {test_m.get('Sharpe', 0):.4f}")

    # ── Walk-forward ──
    print("\nRunning walk-forward analysis...")
    wf = walk_forward(close, sector_map, delist_map)
    print(f"  {len(wf)} walk-forward windows completed")

    # ── Report ──
    report = generate_report(full_daily, trades, wf, train_m, val_m, test_m)
    with open("results/backtest_report.txt", "w") as f:
        f.write(report)
    print("\n" + report)

    # ── Charts ──
    print("\nGenerating charts...")
    generate_charts(full_daily, wf)

    # ── Save CSVs ──
    full_daily.to_csv("results/daily_values.csv")
    pd.DataFrame(trades).to_csv("results/trade_log.csv", index=False)

    wf_rows = []
    for r in wf:
        row = {"window": r["window"], "train_start": r["train_start"],
               "train_end": r["train_end"], "test_start": r["test_start"],
               "test_end": r["test_end"]}
        for k, v in r["train_metrics"].items():
            row[f"IS_{k}"] = v
        for k, v in r["test_metrics"].items():
            row[f"OOS_{k}"] = v
        wf_rows.append(row)
    pd.DataFrame(wf_rows).to_csv("results/walk_forward.csv", index=False)

    print("\nAll results saved to results/")


if __name__ == "__main__":
    main()
