import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import warnings
warnings.filterwarnings('ignore')

np.random.seed(42)

# ==============================
# Load all four scenarios
# ==============================
scenarios = {}
for sc in ['A','B','C','D']:
    cf_path = f'A_community_frame(1).csv' if sc=='A' else f'{sc}_community_frame.csv'
    scenarios[sc] = {
        'df': pd.read_csv(f'{sc}_household_population.csv'),
        'cf': pd.read_csv(cf_path),
    }

# ==============================
# True values per scenario
# ==============================
def get_truth(df):
    return {
        'mean_consume': df['total_consume'].mean(),
        'online_share': df['online_share_true'].mean(),
        'high_rate':    (df['income_level']=='high').mean(),
        'oshop_rate':   df['online_shop_flag'].mean(),
    }

for sc, d in scenarios.items():
    d['truth'] = get_truth(d['df'])

# ==============================
# Helper
# ==============================
def metrics(estimates, truth):
    est = np.array(estimates)
    bias = est.mean() - truth
    sd   = est.std(ddof=1)
    mse  = bias**2 + sd**2
    return {'mean_est': est.mean(), 'bias': bias, 'sd': sd, 'mse': mse}

# ==============================
# Five sampling designs
# ==============================
n_total = 200   # target sample size
m_psu   = 20    # PSUs for cluster/PPS
k_ssu   = 10    # HHs per PSU
B_rep   = 500   # replications

# --- Design 1: SRS ---
def srs_once(df, n):
    s = df.sample(n, replace=False)
    return {k: s[v].mean() if v!='high_rate' else (s['income_level']=='high').mean()
            for k,v in [('mean_consume','total_consume'),
                        ('online_share','online_share_true'),
                        ('oshop_rate','online_shop_flag')] } | \
           {'high_rate': (s['income_level']=='high').mean()}

def srs_once(df, n):
    s = df.sample(n, replace=False)
    return {
        'mean_consume': s['total_consume'].mean(),
        'online_share': s['online_share_true'].mean(),
        'high_rate':    (s['income_level']=='high').mean(),
        'oshop_rate':   s['online_shop_flag'].mean(),
    }

# --- Design 2: Stratified proportional ---
def make_n_h(df, n):
    N_h = df['district_type'].value_counts()
    n_h = (N_h / len(df) * n).round().astype(int)
    n_h.iloc[0] += n - n_h.sum()
    return n_h

def stratified_once(df, n_h):
    samples = []
    for d, nh in n_h.items():
        stratum = df[df['district_type']==d]
        samp = stratum.sample(int(nh), replace=False).copy()
        N_h = len(stratum)
        samp['w'] = N_h / nh
        samples.append(samp)
    s = pd.concat(samples)
    return {
        'mean_consume': np.average(s['total_consume'], weights=s['w']),
        'online_share': np.average(s['online_share_true'], weights=s['w']),
        'high_rate':    np.average((s['income_level']=='high').astype(float), weights=s['w']),
        'oshop_rate':   np.average(s['online_shop_flag'], weights=s['w']),
    }

# --- Design 3: Stratified + Two-Stage ---
def strat_two_stage_once(df, cf, m_h, k):
    samples = []
    for d, mh in m_h.items():
        comm_d = cf[cf['district_type']==d]
        M_h = len(comm_d)
        mh = min(int(mh), M_h)
        sel_comm = comm_d.sample(mh, replace=False)
        for _, c in sel_comm.iterrows():
            hh = df[df['community_id']==c['community_id']].copy()
            M_i = len(hh)
            actual_k = min(k, M_i)
            samp = hh.sample(actual_k, replace=False).copy()
            pi = (mh/M_h)*(actual_k/M_i)
            samp['w'] = 1.0/pi
            samples.append(samp)
    s = pd.concat(samples)
    return {
        'mean_consume': np.average(s['total_consume'], weights=s['w']),
        'online_share': np.average(s['online_share_true'], weights=s['w']),
        'high_rate':    np.average((s['income_level']=='high').astype(float), weights=s['w']),
        'oshop_rate':   np.average(s['online_shop_flag'], weights=s['w']),
    }

