#!/usr/bin/env python3
"""
Replication Script for "Mapping the Structural Divide"
Kyle Saunders, Colorado State University

v1.4 — April 2026

This script reproduces the composite scores, quadrant assignments,
factor analysis, sensitivity analysis, and z-score robustness check
reported in the working paper and supplementary materials.

Requirements: pandas, numpy, scipy (standard Anaconda/pip install)
Input: university_mapping_dataset.csv (provided in the data download)
"""

import pandas as pd
import numpy as np
from scipy import linalg as la
from scipy.stats import chi2 as chi2_dist, rankdata, norm
import sys
import os

# ============================================================
# CONFIGURATION
# ============================================================

RESILIENCE_VARS = ['R_ENDOW', 'R_REVDIV', 'R_ENROLL', 'R_SELECT']
MARKET_VARS = ['R_COMPLETION', 'L_EARNDEBT', 'L_AIEXP', 'L_DEMO']
ALL_COMPONENTS = RESILIENCE_VARS + MARKET_VARS

COMPONENT_LABELS = {
    'R_ENDOW':      'Endowment per FTE student',
    'R_REVDIV':     'Revenue diversification',
    'R_ENROLL':     'Enrollment trajectory (5-year)',
    'R_SELECT':     'Selectivity (1 - admission rate)',
    'R_COMPLETION': 'Six-year completion rate',
    'L_EARNDEBT':   'Earnings-to-debt ratio',
    'L_AIEXP':      'AI exposure (inverted)',
    'L_DEMO':       'Demographic trajectory (WICHE)',
}

QUADRANT_LABELS = {
    (True, True):   'High Capacity',
    (True, False):  'Market Misaligned',
    (False, True):  'Structurally Exposed',
    (False, False): 'High Stress',
}

MIN_COMPONENTS = 2

# ============================================================
# LOAD DATA
# ============================================================

def load_data(filepath='university_mapping_dataset.csv'):
    """Load the dataset."""
    if not os.path.exists(filepath):
        print(f"Error: {filepath} not found.")
        print("Please place the dataset in the current directory.")
        print("Download from: https://kylesaunders.com/university-map (Data & Downloads tab)")
        sys.exit(1)
    df = pd.read_csv(filepath)
    print(f"Loaded {len(df)} institutions from {filepath}")
    return df

# ============================================================
# STEP 1: COMPOSITE SCORES AND QUADRANT ASSIGNMENT
# ============================================================

def compute_composites(df):
    """Compute equal-weight composite scores for each axis."""
    df = df.copy()
    df['REPL_RESILIENCE'] = df[RESILIENCE_VARS].mean(axis=1)
    df['REPL_MARKET'] = df[MARKET_VARS].mean(axis=1)

    r_count = df[RESILIENCE_VARS].notna().sum(axis=1)
    m_count = df[MARKET_VARS].notna().sum(axis=1)
    df.loc[r_count < MIN_COMPONENTS, 'REPL_RESILIENCE'] = np.nan
    df.loc[m_count < MIN_COMPONENTS, 'REPL_MARKET'] = np.nan

    n_mapped = df['REPL_RESILIENCE'].notna() & df['REPL_MARKET'].notna()
    print(f"\nStep 1: Composite Scores")
    print(f"  Institutions with both scores: {n_mapped.sum()}")
    print(f"  Resilience median: {df['REPL_RESILIENCE'].median():.4f}")
    print(f"  Market Position median: {df['REPL_MARKET'].median():.4f}")
    return df

def assign_quadrants(df, r_col='REPL_RESILIENCE', m_col='REPL_MARKET', q_col='REPL_QUADRANT'):
    """Assign quadrants via median split."""
    df = df.copy()
    r_med = df[r_col].median()
    m_med = df[m_col].median()

    def _quad(row):
        r, m = row[r_col], row[m_col]
        if pd.isna(r) or pd.isna(m):
            return None
        return QUADRANT_LABELS[(r >= r_med, m >= m_med)]

    df[q_col] = df.apply(_quad, axis=1)

    print(f"\nQuadrant Distribution ({q_col}):")
    print(df[q_col].value_counts().to_string())

    if 'QUADRANT' in df.columns and q_col == 'REPL_QUADRANT':
        match = (df[q_col] == df['QUADRANT'])
        valid = df[q_col].notna() & df['QUADRANT'].notna()
        agreement = match[valid].mean()
        print(f"\n  Agreement with published quadrants: {agreement:.1%}")
        if agreement >= 0.99:
            print("  Replication matches published values.")

    return df

