import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# ==============================
# Load data
# ==============================
np.random.seed(42)

df_B = pd.read_csv('B_household_population.csv')
cf_B = pd.read_csv('B_community_frame.csv')
df_A = pd.read_csv('A_household_population.csv')
cf_A = pd.read_csv('A_community_frame(1).csv')
df_C = pd.read_csv('C_household_population.csv')
cf_C = pd.read_csv('C_community_frame.csv')

# ==============================
# True values for B
# ==============================
TRUE_MEAN_B   = df_B['total_consume'].mean()
TRUE_ONLINE_B = df_B['online_share_true'].mean()
TRUE_HIGH_B   = (df_B['income_level'] == 'high').mean()
TRUE_OSHOP_B  = df_B['online_shop_flag'].mean()
TRUE_DISTRICT_B = df_B.groupby('district_type')['total_consume'].mean()

print(f"True mean (B): {TRUE_MEAN_B:.2f}")
print(f"True online share (B): {TRUE_ONLINE_B:.4f}")
print(f"True high income rate (B): {TRUE_HIGH_B:.4f}")
print(f"District means (B):\n{TRUE_DISTRICT_B}")

N = len(df_B)
n = 200  # total sample size
B = 500  # replications

# ==============================
# Helper
# ==============================
def metrics(estimates, truth):
    est = np.array(estimates)
    bias = est.mean() - truth
    sd   = est.std(ddof=1)
    mse  = bias**2 + sd**2
    return {'mean_est': est.mean(), 'bias': bias, 'sd': sd, 'mse': mse}

# ==============================
# DESIGN 1: SRS
# ==============================
def srs_once(df, n):
    s = df.sample(n, replace=False)
    return {
        'mean_consume': s['total_consume'].mean(),
        'online_share': s['online_share_true'].mean(),
        'high_rate':    (s['income_level'] == 'high').mean(),
        'oshop_rate':   s['online_shop_flag'].mean(),
    }

print("Running SRS...")
srs_results = [srs_once(df_B, n) for _ in range(B)]

# ==============================
# DESIGN 2: Stratified Sampling (by district, proportional)
# ==============================
districts = df_B['district_type'].unique()
N_h = df_B['district_type'].value_counts()
n_h = (N_h / N * n).round().astype(int)
# ensure sum == n
diff = n - n_h.sum()
n_h.iloc[0] += diff

def stratified_once(df, n_h):
    samples = []
    for d, nh in n_h.items():
        stratum = df[df['district_type'] == d]
        samples.append(stratum.sample(nh, replace=False))
    s = pd.concat(samples)
    # HT estimate
    N_h_all = df['district_type'].value_counts()
    wts = []
    for idx, row in s.iterrows():
        d = row['district_type']
        wts.append(N_h_all[d] / n_h[d])
    s = s.copy()
    s['w'] = wts
    wtd_mean = np.average(s['total_consume'], weights=s['w'])
    wtd_online = np.average(s['online_share_true'], weights=s['w'])
    wtd_high = np.average((s['income_level'] == 'high').astype(float), weights=s['w'])
    wtd_oshop = np.average(s['online_shop_flag'], weights=s['w'])
    return {
        'mean_consume': wtd_mean,
        'online_share': wtd_online,
        'high_rate':    wtd_high,
        'oshop_rate':   wtd_oshop,
    }

print("Running Stratified...")
strat_results = [stratified_once(df_B, n_h) for _ in range(B)]

# ==============================
# DESIGN 3: Stratified + Two-Stage Sampling
# ==============================
# Within each stratum, select m_h PSUs, then k households per PSU
# m_h proportional to N_h; k = 10 per PSU
k = 10
m_total = n // k  # 20 PSUs total
m_h = (N_h / N * m_total).round().astype(int)
diff2 = m_total - m_h.sum()
m_h.iloc[0] += diff2

def strat_two_stage_once(df, cf, m_h, k):
    samples = []
    for d, mh in m_h.items():
        comm_d = cf[cf['district_type'] == d]
        if mh > len(comm_d):
            mh = len(comm_d)
        sel_comm = comm_d.sample(mh, replace=False)
        for _, c in sel_comm.iterrows():
            hh_in_c = df[(df['district_type'] == d) & (df['community_id'] == c['community_id'])]
            actual_k = min(k, len(hh_in_c))
            samp = hh_in_c.sample(actual_k, replace=False).copy()
            # weight = (N_h / m_h) * (M_i / k) where M_i = community_size
            N_h_val = len(df[df['district_type'] == d])
            M_i = c['community_size']
            M_h = len(comm_d)
            w = (N_h_val / mh) * (M_i / actual_k)
            samp['w'] = w
            samples.append(samp)
    s = pd.concat(samples)
    total_w = s['w'].sum()
    norm_w = s['w'] / total_w
    return {
        'mean_consume': np.dot(s['total_consume'], norm_w) * N,
        'online_share': np.average(s['online_share_true'], weights=s['w']),
        'high_rate':    np.average((s['income_level'] == 'high').astype(float), weights=s['w']),
        'oshop_rate':   np.average(s['online_shop_flag'], weights=s['w']),
    }