# --- Design 4: Two-Stage Cluster ---
def cluster_once(df, cf, m, k):
    sel = cf.sample(m, replace=False)
    M_total = len(cf)
    samples = []
    for _, c in sel.iterrows():
        hh = df[df['community_id']==c['community_id']].copy()
        M_i = len(hh)
        actual_k = min(k, M_i)
        samp = hh.sample(actual_k, replace=False).copy()
        pi = (m/M_total)*(actual_k/M_i)
        samp['w'] = 1.0/pi
        samples.append(samp)
    s = pd.concat(samples)
    return {
        'mean_consume': np.average(s['total_consume'], weights=s['w']),
        'online_share': np.average(s['online_share_true'], weights=s['w']),
        'high_rate':    np.average((s['income_level']=='high').astype(float), weights=s['w']),
        'oshop_rate':   np.average(s['online_shop_flag'], weights=s['w']),
    }

# --- Design 5: PPS ---
def pps_once(df, cf, m, k):
    sizes = cf['community_size'].values.astype(float)
    probs = sizes/sizes.sum()
    idx = np.random.choice(len(cf), size=m, replace=True, p=probs)
    samples = []
    for i in idx:
        c = cf.iloc[i]
        hh = df[df['community_id']==c['community_id']].copy()
        M_i = len(hh)
        actual_k = min(k, M_i)
        samp = hh.sample(actual_k, replace=False).copy()
        pi_i = m * sizes[i]/sizes.sum()
        samp['w'] = (1.0/pi_i)*(M_i/actual_k)
        samples.append(samp)
    s = pd.concat(samples)
    return {
        'mean_consume': np.average(s['total_consume'], weights=s['w']),
        'online_share': np.average(s['online_share_true'], weights=s['w']),
        'high_rate':    np.average((s['income_level']=='high').astype(float), weights=s['w']),
        'oshop_rate':   np.average(s['online_shop_flag'], weights=s['w']),
    }

# ==============================
# Run simulations for all 4 scenarios
# ==============================
design_names = ['SRS', 'Stratified', 'Strat+TwoStage', 'TwoStageCluster', 'PPS']
param_names  = ['mean_consume','online_share','high_rate','oshop_rate']
param_labels = ['Monthly Consume Mean','Online Share','High Income Rate','Online Shop Rate']

all_results = {}

for sc, d in scenarios.items():
    df_s = d['df']
    cf_s = d['cf']
    truth_s = d['truth']
    N_s = len(df_s)
    n_h_s = make_n_h(df_s, n_total)
    m_h_s = make_n_h(df_s, m_psu)  # distribute PSUs proportionally

    print(f"Running scenario {sc}...")
    res = {
        'SRS':            [srs_once(df_s, n_total) for _ in range(B_rep)],
        'Stratified':     [stratified_once(df_s, n_h_s) for _ in range(B_rep)],
        'Strat+TwoStage': [strat_two_stage_once(df_s, cf_s, m_h_s, k_ssu) for _ in range(B_rep)],
        'TwoStageCluster':[cluster_once(df_s, cf_s, m_psu, k_ssu) for _ in range(B_rep)],
        'PPS':            [pps_once(df_s, cf_s, m_psu, k_ssu) for _ in range(B_rep)],
    }

    # compute metrics
    sc_metrics = {}
    for dname in design_names:
        sc_metrics[dname] = {}
        for pname in param_names:
            ests = [r[pname] for r in res[dname]]
            sc_metrics[dname][pname] = metrics(ests, truth_s[pname])

    all_results[sc] = {
        'raw': res,
        'metrics': sc_metrics,
        'truth': truth_s,
        'n_h': n_h_s,
    }

# ==============================
# Print summary tables
# ==============================
for sc in ['A','B','C','D']:
    print(f"\n========== SCENARIO {sc} ==========")
    truth_s = all_results[sc]['truth']
    print(f"True monthly mean: {truth_s['mean_consume']:.2f}")
    print("Design            | Bias      | SD       | MSE")
    for d in design_names:
        m = all_results[sc]['metrics'][d]['mean_consume']
        print(f"  {d:18s}| {m['bias']:+8.2f}  | {m['sd']:8.2f} | {m['mse']:10.2f}")

# ==============================
# DEFF: cluster vs SRS per scenario
# ==============================
print("\n========== DEFF (TwoStageCluster vs SRS) ==========")
for sc in ['A','B','C','D']:
    deffs = {}
    for pname, plabel in zip(param_names, param_labels):
        var_srs = all_results[sc]['metrics']['SRS'][pname]['sd']**2
        var_clu = all_results[sc]['metrics']['TwoStageCluster'][pname]['sd']**2
        deffs[plabel] = var_clu/var_srs if var_srs>0 else np.nan
    print(f"Scenario {sc}: {', '.join([f'{k}={v:.3f}' for k,v in deffs.items()])}")

