# -*- coding: utf-8 -*-
"""AI 選股提示詞研究:熱門提示詞 vs 有效提示詞結構。

對應文章:
  https://finlab.finance/blog/ai-stock-picking-prompts-backtest

執行:
  uv run --with finlab --with pandas --with numpy --with matplotlib python \
    static/blog/ai-stock-picking-prompts-backtest/prompt_strategies.py

輸出:
  - /tmp/seo-bt/ai-prompts/metrics.json
  - /tmp/seo-bt/ai-prompts/equity.csv
  - /tmp/seo-bt/ai-prompts/report_strategy.html
  - static/blog/ai-stock-picking-prompts-backtest/*.png  (本機預覽用,正式發佈前上 R2)

共同設定:
  - 回測區間:2018-01-01 到資料最新日
  - 股票池:60 日均量 > 100 萬股
  - 熱門提示詞測試:等權持有前 30 檔,每月再平衡
  - 有效提示詞測試:排名分數、低波動控制、月營收公布後 14 天換股
  - 交易成本:策略使用 finlab sim() 台股預設值;0050 基準為含息買進持有

投資警語:本程式僅供量化研究與教學用途,過去績效不代表未來表現,
不構成任何投資建議;實際交易前請自行評估風險、滑價與交易容量。
"""
from __future__ import annotations

import json
import os
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
from finlab import data
from finlab.backtest import sim

warnings.filterwarnings("ignore")

START = "2018-01-01"
# 資料截止釘在全站 canonical 0050 快照日(聖經 §D,2026-06-09),確保跨文數字一致;
# 季度刷新時更新此日期並同步全站 canonical 數字
END = os.environ.get("DATA_END", "2026-06-09")
TMP_OUT = Path(os.environ.get("OUT_DIR", "/tmp/seo-bt/ai-prompts"))
STATIC_OUT = Path(__file__).resolve().parent
TMP_OUT.mkdir(parents=True, exist_ok=True)
STATIC_OUT.mkdir(parents=True, exist_ok=True)


# ---------- 資料(公告日對齊,避免前視) ----------
def cap(df):
    return df[df.index <= END] if hasattr(df, "index") else df

close = cap(data.get("price:收盤價"))
adj = cap(data.get("etl:adj_close"))
vol = cap(data.get("price:成交股數"))
pe = cap(data.get("price_earning_ratio:本益比"))
dy = cap(data.get("price_earning_ratio:殖利率(%)"))
roe = cap(data.get("fundamental_features:ROE稅後").index_str_to_date())
rev = cap(data.get("monthly_revenue:去年同月增減(%)"))
trust = cap(data.get("institutional_investors_trading_summary:投信買賣超股數"))

roe_d = roe.reindex(close.index, method="ffill")
rev_d = rev.reindex(close.index, method="ffill")
trust_d = trust.reindex(close.index).fillna(0)

pool = (vol.rolling(60).mean() > 1_000_000) & close.notna()
mom60 = close.pct_change(60)
high60 = close.rolling(60).max()
prox60 = close / high60
lowvol60 = close.pct_change().rolling(60).std()
trust_amt = (trust_d * close).rolling(10).sum()


def monthly(df: pd.DataFrame) -> pd.DataFrame:
    return df.reindex(close.index, method="ffill").resample("ME").last()


def rk(df: pd.DataFrame, asc: bool = True, mask: pd.DataFrame | None = None) -> pd.DataFrame:
    """池內百分位排名(0~1);asc=False 代表數值越小分數越高。"""
    d = df.where(mask if mask is not None else pool)
    return d.rank(axis=1, pct=True, ascending=asc).fillna(0)


def top_weights(score: pd.DataFrame, mask: pd.DataFrame, topn: int, power: float = 1.0) -> pd.DataFrame:
    """取總分前 topn,依分數 power 加權;power=0 代表等權。"""
    top = score.where(mask).rank(axis=1, ascending=False) <= topn
    if power == 0:
        w = top.astype(float)
    else:
        w = score.where(top, 0).clip(lower=0) ** power
    return w.div(w.sum(axis=1).replace(0, np.nan), axis=0).fillna(0)


def clip_creturn(creturn: pd.Series) -> pd.Series:
    """sim() 的 creturn 會延伸到執行當日;統計前必須雙端截斷到釘日(2026-06 教訓)。"""
    return creturn[(creturn.index >= START) & (creturn.index <= END)]