# Fix: stratified two-stage
def strat_two_stage_once_v2(df, cf, m_h, k):
    samples = []
    for d, mh in m_h.items():
        comm_d = cf[cf['district_type'] == d].copy()
        M_h = len(comm_d)  # number of PSUs in stratum
        N_h_val = len(df[df['district_type'] == d])
        if mh > M_h:
            mh = M_h
        sel_comm = comm_d.sample(mh, replace=False)
        for _, c in sel_comm.iterrows():
            hh_in_c = df[(df['community_id'] == c['community_id'])].copy()
            M_i = len(hh_in_c)
            actual_k = min(k, M_i)
            samp = hh_in_c.sample(actual_k, replace=False).copy()
            # Design weight = inclusion prob inverse
            # pi_1i = mh/M_h, pi_2j|i = k/M_i
            pi = (mh/M_h) * (actual_k/M_i)
            samp['w'] = 1.0 / pi
            samples.append(samp)
    s = pd.concat(samples)
    wtd_mean = np.average(s['total_consume'], weights=s['w'])
    wtd_online = np.average(s['online_share_true'], weights=s['w'])
    wtd_high = np.average((s['income_level']=='high').astype(float), weights=s['w'])
    wtd_oshop = np.average(s['online_shop_flag'], weights=s['w'])
    return {
        'mean_consume': wtd_mean,
        'online_share': wtd_online,
        'high_rate':    wtd_high,
        'oshop_rate':   wtd_oshop,
    }

print("Running Stratified+TwoStage...")
strat2_results = [strat_two_stage_once_v2(df_B, cf_B, m_h, k) for _ in range(B)]

# ==============================
# DESIGN 4: Two-Stage Cluster Sampling
# ==============================
m_cluster = 20  # PSUs
k_cluster = 10  # SSUs per PSU

def two_stage_cluster_once(df, cf, m, k):
    sel_comm = cf.sample(m, replace=False)
    samples = []
    for _, c in sel_comm.iterrows():
        hh_in_c = df[df['community_id'] == c['community_id']].copy()
        M_i = len(hh_in_c)
        actual_k = min(k, M_i)
        samp = hh_in_c.sample(actual_k, replace=False).copy()
        M_total = len(cf)
        pi = (m/M_total) * (actual_k/M_i)
        samp['w'] = 1.0/pi
        samples.append(samp)
    s = pd.concat(samples)
    return {
        'mean_consume': np.average(s['total_consume'], weights=s['w']),
        'online_share': np.average(s['online_share_true'], weights=s['w']),
        'high_rate':    np.average((s['income_level']=='high').astype(float), weights=s['w']),
        'oshop_rate':   np.average(s['online_shop_flag'], weights=s['w']),
    }

print("Running TwoStage Cluster...")
cluster_results = [two_stage_cluster_once(df_B, cf_B, m_cluster, k_cluster) for _ in range(B)]

# ==============================
# DESIGN 5: PPS Sampling (with replacement, by community size)
# ==============================
def pps_once(df, cf, m, k):
    sizes = cf['community_size'].values
    probs = sizes / sizes.sum()
    idx = np.random.choice(len(cf), size=m, replace=True, p=probs)
    sel_comm = cf.iloc[idx]
    samples = []
    for _, c in sel_comm.iterrows():
        hh_in_c = df[df['community_id'] == c['community_id']].copy()
        M_i = len(hh_in_c)
        actual_k = min(k, M_i)
        samp = hh_in_c.sample(actual_k, replace=False).copy()
        pi_i = m * M_i / sizes.sum()
        w = (1.0/pi_i) * (M_i/actual_k)
        samp['w'] = w
        samples.append(samp)
    s = pd.concat(samples)
    return {
        'mean_consume': np.average(s['total_consume'], weights=s['w']),
        'online_share': np.average(s['online_share_true'], weights=s['w']),
        'high_rate':    np.average((s['income_level']=='high').astype(float), weights=s['w']),
        'oshop_rate':   np.average(s['online_shop_flag'], weights=s['w']),
    }