# ==============================
# Non-response bias analysis (scenario B)
# ==============================
print("\n========== NONRESPONSE BIAS ANALYSIS (Scenario B) ==========")
df_B = scenarios['B']['df']
raw_ests, wt_ests = [], []
for _ in range(500):
    s = df_B.sample(n_total, replace=False).copy()
    s['responded'] = np.random.binomial(1, s['response_prob'])
    resp = s[s['responded']==1]
    if len(resp) < 5: continue
    raw_ests.append(resp['total_consume'].mean())
    resp = resp.copy()
    resp['adj_w'] = 1.0/resp['response_prob']
    wt_ests.append(np.average(resp['total_consume'], weights=resp['adj_w']))
truth_B = all_results['B']['truth']['mean_consume']
m_raw = metrics(raw_ests, truth_B)
m_wt  = metrics(wt_ests, truth_B)
print(f"  Unweighted: bias={m_raw['bias']:.2f}, sd={m_raw['sd']:.2f}, mse={m_raw['mse']:.2f}")
print(f"  IPW weighted: bias={m_wt['bias']:.2f}, sd={m_wt['sd']:.2f}, mse={m_wt['mse']:.2f}")

# ==============================
# Cost analysis
# ==============================
def compute_cost(design, df, n_h=None, m=20, k=10):
    n_outer = int((df['district_type']=='outer').mean() * n_total)
    if design in ('SRS','Stratified'):
        return n_total*20 + n_outer*15
    else:
        n_outer_hh = int((df['district_type']=='outer').mean() * m * k)
        return m*100 + m*k*20 + n_outer_hh*15

print("\n========== COST ANALYSIS ==========")
for sc in ['A','B','C','D']:
    df_s = scenarios[sc]['df']
    print(f"Scenario {sc}:")
    for d in design_names:
        c = compute_cost(d, df_s, m=m_psu, k=k_ssu)
        mse = all_results[sc]['metrics'][d]['mean_consume']['mse']
        print(f"  {d:18s}: cost={c}, MSE={mse:.2f}")

# ==============================
# Scenario classification table
# ==============================
print("\n========== SCENARIO CLASSIFICATION ==========")
sc_chars = {
    'A': {'between_SD': 806.5, 'ICC': 0.2579, 'size_CV': 0.280},
    'B': {'between_SD': 431.5, 'ICC': 0.1133, 'size_CV': 0.271},
    'C': {'between_SD': 551.4, 'ICC': 0.3402, 'size_CV': 0.289},
    'D': {'between_SD': 606.0, 'ICC': 0.1916, 'size_CV': 0.610},
}
for sc in ['A','B','C','D']:
    mses = {d: all_results[sc]['metrics'][d]['mean_consume']['mse'] for d in design_names}
    best = min(mses, key=mses.get)
    ranked = sorted(mses.items(), key=lambda x: x[1])
    print(f"Scenario {sc}: best={best}, ranking={[x[0] for x in ranked]}")
    print(f"  Chars: {sc_chars[sc]}")

# ==============================
# Main figure: 4 scenarios x 2 plots (Scenario B boxplot + MSE comparison)
# ==============================
colors = ['#2196F3','#4CAF50','#FF9800','#E91E63','#9C27B0']
design_short = ['SRS','Stratified','Strat+2Stage','2StageCluster','PPS']

fig, axes = plt.subplots(4, 3, figsize=(18, 22))
fig.suptitle('Sampling Design Performance: Scenarios A-D (500 Replications, n=200)',
             fontsize=14, fontweight='bold', y=0.98)

