import pandas as pd
import numpy as np
from scipy import stats
from datetime import timedelta
import json
import warnings
warnings.filterwarnings('ignore')

# --- Load data ---
users = pd.read_csv('experiment_export.csv')
events = pd.read_csv('event_logs.csv')

print("=== RAW DATA ===")
print(f"Total users in export: {len(users)}")
print(f"Total events: {len(events)}")

# --- Step 1: Exclusions ---
internal = users[users['is_internal'] == True]
enterprise_manual = users[users['enterprise_manual_assign'] == True]
print(f"\nExcluded internal accounts: {list(internal['user_id'].values)}")
print(f"Excluded enterprise manual assigns: {list(enterprise_manual['user_id'].values)}")

clean_users = users[(users['is_internal'] == False) & (users['enterprise_manual_assign'] == False)].copy()
print(f"Clean users remaining: {len(clean_users)}")

# --- Step 2: Parse dates ---
clean_users['signup_date'] = pd.to_datetime(clean_users['signup_date'])
events['event_timestamp'] = pd.to_datetime(events['event_timestamp'])

# --- Step 3: Build activation flags per user ---
clean_user_ids = set(clean_users['user_id'].values)

results = []
for _, user in clean_users.iterrows():
    uid = user['user_id']
    signup = user['signup_date']
    group = user['group']
    platform = user['platform']
    region = user['region']

    # 14-day window
    window_end = signup + timedelta(days=14)

    user_events = events[(events['user_id'] == uid) & (events['event_timestamp'] <= window_end)]

    # Check each activation criterion
    onboarding_done = 'onboarding_completed' in user_events['event_name'].values
    project_created = 'first_project_created' in user_events['event_name'].values
    social_action = ('teammate_invited' in user_events['event_name'].values or
                     'project_shared' in user_events['event_name'].values)

    activated = onboarding_done and project_created and social_action

    # Onboarding duration
    ob_events = user_events[user_events['event_name'] == 'onboarding_completed']
    ob_duration = None
    if len(ob_events) > 0:
        props = json.loads(ob_events.iloc[0]['event_properties'])
        ob_duration = props.get('duration_sec')

    results.append({
        'user_id': uid,
        'group': group,
        'platform': platform,
        'region': region,
        'signup_date': signup,
        'onboarding_completed': onboarding_done,
        'project_created': project_created,
        'social_action': social_action,
        'activated': activated,
        'onboarding_duration_sec': ob_duration
    })

df = pd.DataFrame(results)

# --- Step 4: Primary analysis ---
print("\n=== GROUP SUMMARY ===")
for grp in ['treatment', 'control']:
    g = df[df['group'] == grp]
    n = len(g)
    activated = g['activated'].sum()
    rate = activated / n * 100
    print(f"{grp.upper()}: n={n}, activated={activated}, rate={rate:.1f}%")

treatment = df[df['group'] == 'treatment']
control = df[df['group'] == 'control']

n_t, k_t = len(treatment), treatment['activated'].sum()
n_c, k_c = len(control), control['activated'].sum()
p_t = k_t / n_t
p_c = k_c / n_c

print(f"\nTreatment activation: {k_t}/{n_t} = {p_t:.1%}")
print(f"Control activation:   {k_c}/{n_c} = {p_c:.1%}")
print(f"Absolute lift: +{(p_t - p_c):.1%}")
print(f"Relative lift: +{((p_t - p_c) / p_c * 100):.0f}%" if p_c > 0 else "N/A (control=0)")

# --- Step 5: Fisher's exact test ---
contingency = [[k_t, n_t - k_t], [k_c, n_c - k_c]]
odds_ratio, p_value = stats.fisher_exact(contingency, alternative='two-sided')
print(f"\n=== FISHER'S EXACT TEST ===")
print(f"Contingency table: {contingency}")
print(f"Odds ratio: {odds_ratio:.2f}")
print(f"p-value: {p_value:.4f}")
print(f"Significant at alpha=0.05: {p_value < 0.05}")

# --- Step 6: Confidence interval for difference in proportions (Agresti-Caffo) ---
# Add 1 success and 1 failure to each group for better small-sample CI
n_t_adj = n_t + 2
n_c_adj = n_c + 2
p_t_adj = (k_t + 1) / n_t_adj
p_c_adj = (k_c + 1) / n_c_adj
diff = p_t_adj - p_c_adj
se = np.sqrt(p_t_adj * (1 - p_t_adj) / n_t_adj + p_c_adj * (1 - p_c_adj) / n_c_adj)
ci_low = diff - 1.96 * se
ci_high = diff + 1.96 * se
print(f"\n=== 95% CI FOR DIFFERENCE (Agresti-Caffo) ===")
print(f"Adjusted difference: {diff:.1%}")
print(f"95% CI: [{ci_low:.1%}, {ci_high:.1%}]")