print("Running PPS...")
pps_results = [pps_once(df_B, cf_B, m_cluster, k_cluster) for _ in range(B)]

# ==============================
# Compile Results
# ==============================
designs = {
    'SRS': srs_results,
    '分层抽样': strat_results,
    '分层+两阶段': strat2_results,
    '两阶段整群': cluster_results,
    'PPS': pps_results,
}

targets = {
    'mean_consume': (TRUE_MEAN_B, '月消费均值'),
    'online_share': (TRUE_ONLINE_B, '在线消费占比'),
    'high_rate':    (TRUE_HIGH_B, '高收入占比'),
    'oshop_rate':   (TRUE_OSHOP_B, '网购参与率'),
}

print("\n======= RESULTS =======")
summary = {}
for tname, (truth, label) in targets.items():
    print(f"\n--- {label} (真值={truth:.4f}) ---")
    summary[tname] = {}
    for dname, res in designs.items():
        ests = [r[tname] for r in res]
        m = metrics(ests, truth)
        summary[tname][dname] = m
        print(f"  {dname:10s}: mean={m['mean_est']:.4f}, bias={m['bias']:.4f}, sd={m['sd']:.4f}, mse={m['mse']:.6f}")

# ==============================
# DEFF computation (cluster vs SRS)
# ==============================
print("\n======= Design Effects =======")
for tname, (truth, label) in targets.items():
    var_srs = summary[tname]['SRS']['sd']**2
    var_clu = summary[tname]['两阶段整群']['sd']**2
    deff = var_clu / var_srs if var_srs > 0 else np.nan
    print(f"  DEFF ({label}): {deff:.4f}")

# ==============================
# SCENARIO ANALYSIS: A vs B vs C
# ==============================
print("\n======= SCENARIO ANALYSIS =======")
datasets = {
    'A（区域差异强）': (df_A, cf_A),
    'B（基准情景）': (df_B, cf_B),
    'C（区域差异弱）': (df_C, cf_C),
}

scenario_results = {}
for sname, (df_s, cf_s) in datasets.items():
    truth_s = df_s['total_consume'].mean()
    N_h_s = df_s['district_type'].value_counts()
    n_h_s = (N_h_s / len(df_s) * n).round().astype(int)
    diff_s = n - n_h_s.sum()
    n_h_s.iloc[0] += diff_s

    srs_s   = [srs_once(df_s, n) for _ in range(200)]
    strat_s = [stratified_once(df_s, n_h_s) for _ in range(200)]

    mse_srs   = metrics([r['mean_consume'] for r in srs_s], truth_s)['mse']
    mse_strat = metrics([r['mean_consume'] for r in strat_s], truth_s)['mse']
    ratio = mse_srs / mse_strat if mse_strat > 0 else np.nan

    # compute between-district variance for this dataset
    dist_means = df_s.groupby('district_type')['total_consume'].mean()
    overall_mean = df_s['total_consume'].mean()
    between_var = ((dist_means - overall_mean)**2).mean()

    scenario_results[sname] = {
        'truth': truth_s,
        'mse_srs': mse_srs,
        'mse_strat': mse_strat,
        'ratio': ratio,
        'between_var': between_var,
    }
    print(f"  {sname}: between_var={between_var:.1f}, MSE_SRS={mse_srs:.2f}, MSE_STRAT={mse_strat:.2f}, ratio={ratio:.3f}")

# ==============================
# BIAS ANALYSIS: Non-response
# ==============================
print("\n======= BIAS ANALYSIS (Non-response) =======")
def simulate_nonresponse(df, n, B_rep=300):
    raw_ests = []
    wt_ests  = []
    for _ in range(B_rep):
        s = df.sample(n, replace=False).copy()
        # simulate response based on response_prob
        s['responded'] = np.random.binomial(1, s['response_prob'])
        resp = s[s['responded'] == 1]
        if len(resp) == 0:
            continue
        raw_est = resp['total_consume'].mean()
        # weight adjustment: inverse of response prob
        resp = resp.copy()
        resp['adj_w'] = 1.0 / resp['response_prob']
        wt_est = np.average(resp['total_consume'], weights=resp['adj_w'])
        raw_ests.append(raw_est)
        wt_ests.append(wt_est)
    truth = df['total_consume'].mean()
    m_raw = metrics(raw_ests, truth)
    m_wt  = metrics(wt_ests, truth)
    return m_raw, m_wt