def summarize(name: str, report, position: pd.DataFrame) -> dict:
    # 統計一律對截斷後 creturn 用純算術計算,與全站 canonical 0050 口徑同式
    cr = clip_creturn(report.creturn)
    ret = cr.pct_change().dropna()
    mret = cr.resample("ME").last().pct_change().dropna()
    total = float(cr.iloc[-1] / cr.iloc[0] - 1)
    years = (cr.index[-1] - cr.index[0]).days / 365.25
    pos = position[position.index >= START]
    hold = (pos > 0).sum(axis=1).replace(0, np.nan)
    return {
        "name": name,
        "cagr": round(((1 + total) ** (1 / years) - 1) * 100, 2),
        "daily_sharpe": round(float(ret.mean() / ret.std() * 252 ** 0.5), 2),
        "daily_sortino": round(float(ret.mean() / ret[ret < 0].std() * 252 ** 0.5), 2),
        "monthly_sortino": round(float(mret.mean() / mret[mret < 0].std() * 12 ** 0.5), 2),
        "max_drawdown": round(float((cr / cr.cummax() - 1).min()) * 100, 2),
        "total_return": round(total * 100, 1),
        "avg_holdings": round(float(hold.mean()), 1) if hold.notna().any() else 0,
    }


def run(name: str, position: pd.DataFrame, *, resample_offset: str | None = None, fee_ratio=None):
    pos = position[position.index >= START]
    kwargs = {"resample": "M", "upload": False}
    if resample_offset is not None:
        kwargs["resample_offset"] = resample_offset
    if fee_ratio is not None:
        kwargs["fee_ratio"] = fee_ratio
    report = sim(pos, **kwargs)
    return summarize(name, report, pos), report


# ---------- 0050 含息基準 ----------
# canonical 0050 口徑(聖經 §D):etl:adj_close 還原價 buy-and-hold 純指數算術,
# 不經 sim(sim 的進場時滯會低估數個百分點,造成跨文數字不一致)
bser = adj["0050"]
bser = bser[(bser.index >= START) & (bser.index <= END)].dropna()
bret = bser.pct_change().dropna()
_years = (bser.index[-1] - bser.index[0]).days / 365.25
_total = float(bser.iloc[-1] / bser.iloc[0] - 1)
_mret = bser.resample("ME").last().pct_change().dropna()
bench = {
    "name": "0050 含息",
    "cagr": round(((1 + _total) ** (1 / _years) - 1) * 100, 2),
    "daily_sharpe": round(float(bret.mean() / bret.std() * (252 ** 0.5)), 2),
    "daily_sortino": round(float(bret.mean() / bret[bret < 0].std() * (252 ** 0.5)), 2),
    "monthly_sortino": round(float(_mret.mean() / _mret[_mret < 0].std() * (12 ** 0.5)), 2),
    "max_drawdown": round(float((bser / bser.cummax() - 1).min()) * 100, 2),
    "total_return": round(_total * 100, 1),
    "avg_holdings": 1.0,
}
bench_creturn = bser / bser.iloc[0]


# ---------- 熱門提示詞:把常見名單型提示詞照字面翻成規則 ----------
cond_a = pool & (pe > 0) & (roe_d > 0)
pos_a = (rk(pe, asc=False, mask=cond_a) + rk(roe_d, mask=cond_a)).where(cond_a).is_largest(30)

cond_b = pool & (rev_d > 30)
pos_b = rev_d.where(cond_b).is_largest(30)

cond_c = pool & (close >= high60)
pos_c = mom60.where(cond_c).is_largest(30)

cond_d = pool & (trust_d > 0).sustain(5)
pos_d = trust_amt.where(cond_d).is_largest(30)

cond_e = pool & (dy > 0)
pos_e = dy.where(cond_e).is_largest(30)

naive_score = (
    (rk(pe, asc=False) + rk(roe_d)) / 2
    + rk(rev_d)
    + rk(prox60)
    + rk(trust_amt)
    + rk(dy)
)
pos_f = naive_score.where(pool).is_largest(30)

popular_positions = {
    "低本益比+高ROE": pos_a,
    "月營收年增>30%": pos_b,
    "創60日新高": pos_c,
    "投信連買5日": pos_d,
    "高殖利率": pos_e,
    "5因子天真複合": pos_f,
}

popular_rows = []
popular_reports = {}
for name, pos in popular_positions.items():
    row, report = run(name, pos)
    popular_rows.append(row)
    popular_reports[name] = report


# ---------- 有效提示詞結構研究 ----------
revm = monthly(rev)
momm = monthly(mom60)
roem = monthly(roe)
lowvolm = monthly(lowvol60)
liqm = monthly(vol.rolling(60).mean())
amountm = monthly(vol.rolling(60).mean() * close)
pool_m = (liqm > 1_000_000) & (revm > 10) & (revm < 150) & ((revm > 0).rolling(3).sum() == 3)
pool_refined = (revm > 10) & (revm < 150) & ((revm > 0).rolling(3).sum() == 3)
pool_amount = pool_refined & (amountm > 100_000_000)