# ============================================================
# STEP 2: PRE-FACTOR CORRELATION DIAGNOSTICS
# ============================================================

def correlation_diagnostics(df):
    """Compute and display the 8x8 correlation matrix."""
    sub = df[ALL_COMPONENTS].dropna()
    n = len(sub)
    corr = sub.corr()

    print(f"\nStep 2: Pre-Factor Correlation Diagnostics (N = {n})")
    print(f"\n  Pairwise correlations |r| > 0.4:")
    names = [COMPONENT_LABELS[v].split('(')[0].strip()[:20] for v in ALL_COMPONENTS]
    for i in range(len(ALL_COMPONENTS)):
        for j in range(i+1, len(ALL_COMPONENTS)):
            r = corr.iloc[i, j]
            if abs(r) > 0.4:
                axis_i = 'R' if ALL_COMPONENTS[i] in RESILIENCE_VARS else 'M'
                axis_j = 'R' if ALL_COMPONENTS[j] in RESILIENCE_VARS else 'M'
                cross = "CROSS-AXIS" if axis_i != axis_j else "within-axis"
                print(f"    {names[i]:>20} <-> {names[j]:<20} r = {r:+.3f} [{cross}]")

    det = la.det(corr.values)
    print(f"\n  Determinant of correlation matrix: {det:.6f}")
    return corr

# ============================================================
# STEP 3: FACTOR ANALYSIS (Principal Components)
# ============================================================

def extract_factors_pca(corr_matrix, n_factors):
    """PCA extraction: eigenvectors x sqrt(eigenvalues)."""
    R = corr_matrix.values.copy() if hasattr(corr_matrix, 'values') else corr_matrix.copy()
    eigenvalues, eigenvectors = la.eigh(R)
    idx = np.argsort(eigenvalues)[::-1]
    eigenvalues = eigenvalues[idx]
    eigenvectors = eigenvectors[:, idx]
    loadings = eigenvectors[:, :n_factors] * np.sqrt(eigenvalues[:n_factors])
    return loadings, eigenvalues

def _normalize_signs(loadings):
    """Enforce deterministic sign convention."""
    signs = np.sign(loadings[np.argmax(np.abs(loadings), axis=0), np.arange(loadings.shape[1])])
    return loadings * signs

def varimax_rotation(loadings, max_iter=1000, tol=1e-8):
    """Varimax (orthogonal) rotation."""
    n, k = loadings.shape
    R = np.eye(k)
    d = 0
    for _ in range(max_iter):
        old_d = d
        Lambda = loadings @ R
        u, s, vt = la.svd(
            loadings.T @ (Lambda**3 - (1.0/n) * Lambda @ np.diag(np.sum(Lambda**2, axis=0)))
        )
        R = u @ vt
        d = np.sum(s)
        if abs(d - old_d) < tol:
            break
    rotated = _normalize_signs(loadings @ R)
    return rotated, R

def promax_rotation(loadings, power=4):
    """Promax (oblique) rotation."""
    rotated, R_v = varimax_rotation(loadings)
    target = np.sign(rotated) * np.abs(rotated)**power
    T, _, _, _ = la.lstsq(rotated, target)
    T = T / np.sqrt(np.sum(T**2, axis=0))
    pattern = _normalize_signs(loadings @ R_v @ T)
    T_inv = la.inv(T)
    factor_corr = T_inv @ T_inv.T
    D = np.diag(1.0 / np.sqrt(np.diag(factor_corr)))
    factor_corr = D @ factor_corr @ D
    return pattern, factor_corr