for row, sc in enumerate(['A','B','C','D']):
    truth_mc = all_results[sc]['truth']['mean_consume']
    res = all_results[sc]['raw']
    met = all_results[sc]['metrics']

    # Col 0: Boxplot of monthly consume estimates
    ax = axes[row, 0]
    box_data = [[r['mean_consume'] for r in res[d]] for d in design_names]
    bp = ax.boxplot(box_data, patch_artist=True, notch=False, widths=0.55)
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color); patch.set_alpha(0.75)
    ax.axhline(truth_mc, color='red', linestyle='--', linewidth=1.5, label=f'True={truth_mc:.0f}')
    ax.set_xticklabels(design_short, rotation=20, fontsize=8, ha='right')
    ax.set_title(f'Scenario {sc}: Monthly Consume Estimates', fontsize=10)
    ax.set_ylabel('Estimate (yuan)')
    ax.legend(fontsize=8)

    # Col 1: MSE bar chart (all 4 params, relative to SRS=1)
    ax = axes[row, 1]
    x = np.arange(len(design_names))
    bar_w = 0.18
    param_colors = ['#1a1a2e','#16213e','#0f3460','#533483']
    for i, (pname, plabel) in enumerate(zip(param_names, param_labels)):
        mses = [met[d][pname]['mse'] for d in design_names]
        srs_mse = mses[0] if mses[0]>0 else 1
        rel = [v/srs_mse for v in mses]
        ax.bar(x + i*bar_w, rel, bar_w, label=plabel, color=param_colors[i], alpha=0.8)
    ax.set_xticks(x + bar_w*1.5)
    ax.set_xticklabels(design_short, rotation=20, fontsize=8, ha='right')
    ax.axhline(1.0, color='gray', linestyle=':', alpha=0.8)
    ax.set_title(f'Scenario {sc}: Relative MSE (SRS=1)', fontsize=10)
    ax.set_ylabel('Relative MSE')
    ax.legend(fontsize=6)

    # Col 2: Bias comparison for monthly consume
    ax = axes[row, 2]
    biases = [met[d]['mean_consume']['bias'] for d in design_names]
    sds    = [met[d]['mean_consume']['sd'] for d in design_names]
    x2 = np.arange(len(design_names))
    bars = ax.bar(x2, biases, color=colors, alpha=0.8, zorder=3)
    ax.errorbar(x2, biases, yerr=sds, fmt='none', color='black', capsize=4, linewidth=1.5, zorder=4)
    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_xticks(x2); ax.set_xticklabels(design_short, rotation=20, fontsize=8, ha='right')
    ax.set_title(f'Scenario {sc}: Bias ± SD (Monthly Consume)', fontsize=10)
    ax.set_ylabel('Bias (yuan)')

plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.savefig('sampling_comparison_v2.png', dpi=150, bbox_inches='tight')
print("\nMain figure saved: sampling_comparison_v2.png")

# ==============================
# Summary figure: cross-scenario comparison
# ==============================
fig2, axes2 = plt.subplots(1, 3, figsize=(18, 6))
fig2.suptitle('Cross-Scenario Summary: Structural Characteristics vs. Design Efficiency',
              fontsize=13, fontweight='bold')

sc_labels = ['A\n(High\nDistrict Het.)', 'B\n(Baseline)', 'C\n(High ICC)', 'D\n(High Size CV)']
sc_chars_vals = {
    'A': {'between_SD': 806.5, 'ICC': 0.258, 'size_CV': 0.280},
    'B': {'between_SD': 431.5, 'ICC': 0.113, 'size_CV': 0.271},
    'C': {'between_SD': 551.4, 'ICC': 0.340, 'size_CV': 0.289},
    'D': {'between_SD': 606.0, 'ICC': 0.192, 'size_CV': 0.610},
}

# Plot 1: SRS vs Stratified MSE ratio across scenarios
ax = axes2[0]
ratios_strat = []
ratios_cluster = []
ratios_pps = []
for sc in ['A','B','C','D']:
    mse_srs   = all_results[sc]['metrics']['SRS']['mean_consume']['mse']
    mse_strat = all_results[sc]['metrics']['Stratified']['mean_consume']['mse']
    mse_clu   = all_results[sc]['metrics']['TwoStageCluster']['mean_consume']['mse']
    mse_pps   = all_results[sc]['metrics']['PPS']['mean_consume']['mse']
    ratios_strat.append(mse_srs/mse_strat)
    ratios_cluster.append(mse_srs/mse_clu)
    ratios_pps.append(mse_srs/mse_pps)