m_raw, m_wt = simulate_nonresponse(df_B, n)
print(f"  未加权: bias={m_raw['bias']:.4f}, sd={m_raw['sd']:.4f}, mse={m_raw['mse']:.4f}")
print(f"  加权后: bias={m_wt['bias']:.4f},  sd={m_wt['sd']:.4f},  mse={m_wt['mse']:.4f}")

# ==============================
# COST ANALYSIS
# ==============================
C_ENTER = 100   # enter new community
C_HH    = 20    # per household
C_OUTER_EXTRA = 15  # extra for outer district
C_FOLLOWUP = 10 # follow-up per HH

def compute_cost(design_name, df, n=200, m=20, k=10, n_h=None):
    if design_name == 'SRS':
        n_outer = int(n * (df['district_type']=='outer').mean())
        cost = n * C_HH + n_outer * C_OUTER_EXTRA
        return cost, n
    elif design_name == '分层抽样':
        n_outer = n_h.get('outer', n_h.get('outer', 0))
        cost = n * C_HH + int(n_outer) * C_OUTER_EXTRA
        return cost, n
    elif design_name in ('分层+两阶段', '两阶段整群'):
        n_comm = m
        n_outer_hh = int(k * (df['district_type']=='outer').mean() * m)
        cost = n_comm * C_ENTER + m*k * C_HH + n_outer_hh * C_OUTER_EXTRA
        return cost, m*k
    elif design_name == 'PPS':
        n_comm = m
        n_outer_hh = int(k * (df['district_type']=='outer').mean() * m)
        cost = n_comm * C_ENTER + m*k * C_HH + n_outer_hh * C_OUTER_EXTRA
        return cost, m*k

print("\n======= COST ANALYSIS =======")
n_h_outer = n_h['outer']
cost_data = []
for d in ['SRS', '分层抽样', '分层+两阶段', '两阶段整群', 'PPS']:
    c, eff_n = compute_cost(d, df_B, n=n, m=m_cluster, k=k_cluster, n_h=n_h)
    mse_val = summary['mean_consume'].get(d, {}).get('mse', np.nan)
    cost_data.append({'设计': d, '成本(元)': c, '实际样本量': eff_n, 'MSE(月消费均值)': round(mse_val, 2)})
    print(f"  {d}: 成本={c}元, n={eff_n}, MSE={mse_val:.2f}")

# ==============================
# Plotting
# ==============================
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
fig.suptitle('五种抽样设计性能比较（情景B，500次重复模拟）', fontsize=14, fontweight='bold')

colors = ['#2196F3', '#4CAF50', '#FF9800', '#E91E63', '#9C27B0']
design_names = list(designs.keys())

# Plot 1: Boxplot of mean_consume estimates
ax = axes[0, 0]
box_data = [[r['mean_consume'] for r in designs[d]] for d in design_names]
bp = ax.boxplot(box_data, patch_artist=True, notch=False)
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)
ax.axhline(TRUE_MEAN_B, color='red', linestyle='--', linewidth=1.5, label=f'真值={TRUE_MEAN_B:.1f}')
ax.set_xticklabels(design_names, rotation=15, fontsize=8)
ax.set_title('月消费均值估计分布')
ax.set_ylabel('估计值（元）')
ax.legend(fontsize=8)

# Plot 2: MSE comparison bar chart
ax = axes[0, 1]
target_names = ['mean_consume', 'online_share', 'high_rate', 'oshop_rate']
target_labels = ['月消费均值', '在线消费占比', '高收入占比', '网购参与率']
x = np.arange(len(design_names))
bar_width = 0.2
for i, (tn, tl) in enumerate(zip(target_names, target_labels)):
    mses = [summary[tn][d]['mse'] for d in design_names]
    # normalize to SRS=1 for visualization
    mses_norm = [m/mses[0] if mses[0] > 0 else 1 for m in mses]
    ax.bar(x + i*bar_width, mses_norm, bar_width, label=tl, alpha=0.8)
ax.set_xticks(x + bar_width*1.5)
ax.set_xticklabels(design_names, rotation=15, fontsize=8)
ax.set_title('各设计相对MSE（SRS=1）')
ax.set_ylabel('相对MSE')
ax.axhline(1.0, color='gray', linestyle=':', alpha=0.7)
ax.legend(fontsize=7)

# Plot 3: Bias comparison
ax = axes[0, 2]
biases = {d: [summary[tn][d]['bias'] for tn in target_names] for d in design_names}
x = np.arange(len(target_labels))
bar_width = 0.15
for i, (d, c) in enumerate(zip(design_names, colors)):
    ax.bar(x + i*bar_width, biases[d], bar_width, label=d, color=c, alpha=0.8)