def rk_m(df: pd.DataFrame, asc: bool = True) -> pd.DataFrame:
    return df.where(pool_m).rank(axis=1, pct=True, ascending=asc).fillna(0)


def rk_refined(df: pd.DataFrame, asc: bool = True, mask: pd.DataFrame | None = None) -> pd.DataFrame:
    p = pool_refined if mask is None else mask
    return df.where(p).rank(axis=1, pct=True, ascending=asc).fillna(0)


score_rev_mom_roe = rk_m(revm) + rk_m(momm) + rk_m(roem)
score_core = score_rev_mom_roe + rk_m(lowvolm, asc=False)
score_refined = (
    rk_refined(revm)
    + rk_refined(momm)
    + rk_refined(roem)
    + rk_refined(lowvolm, asc=False)
)
score_amount = (
    rk_refined(revm, mask=pool_amount)
    + rk_refined(momm, mask=pool_amount)
    + rk_refined(roem, mask=pool_amount)
    + rk_refined(lowvolm, asc=False, mask=pool_amount)
)

fixed_and = pool & (rev_d > 30) & (mom60 > 0.1) & (roe_d > 0)
fixed_and_ranked = mom60.where(fixed_and).is_largest(30)

w_rank_3_equal = top_weights(score_rev_mom_roe, pool_m, topn=40, power=0)
w_rank_4_equal = top_weights(score_core, pool_m, topn=40, power=0)
w_rank_4_squared = top_weights(score_core, pool_m, topn=40, power=2)
w_rank_4_no_offset = w_rank_4_squared.copy()
w_refined = top_weights(score_refined, pool_refined, topn=40, power=2)
w_refined_amount = top_weights(score_amount, pool_amount, topn=40, power=2)

design_specs = [
    ("固定門檻 AND 提示詞", fixed_and_ranked, None),
    ("排名複合：營收+動能+ROE", w_rank_3_equal, "14D"),
    ("加入低波動：四因子等權", w_rank_4_equal, "14D"),
    ("分數平方加權：四因子", w_rank_4_squared, "14D"),
    ("同四因子但不延後換股", w_rank_4_no_offset, None),
    ("精煉四因子：低相關+平方", w_refined, "14D"),
    ("精煉四因子+成交額>1億", w_refined_amount, "14D"),
]

design_rows = []
design_reports = {}
for name, pos, offset in design_specs:
    row, report = run(name, pos, resample_offset=offset)
    design_rows.append(row)
    design_reports[name] = report

strategy_report = design_reports["精煉四因子：低相關+平方"]
baseline_report = popular_reports["5因子天真複合"]
strategy_report.to_html(str(TMP_OUT / "report_strategy.html"), title="有效提示詞：四因子排名複合")
baseline_report.to_html(str(TMP_OUT / "report_baseline.html"), title="天真提示詞：5 因子複合")
strategy_report.to_html(str(STATIC_OUT / "report_strategy.html"), title="有效提示詞：四因子排名複合")
baseline_report.to_html(str(STATIC_OUT / "report_baseline.html"), title="天真提示詞：5 因子複合")


# ---------- 輸出數據 ----------
eq = pd.DataFrame({
    "0050": bench_creturn,
    "naive_5_factor": clip_creturn(baseline_report.creturn),
    "useful_4_factor": clip_creturn(strategy_report.creturn),
})
eq.to_csv(TMP_OUT / "equity.csv")

metrics = {
    "data_end": str(close.index.max().date()),
    "start": START,
    "benchmark": bench,
    "popular_prompts": popular_rows,
    "prompt_design": design_rows,
    "method": {
        "popular_prompt_setting": "60日均量>100萬股,前30檔等權,每月再平衡",
        "useful_prompt_setting": "月營收年增10~150且連3月正成長,營收/動能/ROE/低波動排名分數,前40檔,分數平方加權,月營收公布後14天換股",
        "cost": "策略使用 finlab sim() 台股預設成本;0050 基準為含息買進持有",
    },
}
(TMP_OUT / "metrics.json").write_text(json.dumps(metrics, ensure_ascii=False, indent=2), encoding="utf-8")
pd.DataFrame([bench] + popular_rows).to_csv(TMP_OUT / "popular_prompts.csv", index=False)
pd.DataFrame([bench] + design_rows).to_csv(TMP_OUT / "prompt_design.csv", index=False)