def factor_diagnostics(df):
    """Compute KMO, Bartlett's test, and parallel analysis."""
    sub = df[ALL_COMPONENTS].dropna()
    n, p = sub.shape
    X = sub.values
    X_std = (X - X.mean(axis=0)) / X.std(axis=0, ddof=0)
    R_corr = np.corrcoef(X_std.T)

    # KMO
    R_inv = la.inv(R_corr)
    D = np.diag(1.0 / np.sqrt(np.diag(R_inv)))
    partial_corr = -D @ R_inv @ D
    np.fill_diagonal(partial_corr, 1.0)
    r2_sum = sum(R_corr[i,j]**2 for i in range(p) for j in range(p) if i!=j)
    a2_sum = sum(partial_corr[i,j]**2 for i in range(p) for j in range(p) if i!=j)
    kmo = r2_sum / (r2_sum + a2_sum)

    # Bartlett's test
    det_R = la.det(R_corr)
    df_b = p * (p - 1) / 2
    chi2_stat = -((n - 1) - (2 * p + 5) / 6) * np.log(det_R)
    p_val = chi2_dist.sf(chi2_stat, df_b)

    # PCA eigenvalues
    eigenvalues = np.sort(la.eigvalsh(R_corr))[::-1]

    # Parallel analysis (1000 replications)
    np.random.seed(42)
    random_eigs = np.zeros((1000, p))
    for rep in range(1000):
        rd = np.random.normal(0, 1, (n, p))
        random_eigs[rep] = np.sort(la.eigvalsh(np.corrcoef(rd.T)))[::-1]
    pct95 = np.percentile(random_eigs, 95, axis=0)
    n_pa = sum(1 for i in range(p) if eigenvalues[i] > pct95[i])

    print(f"\nStep 2b: Factor Analysis Diagnostics (N = {n})")
    print(f"  KMO = {kmo:.3f}", end="")
    for label, lo in [("Marvelous",.9),("Meritorious",.8),("Middling",.7),("Mediocre",.6),("Miserable",.5)]:
        if kmo >= lo:
            print(f" ({label})"); break
    else:
        print(" (Unacceptable)")
    print(f"  Bartlett's chi2({int(df_b)}) = {chi2_stat:.2f}, p < .001")
    print(f"  Parallel analysis: {n_pa} factors (95th percentile criterion)")

    for k in range(p):
        r2_k = sum(R_corr[k,j]**2 for j in range(p) if j!=k)
        a2_k = sum(partial_corr[k,j]**2 for j in range(p) if j!=k)
        msa_k = r2_k / (r2_k + a2_k)
        print(f"    {COMPONENT_LABELS[ALL_COMPONENTS[k]]:35s} MSA = {msa_k:.3f}")

def run_factor_analysis(df):
    """Run factor analysis: PCA + varimax/promax rotation."""
    sub = df[ALL_COMPONENTS].dropna()
    n = len(sub)
    X = sub.values
    X_std = (X - X.mean(axis=0)) / X.std(axis=0)
    corr = np.corrcoef(X_std.T)

    eigenvalues = np.sort(la.eigvalsh(corr))[::-1]

    print(f"\nStep 3: Factor Analysis (N = {n})")
    print(f"\n  PCA Eigenvalues:")
    cumvar = 0
    for i, ev in enumerate(eigenvalues):
        cumvar += ev / len(ALL_COMPONENTS) * 100
        marker = "<- RETAIN" if ev > 1.0 else ""
        print(f"    PC{i+1}: {ev:.4f}  ({ev/len(ALL_COMPONENTS)*100:.1f}%, cumulative: {cumvar:.1f}%)  {marker}")

    # 3-factor promax (reported in manuscript Section 5.4)
    corr_df = pd.DataFrame(corr, index=ALL_COMPONENTS, columns=ALL_COMPONENTS)
    loadings_unrot, _ = extract_factors_pca(corr_df, 3)
    loadings_promax, factor_corr = promax_rotation(loadings_unrot)

    print(f"\n  3-Factor Promax Pattern Matrix:")
    names = [COMPONENT_LABELS[v].split('(')[0].strip()[:20] for v in ALL_COMPONENTS]
    print(f"    {'Variable':>20}  {'F1':>8}  {'F2':>8}  {'F3':>8}")
    for i, name in enumerate(names):
        print(f"    {name:>20}  {loadings_promax[i,0]:>8.3f}  {loadings_promax[i,1]:>8.3f}  {loadings_promax[i,2]:>8.3f}")

    # 2-factor varimax (reported in supplementary)
    loadings_2, _ = extract_factors_pca(corr_df, 2)
    loadings_2v, _ = varimax_rotation(loadings_2)
    comm2 = np.sum(loadings_2v**2, axis=1)

    print(f"\n  2-Factor Varimax Loadings and Communalities:")
    print(f"    {'Variable':>20}  {'F1':>8}  {'F2':>8}  {'h2':>8}  {'u2':>8}")
    for i, name in enumerate(names):
        print(f"    {name:>20}  {loadings_2v[i,0]:>8.3f}  {loadings_2v[i,1]:>8.3f}  {comm2[i]:>8.3f}  {1-comm2[i]:>8.3f}")