ax.set_xticks(x + bar_width*2)
ax.set_xticklabels(target_labels, rotation=10, fontsize=8)
ax.set_title('各设计偏差（Bias）比较')
ax.set_ylabel('偏差')
ax.axhline(0, color='black', linewidth=0.8)
ax.legend(fontsize=7)

# Plot 4: Scenario analysis - between_var vs MSE ratio
ax = axes[1, 0]
sc_names = list(scenario_results.keys())
bvars = [scenario_results[s]['between_var'] for s in sc_names]
ratios = [scenario_results[s]['ratio'] for s in sc_names]
ax.scatter(bvars, ratios, s=120, c=['#2196F3', '#4CAF50', '#FF9800'], zorder=5)
for i, s in enumerate(sc_names):
    ax.annotate(s, (bvars[i], ratios[i]), textcoords='offset points', xytext=(5,5), fontsize=8)
ax.axhline(1.0, color='gray', linestyle='--', alpha=0.7, label='SRS=分层')
ax.set_xlabel('区域间方差')
ax.set_ylabel('MSE(SRS)/MSE(分层)')
ax.set_title('情景分析：区域异质性 vs 分层效率')
ax.legend(fontsize=8)

# Plot 5: Non-response bias analysis
ax = axes[1, 1]
categories = ['未加权', '加权后']
bias_vals  = [m_raw['bias'], m_wt['bias']]
sd_vals    = [m_raw['sd'], m_wt['sd']]
mse_vals   = [m_raw['mse'], m_wt['mse']]
x = np.arange(3)
labels = ['Bias', 'SD', 'MSE']
raw_vals = [m_raw['bias'], m_raw['sd'], m_raw['mse']]
wt_vals  = [m_wt['bias'],  m_wt['sd'],  m_wt['mse']]
ax.bar(x-0.2, raw_vals, 0.35, label='未加权', color='#E91E63', alpha=0.8)
ax.bar(x+0.2, wt_vals,  0.35, label='加权后', color='#4CAF50', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_title('非应答偏差分析（SRS+加权调整）')
ax.set_ylabel('绝对值')
ax.legend()

# Plot 6: Cost vs MSE scatter
ax = axes[1, 2]
costs = []
mse_list = []
design_list = []
for d in design_names:
    c, _ = compute_cost(d, df_B, n=n, m=m_cluster, k=k_cluster, n_h=n_h)
    costs.append(c)
    mse_list.append(summary['mean_consume'][d]['mse'])
    design_list.append(d)
ax.scatter(costs, mse_list, s=150, c=colors, zorder=5)
for i, d in enumerate(design_list):
    ax.annotate(d, (costs[i], mse_list[i]), textcoords='offset points', xytext=(5,5), fontsize=8)
ax.set_xlabel('调查成本（元）')
ax.set_ylabel('MSE（月消费均值）')
ax.set_title('成本—精度权衡图')

plt.tight_layout()
plt.savefig('sampling_comparison.png', dpi=150, bbox_inches='tight')
print("\nFigure saved to sampling_comparison.png")

# ==============================
# Print final table for markdown
# ==============================
print("\n====== FINAL TABLE ======")
print("设计,真值,均值估计,Bias,SD,MSE")
for d in design_names:
    r = summary['mean_consume'][d]
    print(f"{d},{TRUE_MEAN_B:.2f},{r['mean_est']:.2f},{r['bias']:.2f},{r['sd']:.2f},{r['mse']:.2f}")

print("\n====== DEFF TABLE ======")
for tname, (truth, label) in targets.items():
    var_srs = summary[tname]['SRS']['sd']**2
    var_clu = summary[tname]['两阶段整群']['sd']**2
    deff = var_clu / var_srs
    print(f"  {label}: DEFF={deff:.3f}")

print("\n====== SCENARIO TABLE ======")
for sname, sr in scenario_results.items():
    print(f"{sname}: between_var={sr['between_var']:.1f}, ratio={sr['ratio']:.3f}")

print("\n====== NONRESPONSE TABLE ======")
print(f"未加权: bias={m_raw['bias']:.2f}, sd={m_raw['sd']:.2f}, mse={m_raw['mse']:.2f}")
print(f"加权后: bias={m_wt['bias']:.2f}, sd={m_wt['sd']:.2f},  mse={m_wt['mse']:.2f}")

print("\nDone!")
