""" 批量识别对称三角形 - 从 pkl 文件读取 OHLCV 数据 每个 pkl 文件包含 108 个股票 × N 个交易日的矩阵 """ import os import sys import pickle import numpy as np import pandas as pd # 让脚本能找到 src/ 下的模块 sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "src")) from sym_triangle import detect_sym_triangle, line_y, pivots_fractal, fit_boundary_line # ============================================================================ # 【可调参数区】 # ============================================================================ # --- 数据源 --- DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "data") # --- 窗口大小 --- WINDOW = 400 # --- 枢轴点检测 --- PIVOT_K = 20 # --- 边界线拟合 --- BOUNDARY_N_SEGMENTS = 2 BOUNDARY_SOURCE = "full" # --- 斜率约束 --- UPPER_SLOPE_MAX = 0.10 LOWER_SLOPE_MIN = -0.10 # --- 触碰判定 --- TOUCH_TOL = 0.10 TOUCH_LOSS_MAX = 0.10 # --- 收敛要求 --- SHRINK_RATIO = 0.8 # --- 突破判定 --- BREAK_TOL = 0.001 VOL_WINDOW = 20 VOL_K = 1.3 FALSE_BREAK_M = 5 # --- 输出控制 --- PRINT_DEBUG = False # 批量时关闭调试输出 SAVE_ALL_CHARTS = False # True=保存所有股票图,False=只保存识别到的 # ============================================================================ # pkl 数据加载 # ============================================================================ class FakeModule: """空壳模块,绕过 model 依赖""" ndarray = np.ndarray def load_pkl(pkl_path: str) -> dict: """加载 pkl 文件,返回字典 {mtx, dtes, tkrs, tkrs_name, ...}""" # 注入空壳模块 sys.modules['model'] = FakeModule() sys.modules['model.index_info'] = FakeModule() with open(pkl_path, 'rb') as f: data = pickle.load(f) return data def load_ohlcv_from_pkl(data_dir: str) -> tuple: """ 从 pkl 文件加载 OHLCV 数据 Returns: open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx: shape=(n_stocks, n_days) dates: shape=(n_days,) 真实日期 (如 20050104) tkrs: shape=(n_stocks,) 股票代码 tkrs_name: shape=(n_stocks,) 股票名称 """ open_data = load_pkl(os.path.join(data_dir, "open.pkl")) high_data = load_pkl(os.path.join(data_dir, "high.pkl")) low_data = load_pkl(os.path.join(data_dir, "low.pkl")) close_data = load_pkl(os.path.join(data_dir, "close.pkl")) volume_data = load_pkl(os.path.join(data_dir, "volume.pkl")) # 使用 close 的元数据 dates = close_data["dtes"] tkrs = close_data["tkrs"] tkrs_name = close_data["tkrs_name"] return ( open_data["mtx"], high_data["mtx"], low_data["mtx"], close_data["mtx"], volume_data["mtx"], dates, tkrs, tkrs_name, ) def get_stock_df( stock_idx: int, open_mtx: np.ndarray, high_mtx: np.ndarray, low_mtx: np.ndarray, close_mtx: np.ndarray, volume_mtx: np.ndarray, dates: np.ndarray, ) -> pd.DataFrame: """提取单个股票的 DataFrame""" df = pd.DataFrame({ "date": dates, "open": open_mtx[stock_idx, :], "high": high_mtx[stock_idx, :], "low": low_mtx[stock_idx, :], "close": close_mtx[stock_idx, :], "volume": volume_mtx[stock_idx, :], }) # 过滤掉 NaN/0 值 df = df.replace(0, np.nan).dropna().reset_index(drop=True) return df # ============================================================================ # 绘图 # ============================================================================ def plot_sym_triangle(df: pd.DataFrame, res, stock_id: str, out_path: str) -> None: import matplotlib.pyplot as plt close = df["close"].to_numpy(dtype=float) x = np.arange(len(df), dtype=float) dates = df["date"].to_numpy() a_u, b_u = res.upper_coef a_l, b_l = res.lower_coef start, end = res.start, res.end xw = np.arange(start, end + 1, dtype=float) upper = line_y(a_u, b_u, xw) lower = line_y(a_l, b_l, xw) plt.figure(figsize=(12, 5)) plt.plot(x, close, linewidth=1.2, label="close") plt.plot(xw, upper, linewidth=2, label="upper") plt.plot(xw, lower, linewidth=2, label="lower") plt.axvline(end, color="gray", linestyle="--", linewidth=1) start_date = dates[start] if len(dates) > start else start end_date = dates[end] if len(dates) > end else end plt.title( f"[{stock_id}] sym_triangle: " f"range={start_date}-{end_date}, " f"slope=({a_u:.4f},{a_l:.4f}), " f"width_ratio={res.width_ratio:.2f}, " f"touches=({res.touches_upper},{res.touches_lower})" ) if len(dates) > 0: step = max(1, len(dates) // 8) idx = np.arange(0, len(dates), step) plt.xticks(idx, dates[idx], rotation=45, ha="right") plt.legend() plt.tight_layout() plt.savefig(out_path, dpi=150) plt.close() # ============================================================================ # 主流程 # ============================================================================ def main() -> None: print("=" * 60) print("Symmetric Triangle Batch Detection - from pkl files") print("=" * 60) # 1. 加载数据 print("\n[1] Loading OHLCV pkl files...") open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(DATA_DIR) n_stocks, n_days = close_mtx.shape print(f" Stocks: {n_stocks}") print(f" Days: {n_days}") print(f" Date range: {dates[0]} ~ {dates[-1]}") # 2. 准备输出目录 outputs_dir = os.path.join(os.path.dirname(__file__), "..", "..", "outputs", "sym_triangles") os.makedirs(outputs_dir, exist_ok=True) # 3. 遍历所有股票 print(f"\n[2] Scanning {n_stocks} stocks...") detected = [] for i in range(n_stocks): stock_code = tkrs[i] stock_name = tkrs_name[i] stock_id = f"{stock_code}" # 使用真实股票代码 # 提取单只股票数据 df = get_stock_df(i, open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates) if len(df) < WINDOW: if PRINT_DEBUG: print(f" [{stock_id}] 跳过: 数据不足 ({len(df)} < {WINDOW})") continue # 运行检测 res = detect_sym_triangle( df, window=WINDOW, pivot_k=PIVOT_K, touch_tol=TOUCH_TOL, touch_loss_max=TOUCH_LOSS_MAX, shrink_ratio=SHRINK_RATIO, break_tol=BREAK_TOL, vol_window=VOL_WINDOW, vol_k=VOL_K, false_break_m=FALSE_BREAK_M, upper_slope_max=UPPER_SLOPE_MAX, lower_slope_min=LOWER_SLOPE_MIN, boundary_fit=True, boundary_n_segments=BOUNDARY_N_SEGMENTS, boundary_source=BOUNDARY_SOURCE, ) if res is not None: detected.append((i, stock_id, res, df)) out_path = os.path.join(outputs_dir, f"{stock_id}.png") plot_sym_triangle(df, res, stock_id, out_path) print(f" [OK] {stock_id} -> {out_path}") elif PRINT_DEBUG: print(f" [--] {stock_id} not detected") # 4. 汇总结果 print("\n" + "=" * 60) print(f"Scan completed! {len(detected)}/{n_stocks} stocks have symmetric triangles") print("=" * 60) if detected: print("\nDetected stocks:") for i, stock_id, res, _ in detected: name = tkrs_name[i] print(f" - {stock_id} ({name}): slope=({res.upper_coef[0]:.4f},{res.lower_coef[0]:.4f}), " f"width_ratio={res.width_ratio:.2f}, breakout={res.breakout}") print(f"\nCharts saved to: {outputs_dir}") # 5. Save summary CSV if detected: summary_path = os.path.join(outputs_dir, "summary.csv") summary_data = [] for i, stock_id, res, _ in detected: summary_data.append({ "stock_idx": i, "stock_code": stock_id, "stock_name": tkrs_name[i], "start_date": dates[res.start], "end_date": dates[res.end], "upper_slope": res.upper_coef[0], "lower_slope": res.lower_coef[0], "width_ratio": res.width_ratio, "touches_upper": res.touches_upper, "touches_lower": res.touches_lower, "breakout": res.breakout, }) pd.DataFrame(summary_data).to_csv(summary_path, index=False, encoding="utf-8-sig") print(f"Summary saved: {summary_path}") if __name__ == "__main__": main()