# ============================================================
# STEP 4: SENSITIVITY ANALYSIS
# ============================================================

def _compute_quads(df, r_vars, m_vars):
    """Compute quadrant assignments for a given specification."""
    r = df[r_vars].mean(axis=1)
    m = df[m_vars].mean(axis=1)
    r[df[r_vars].notna().sum(axis=1) < 2] = np.nan
    m[df[m_vars].notna().sum(axis=1) < 2] = np.nan
    r_med, m_med = r.median(), m.median()

    quads = pd.Series(index=df.index, dtype=object)
    mask = r.notna() & m.notna()
    quads[(r >= r_med) & (m >= m_med) & mask] = 'High Capacity'
    quads[(r >= r_med) & (m <  m_med) & mask] = 'Market Misaligned'
    quads[(r <  r_med) & (m >= m_med) & mask] = 'Structurally Exposed'
    quads[(r <  r_med) & (m <  m_med) & mask] = 'High Stress'
    return quads

def sensitivity_analysis(df):
    """Run alternative specifications reported in the manuscript."""
    base = _compute_quads(df, RESILIENCE_VARS, MARKET_VARS)

    alt_specs = {
        'No selectivity':           (['R_ENDOW','R_REVDIV','R_ENROLL'], MARKET_VARS),
        'No AI exposure':           (RESILIENCE_VARS, ['R_COMPLETION','L_EARNDEBT','L_DEMO']),
        'No demographics':          (RESILIENCE_VARS, ['R_COMPLETION','L_EARNDEBT','L_AIEXP']),
        'No enrollment':            (['R_ENDOW','R_REVDIV','R_SELECT'], MARKET_VARS),
        'No rev diversification':   (['R_ENDOW','R_ENROLL','R_SELECT'], MARKET_VARS),
        'No endowment':             (['R_REVDIV','R_ENROLL','R_SELECT'], MARKET_VARS),
        'No completion':            (RESILIENCE_VARS, ['L_EARNDEBT','L_AIEXP','L_DEMO']),
        'No earn/debt':             (RESILIENCE_VARS, ['R_COMPLETION','L_AIEXP','L_DEMO']),
        'Completion on resilience': (['R_ENDOW','R_REVDIV','R_COMPLETION','R_ENROLL','R_SELECT'],
                                     ['L_EARNDEBT','L_AIEXP','L_DEMO']),
        'Double-weight completion': (RESILIENCE_VARS,
                                     ['R_COMPLETION','R_COMPLETION','L_EARNDEBT','L_AIEXP','L_DEMO']),
        'Double-weight AI':         (RESILIENCE_VARS,
                                     ['R_COMPLETION','L_EARNDEBT','L_AIEXP','L_AIEXP','L_DEMO']),
    }

    print(f"\nStep 4: Sensitivity Analysis")
    print(f"\n  {'Specification':<35} {'N':>6} {'Same':>8} {'%Agree':>8}")
    print(f"  {'-'*60}")

    for name, (r_vars, m_vars) in alt_specs.items():
        alt = _compute_quads(df, r_vars, m_vars)
        both = base.notna() & alt.notna()
        n = both.sum()
        same = (base[both] == alt[both]).sum()
        print(f"  {name:<35} {n:>6} {same:>8} {same/n*100:>7.1f}%")

    # Half-weight endowment (no external data needed)
    df_hw = df.copy()
    df_hw['R_ENDOW'] = 0.5 * df_hw['R_ENDOW'] + 0.25
    alt_hw = _compute_quads(df_hw, RESILIENCE_VARS, MARKET_VARS)
    both_hw = base.notna() & alt_hw.notna()
    n_hw = both_hw.sum()
    same_hw = (base[both_hw] == alt_hw[both_hw]).sum()
    print(f"  {'Half-weight endowment':<35} {n_hw:>6} {same_hw:>8} {same_hw/n_hw*100:>7.1f}%")

    print(f"\n  NOTE: Additional v1.3 specifications (RevDiv HERD discount,")
    print(f"  endowment yield, admission yield) require external IPEDS/HERD data.")