# --- Step 7: Funnel analysis ---
print("\n=== FUNNEL ANALYSIS ===")
for grp in ['treatment', 'control']:
    g = df[df['group'] == grp]
    n = len(g)
    ob = g['onboarding_completed'].sum()
    proj = g['project_created'].sum()
    social = g['social_action'].sum()
    act = g['activated'].sum()
    print(f"\n{grp.upper()} (n={n}):")
    print(f"  Onboarding completed: {ob}/{n} ({ob/n:.0%})")
    print(f"  Project created:      {proj}/{n} ({proj/n:.0%})")
    print(f"  Social action:        {social}/{n} ({social/n:.0%})")
    print(f"  Fully activated:      {act}/{n} ({act/n:.0%})")

# --- Step 8: Onboarding duration ---
print("\n=== ONBOARDING DURATION (seconds) ===")
for grp in ['treatment', 'control']:
    g = df[(df['group'] == grp) & (df['onboarding_duration_sec'].notna())]
    durations = g['onboarding_duration_sec']
    if len(durations) > 0:
        print(f"{grp.upper()}: median={durations.median():.0f}s, mean={durations.mean():.0f}s, "
              f"min={durations.min():.0f}s, max={durations.max():.0f}s, n={len(durations)}")

# Mann-Whitney U test for onboarding duration
t_dur = df[(df['group'] == 'treatment') & (df['onboarding_duration_sec'].notna())]['onboarding_duration_sec']
c_dur = df[(df['group'] == 'control') & (df['onboarding_duration_sec'].notna())]['onboarding_duration_sec']
if len(t_dur) > 0 and len(c_dur) > 0:
    u_stat, u_pval = stats.mannwhitneyu(t_dur, c_dur, alternative='two-sided')
    print(f"Mann-Whitney U test: U={u_stat:.0f}, p={u_pval:.4f}")

# --- Step 9: Mobile bug impact (Nov 20-22, treatment, mobile) ---
print("\n=== MOBILE BUG IMPACT CHECK ===")
bug_start = pd.Timestamp('2025-11-20')
bug_end = pd.Timestamp('2025-11-22')
bug_affected = df[
    (df['group'] == 'treatment') &
    (df['platform'] == 'mobile') &
    (df['signup_date'] >= bug_start) &
    (df['signup_date'] <= bug_end)
]
print(f"Bug-window treatment mobile users: {len(bug_affected)}")
for _, row in bug_affected.iterrows():
    print(f"  {row['user_id']}: activated={row['activated']}, "
          f"onboarding_completed={row['onboarding_completed']}")

# --- Step 10: Sensitivity analysis excluding bug-affected users ---
print("\n=== SENSITIVITY: EXCLUDING BUG-AFFECTED USERS ===")
bug_ids = set(bug_affected['user_id'].values)
df_nobug = df[~df['user_id'].isin(bug_ids)]
t_nb = df_nobug[df_nobug['group'] == 'treatment']
c_nb = df_nobug[df_nobug['group'] == 'control']
n_t2, k_t2 = len(t_nb), t_nb['activated'].sum()
n_c2, k_c2 = len(c_nb), c_nb['activated'].sum()
p_t2 = k_t2 / n_t2
p_c2 = k_c2 / n_c2
print(f"Treatment: {k_t2}/{n_t2} = {p_t2:.1%}")
print(f"Control:   {k_c2}/{n_c2} = {p_c2:.1%}")
cont2 = [[k_t2, n_t2 - k_t2], [k_c2, n_c2 - k_c2]]
or2, pv2 = stats.fisher_exact(cont2, alternative='two-sided')
print(f"Fisher's exact: OR={or2:.2f}, p={pv2:.4f}")

# --- Step 11: Subgroup check by platform ---
print("\n=== SUBGROUP: PLATFORM ===")
for plat in ['web', 'mobile']:
    sub = df[df['platform'] == plat]
    for grp in ['treatment', 'control']:
        g = sub[sub['group'] == grp]
        if len(g) > 0:
            print(f"  {plat}/{grp}: {g['activated'].sum()}/{len(g)} = {g['activated'].mean():.0%}")

# --- Step 12: Subgroup check by region ---
print("\n=== SUBGROUP: REGION ===")
for reg in ['US', 'EU', 'APAC']:
    sub = df[df['region'] == reg]
    for grp in ['treatment', 'control']:
        g = sub[sub['group'] == grp]
        if len(g) > 0:
            print(f"  {reg}/{grp}: {g['activated'].sum()}/{len(g)} = {g['activated'].mean():.0%}")

print("\n=== ANALYSIS COMPLETE ===")
