"""US cross-sectional momentum -> leveraged-ETF regime strategy (FinLab).

The screen runs on the full liquid US universe; the tradable position is a small
basket of leveraged growth ETFs that is only held when stock-level momentum
breadth and the QQQ trend are both positive. Reproduces the backtest in the
article. Requires `pip install finlab` and a FinLab account.

    pip install finlab
    python strategy.py
"""
from finlab import data
from finlab.backtest import sim
from finlab.dataframe import FinlabDataFrame

RISK_ASSETS = ["TQQQ", "TECL"]            # leveraged growth ETFs (held in risk-on)
DEFENSIVE_ASSETS = ["IEF", "GLD", "SHY"]  # treasuries / gold / cash-like (risk-off)


def liquid_universe(close, volume, top_n=500, min_price=10, min_dollar_volume=10_000_000):
    """Point-in-time top-N universe by 60-day dollar volume, with price and
    data-quality gates so the backtest can only pick names it could trade."""
    dollar_volume = (close * volume).rolling(60, min_periods=20).mean()
    top_liquid = dollar_volume.is_largest(top_n)
    liquid_enough = dollar_volume > min_dollar_volume
    price_ok = close > min_price

    # Drop names for a year after any single-day move >= 50% (split/data artifacts).
    returns = close.pct_change()
    extreme = returns.abs() >= 0.50
    recently_bad = extreme.rolling(252, min_periods=1).max().shift(1).fillna(False) > 0
    return top_liquid & liquid_enough & price_ok & (~recently_bad.astype(bool))


def build_position():
    data.set_market("us")
    close = data.get("us_price:adj_close")
    volume = data.get("us_price:volume")
    fund_close = data.get("us_fund_price:adj_close")[["QQQ"] + RISK_ASSETS + DEFENSIVE_ASSETS]

    # 1. Liquid universe + the 12-month-minus-1-month momentum factor.
    universe = liquid_universe(close, volume)
    ret_12m = close / close.shift(252) - 1
    ret_1m = close / close.shift(21) - 1
    momentum = ret_12m - ret_1m

    # 2. Breadth: share of the universe that are top-decile, positive-trend leaders.
    leaders = (momentum[universe].rank(axis=1, pct=True) > 0.90) & (ret_12m > 0)
    breadth = leaders.sum(axis=1) / universe.sum(axis=1)

    # 3. Risk-on regime: broad momentum AND QQQ above its 200-day with positive 126-day trend.
    qqq_trend = fund_close["QQQ"] > fund_close["QQQ"].rolling(200, min_periods=100).mean()
    qqq_mom = fund_close["QQQ"] / fund_close["QQQ"].shift(126) - 1
    risk_on = (breadth > 0.08) & (qqq_trend & (qqq_mom > 0)).reindex(breadth.index)
    risk_on = risk_on.reindex(fund_close.index).fillna(False)

    # 4. In risk-on: hold the 2 strongest leveraged ETFs. Otherwise: 1 defensive ETF.
    risk_score = FinlabDataFrame(fund_close[RISK_ASSETS].pct_change(126))
    defensive_score = FinlabDataFrame(
        fund_close[DEFENSIVE_ASSETS].pct_change(63) - fund_close[DEFENSIVE_ASSETS].pct_change(21)
    )
    risky = risk_score.is_largest(2) & risk_on.to_frame().reindex(risk_score.index).iloc[:, 0]
    risky = risky.astype(float).div(risky.astype(float).sum(axis=1), axis=0).fillna(0)
    defensive = (
        defensive_score.is_largest(1) & (~risk_on).to_frame().reindex(defensive_score.index).iloc[:, 0]
    ).astype(float)

    position = (
        risky.reindex(fund_close.index).fillna(0)
        .join(defensive.reindex(fund_close.index).fillna(0), how="outer")
        .fillna(0)
        .T.groupby(level=0).max().T
    )
    return position.loc["2016-01-01":]


def main():
    data.login()  # finlab guides you through login automatically
    position = build_position()
    data.set_market("us_fund")
    report = sim(
        position,
        resample="M",           # rebalance monthly
        position_limit=1,       # max weight PER asset (1.0 = 100%); the number of
                                # holdings is set by is_largest(2) above -> 2 ETFs at 50% each
        stop_loss=0.08,         # 8% touched stop on the ETF leg
        touched_exit=True,
        trade_at_price="close",
        upload=False,
    )
    print(report.get_stats())
    report.display()


if __name__ == "__main__":
    main()