# ============================================================
# STEP 5: Z-SCORE ROBUSTNESS CHECK (Manuscript Section 5.4)
# ============================================================

def zscore_robustness(df):
    """Z-score robustness check using probit transform (inverse normal CDF)
    of percentile ranks.

    The manuscript reports: 92.1% quadrant agreement with the baseline
    (rho = 0.987 on resilience, rho = 0.985 on market position).
    """
    from scipy.stats import spearmanr

    df = df.copy()

    # Probit transform: rank within non-missing values, scale to (0,1)
    # open interval using rank/(N+1), then apply inverse normal CDF.
    # This decompresses extremes that percentile ranking compressed.
    for col in ALL_COMPONENTS:
        zcol = f'Z_{col}'
        df[zcol] = np.nan
        mask = df[col].notna()
        vals = df.loc[mask, col]
        n = len(vals)
        ranks = vals.rank(method='average') / (n + 1)
        df.loc[mask, zcol] = norm.ppf(ranks)

    z_resilience = [f'Z_{v}' for v in RESILIENCE_VARS]
    z_market = [f'Z_{v}' for v in MARKET_VARS]

    # Compute z-score composites
    df['Z_RESILIENCE'] = df[z_resilience].mean(axis=1)
    df['Z_MARKET'] = df[z_market].mean(axis=1)

    r_count = df[z_resilience].notna().sum(axis=1)
    m_count = df[z_market].notna().sum(axis=1)
    df.loc[r_count < MIN_COMPONENTS, 'Z_RESILIENCE'] = np.nan
    df.loc[m_count < MIN_COMPONENTS, 'Z_MARKET'] = np.nan

    # Assign z-score quadrants
    z_r_med = df['Z_RESILIENCE'].median()
    z_m_med = df['Z_MARKET'].median()

    def _zquad(row):
        r, m = row['Z_RESILIENCE'], row['Z_MARKET']
        if pd.isna(r) or pd.isna(m):
            return None
        return QUADRANT_LABELS[(r >= z_r_med, m >= z_m_med)]

    df['Z_QUADRANT'] = df.apply(_zquad, axis=1)

    # Compare with baseline
    both = df['REPL_QUADRANT'].notna() & df['Z_QUADRANT'].notna()
    n = both.sum()
    same = (df.loc[both, 'REPL_QUADRANT'] == df.loc[both, 'Z_QUADRANT']).sum()
    pct = same / n * 100

    # Spearman correlations
    valid_r = df['REPL_RESILIENCE'].notna() & df['Z_RESILIENCE'].notna()
    rho_r, _ = spearmanr(df.loc[valid_r, 'REPL_RESILIENCE'], df.loc[valid_r, 'Z_RESILIENCE'])
    valid_m = df['REPL_MARKET'].notna() & df['Z_MARKET'].notna()
    rho_m, _ = spearmanr(df.loc[valid_m, 'REPL_MARKET'], df.loc[valid_m, 'Z_MARKET'])

    print(f"\nStep 5: Z-Score Robustness Check (Probit Transform)")
    print(f"  Quadrant agreement: {same}/{n} ({pct:.1f}%)")
    print(f"  Spearman rho (resilience): {rho_r:.3f}")
    print(f"  Spearman rho (market position): {rho_m:.3f}")
    print(f"  Institutions that shift: {n - same}")

    # Per-tier agreement
    if 'CARNEGIE25_TIER' in df.columns:
        print(f"\n  Per-tier agreement:")
        for tier in df['CARNEGIE25_TIER'].dropna().unique():
            tmask = (df['CARNEGIE25_TIER'] == tier) & both
            tn = tmask.sum()
            if tn > 0:
                tsame = (df.loc[tmask, 'REPL_QUADRANT'] == df.loc[tmask, 'Z_QUADRANT']).sum()
                print(f"    {tier:<30} {tsame}/{tn} ({tsame/tn*100:.1f}%)")

    return df