# ---------- 本機預覽圖(PNG 正式發佈前應上 R2) ----------
def save_charts() -> None:
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    plt.rcParams["font.sans-serif"] = [
        "Noto Sans CJK TC",
        "Heiti TC",
        "PingFang TC",
        "Arial Unicode MS",
        "DejaVu Sans",
    ]
    plt.rcParams["axes.unicode_minus"] = False

    colors = {
        "bench": "#9CA3AF",
        "naive": "#EF4444",
        "useful": "#2563EB",
        "green": "#10B981",
    }

    def style(ax):
        ax.grid(True, alpha=0.22)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

    eq_plot = eq.dropna(how="all").ffill()
    fig, ax = plt.subplots(figsize=(12, 6.75), dpi=140)
    ax.plot(eq_plot.index, eq_plot["0050"], label="0050 含息", color=colors["bench"], linewidth=2.2)
    ax.plot(eq_plot.index, eq_plot["naive_5_factor"], label="5 因子天真提示詞", color=colors["naive"], linewidth=2.0)
    ax.plot(eq_plot.index, eq_plot["useful_4_factor"], label="有效提示詞四因子", color=colors["useful"], linewidth=2.4)
    ax.set_title("AI 選股提示詞回測：天真名單 vs 可驗證規格", fontsize=16, fontweight="bold")
    ax.set_ylabel("淨值")
    ax.legend(loc="upper left")
    style(ax)
    fig.tight_layout()
    fig.savefig(STATIC_OUT / "chart_equity_all.png")
    plt.close(fig)

    rows = [bench] + popular_rows
    labels = [r["name"] for r in rows]
    x = np.arange(len(labels))
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6.75), dpi=140)
    ax1.bar(x, [r["cagr"] for r in rows], color=[colors["bench"]] + [colors["naive"]] * len(popular_rows))
    ax1.set_title("年化報酬 CAGR")
    ax1.set_xticks(x)
    ax1.set_xticklabels(labels, rotation=35, ha="right")
    ax2.bar(x, [r["daily_sharpe"] for r in rows], color=[colors["bench"]] + [colors["naive"]] * len(popular_rows))
    ax2.set_title("日夏普")
    ax2.set_xticks(x)
    ax2.set_xticklabels(labels, rotation=35, ha="right")
    style(ax1)
    style(ax2)
    fig.suptitle("5 條熱門 AI 選股提示詞:照字面翻成規則後沒有勝出", fontsize=15, fontweight="bold")
    fig.tight_layout()
    fig.savefig(STATIC_OUT / "chart_cagr_sharpe_bars.png")
    plt.close(fig)

    fig, ax = plt.subplots(figsize=(12, 6.75), dpi=140)
    ax.bar(x, [r["max_drawdown"] for r in rows], color=[colors["bench"]] + [colors["naive"]] * len(popular_rows))
    ax.set_title("最大回撤：名單型提示詞的主要代價是回撤深", fontsize=16, fontweight="bold")
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=35, ha="right")
    ax.set_ylabel("%")
    style(ax)
    fig.tight_layout()
    fig.savefig(STATIC_OUT / "chart_max_drawdown_bars.png")
    plt.close(fig)

    rows = design_rows
    labels = [r["name"] for r in rows]
    x = np.arange(len(labels))
    fig, ax = plt.subplots(figsize=(12, 6.75), dpi=140)
    bar_colors = [colors["naive"], "#F59E0B", colors["green"], "#60A5FA", "#8B5CF6", colors["useful"], "#0F766E"]
    ax.bar(x, [r["daily_sortino"] for r in rows], color=bar_colors[: len(rows)])
    ax.axhline(1.5, color="#111827", linestyle="--", linewidth=1.5, label="品質門檻 1.5")
    ax.set_title("有效提示詞結構：排名、低波、事件對齊會改善風險調整報酬", fontsize=15, fontweight="bold")
    ax.set_ylabel("日索提諾")
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=30, ha="right")
    ax.legend(loc="upper left")
    style(ax)
    fig.tight_layout()
    fig.savefig(STATIC_OUT / "chart_prompt_design_ladder.png")
    plt.close(fig)

    dd = eq_plot["useful_4_factor"] / eq_plot["useful_4_factor"].cummax() - 1
    fig, ax = plt.subplots(figsize=(12, 6.75), dpi=140)
    ax.fill_between(dd.index, dd.values * 100, 0, color=colors["useful"], alpha=0.25)
    ax.plot(dd.index, dd.values * 100, color=colors["useful"], linewidth=2)
    ax.set_title("有效提示詞四因子策略回撤曲線", fontsize=16, fontweight="bold")
    ax.set_ylabel("回撤 %")
    style(ax)
    fig.tight_layout()
    fig.savefig(STATIC_OUT / "chart_underwater_composite.png")
    plt.close(fig)


try:
    save_charts()
except Exception as exc:
    print("chart generation skipped:", repr(exc))


print("\n=== 0050 基準 ===")
print(bench)
print("\n=== 熱門提示詞 ===")
for row in popular_rows:
    print(row)
print("\n=== 提示詞結構研究 ===")
for row in design_rows:
    print(row)
print(f"\nSaved outputs to {TMP_OUT}")
print(f"Saved local review charts/reports to {STATIC_OUT}")