x3 = np.arange(4)
ax.bar(x3-0.25, ratios_strat,  0.22, label='SRS/Stratified',     color='#4CAF50', alpha=0.85)
ax.bar(x3,      ratios_cluster,0.22, label='SRS/2StageCluster',   color='#E91E63', alpha=0.85)
ax.bar(x3+0.25, ratios_pps,    0.22, label='SRS/PPS',             color='#9C27B0', alpha=0.85)
ax.axhline(1.0, color='gray', linestyle='--', linewidth=1.2)
ax.set_xticks(x3); ax.set_xticklabels(sc_labels, fontsize=9)
ax.set_title('Efficiency Ratio vs SRS\n(>1 = better than SRS)')
ax.set_ylabel('MSE(SRS)/MSE(design)')
ax.legend(fontsize=8)

# Plot 2: DEFF across scenarios
ax = axes2[1]
deffs_all = []
for sc in ['A','B','C','D']:
    var_srs = all_results[sc]['metrics']['SRS']['mean_consume']['sd']**2
    var_clu = all_results[sc]['metrics']['TwoStageCluster']['mean_consume']['sd']**2
    deffs_all.append(var_clu/var_srs)
ax.bar(x3, deffs_all, color='#FF9800', alpha=0.85)
ax.axhline(1.0, color='red', linestyle='--', linewidth=1.5, label='DEFF=1')
for i, v in enumerate(deffs_all): ax.text(i, v+0.03, f'{v:.2f}', ha='center', fontsize=10)
ax.set_xticks(x3); ax.set_xticklabels(sc_labels, fontsize=9)
ax.set_title('Design Effect (DEFF)\n2-Stage Cluster vs SRS\n(Monthly Consume)')
ax.set_ylabel('DEFF'); ax.legend(fontsize=9)

# Plot 3: Structural characteristics radar-like bar
ax = axes2[2]
chars_between = [sc_chars_vals[sc]['between_SD']/max(sc_chars_vals[s]['between_SD'] for s in 'ABCD') for sc in 'ABCD']
chars_icc     = [sc_chars_vals[sc]['ICC']/max(sc_chars_vals[s]['ICC'] for s in 'ABCD') for sc in 'ABCD']
chars_sizeCV  = [sc_chars_vals[sc]['size_CV']/max(sc_chars_vals[s]['size_CV'] for s in 'ABCD') for sc in 'ABCD']
ax.bar(x3-0.25, chars_between, 0.22, label='District heterogeneity (norm.)', color='#2196F3', alpha=0.85)
ax.bar(x3,      chars_icc,     0.22, label='Intra-cluster correlation (norm.)', color='#FF9800', alpha=0.85)
ax.bar(x3+0.25, chars_sizeCV,  0.22, label='Community size CV (norm.)', color='#9C27B0', alpha=0.85)
ax.set_xticks(x3); ax.set_xticklabels(sc_labels, fontsize=9)
ax.set_title('Normalized Structural Characteristics\nby Scenario')
ax.set_ylabel('Normalized value (0-1)'); ax.legend(fontsize=8)

plt.tight_layout()
plt.savefig('sampling_summary_v2.png', dpi=150, bbox_inches='tight')
print("Summary figure saved: sampling_summary_v2.png")

# ==============================
# Print all numbers for markdown
# ==============================
print("\n====== ALL METRICS FOR MARKDOWN ======")
for sc in ['A','B','C','D']:
    truth_s = all_results[sc]['truth']
    print(f"\nScenario {sc} (true mean={truth_s['mean_consume']:.2f}):")
    for d in design_names:
        m = all_results[sc]['metrics'][d]['mean_consume']
        print(f"  {d:20s} bias={m['bias']:+7.2f} sd={m['sd']:7.2f} mse={m['mse']:9.2f}")

print("\n====== DEFF TABLE ======")
for sc in ['A','B','C','D']:
    var_srs = all_results[sc]['metrics']['SRS']['mean_consume']['sd']**2
    var_clu = all_results[sc]['metrics']['TwoStageCluster']['mean_consume']['sd']**2
    var_pps = all_results[sc]['metrics']['PPS']['mean_consume']['sd']**2
    print(f"  {sc}: DEFF_cluster={var_clu/var_srs:.3f}, DEFF_pps={var_pps/var_srs:.3f}")

print("\n====== NONRESPONSE (Scenario B) ======")
print(f"  Unweighted: bias={m_raw['bias']:.2f}, sd={m_raw['sd']:.2f}, mse={m_raw['mse']:.2f}")
print(f"  IPW: bias={m_wt['bias']:.2f}, sd={m_wt['sd']:.2f}, mse={m_wt['mse']:.2f}")

print("\nDone!")