# ============================================================
# STEP 6: CARNEGIE TIER STRATIFICATION (Manuscript Table 1)
# ============================================================

def tier_stratification(df):
    """Reproduce the Carnegie tier x quadrant cross-tabulation (Table 1)."""
    if 'CARNEGIE25_TIER' not in df.columns:
        print("\nStep 6: Skipped (CARNEGIE25_TIER not in dataset)")
        return

    q_col = 'REPL_QUADRANT'
    valid = df[q_col].notna() & df['CARNEGIE25_TIER'].notna()
    ct = pd.crosstab(df.loc[valid, 'CARNEGIE25_TIER'], df.loc[valid, q_col])

    print(f"\nStep 6: Carnegie Tier x Quadrant (Table 1)")
    print(f"\n  {'Tier':<25} {'HS':>6} {'SE':>6} {'MM':>6} {'HC':>6} {'Total':>6}")
    print(f"  {'-'*58}")

    quad_order = ['High Stress', 'Structurally Exposed', 'Market Misaligned', 'High Capacity']
    for tier in ct.index:
        row = ct.loc[tier]
        total = row.sum()
        vals = [row.get(q, 0) for q in quad_order]
        pcts = [f"{v} ({v/total*100:.0f}%)" if total > 0 else "0" for v in vals]
        print(f"  {tier:<25} {pcts[0]:>10} {pcts[1]:>10} {pcts[2]:>10} {pcts[3]:>10} {total:>6}")

    # Bifurcation statistic (Section 5.2)
    mapped = df[q_col].notna()
    declining = df['R_ENROLL'] < df.loc[mapped, 'R_ENROLL'].median()
    low_res = df['REPL_RESILIENCE'] < df.loc[mapped, 'REPL_RESILIENCE'].median()
    both_stress = declining & low_res & mapped
    n_both = both_stress.sum()
    n_total = mapped.sum()
    print(f"\n  Bifurcation: {n_both} institutions ({n_both/n_total*100:.1f}%) declining + below-median resilience")

# ============================================================
# STEP 7: AI EXPOSURE SUMMARY (Manuscript Section 5.3)
# ============================================================

def ai_exposure_summary(df):
    """Report AI exposure statistics."""
    if 'AI_EXPOSURE_BLENDED' not in df.columns and 'L_AIEXP' not in df.columns:
        print("\nStep 7: Skipped (AI exposure columns not in dataset)")
        return

    print(f"\nStep 7: AI Exposure Summary")

    if 'AI_EXPOSURE_BLENDED' in df.columns:
        ai = df['AI_EXPOSURE_BLENDED'].dropna()
        print(f"  Blended AI exposure: mean={ai.mean():.4f}, SD={ai.std():.4f}, N={len(ai)}")

    ai_rank = df['L_AIEXP'].dropna()
    print(f"  AI exposure (ranked, inverted): mean={ai_rank.mean():.4f}, SD={ai_rank.std():.4f}, N={len(ai_rank)}")

    # PSEO earnings correlation (if available)
    if 'PSEO_EARN_1YR' in df.columns or 'PSEO_EARNINGS' in df.columns:
        earn_col = 'PSEO_EARN_1YR' if 'PSEO_EARN_1YR' in df.columns else 'PSEO_EARNINGS'
        from scipy.stats import spearmanr
        valid = df['L_AIEXP'].notna() & df[earn_col].notna()
        if valid.sum() > 10:
            rho, pval = spearmanr(df.loc[valid, 'L_AIEXP'], df.loc[valid, earn_col])
            print(f"  PSEO earnings x AI exposure: rho={rho:.3f}, p={pval:.4f}, N={valid.sum()}")

# ============================================================
# MAIN
# ============================================================

if __name__ == '__main__':
    filepath = sys.argv[1] if len(sys.argv) > 1 else 'university_mapping_dataset.csv'

    df = load_data(filepath)
    df = compute_composites(df)
    df = assign_quadrants(df)
    corr = correlation_diagnostics(df)
    factor_diagnostics(df)
    run_factor_analysis(df)
    sensitivity_analysis(df)
    df = zscore_robustness(df)
    tier_stratification(df)
    ai_exposure_summary(df)

    print("\n" + "=" * 60)
    print("Replication complete. All results should match manuscript.")
    print("=" * 60)
