""" 批量滚动检测收敛三角形 - 从 pkl 文件读取 OHLCV 数据 每个股票的每个交易日都会计算,输出 DataFrame """ import os import sys import pickle import time import numpy as np import pandas as pd # 让脚本能找到 src/ 下的模块 sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) from converging_triangle import ( ConvergingTriangleParams, ConvergingTriangleResult, detect_converging_triangle_batch, ) # ============================================================================ # 【可调参数区】 # ============================================================================ # --- 数据源 --- DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") # --- 检测参数 --- PARAMS = ConvergingTriangleParams( window=120, pivot_k=15, boundary_n_segments=2, boundary_source="full", upper_slope_max=0.10, lower_slope_min=-0.10, touch_tol=0.10, touch_loss_max=0.10, shrink_ratio=0.8, break_tol=0.001, vol_window=20, vol_k=1.3, false_break_m=5, ) # --- 计算范围 --- # 设为 None 表示计算全部历史;设为具体数字可以只算最近 N 天 RECENT_DAYS = 500 # 默认只算最近 500 天,避免计算时间过长 # --- 输出控制 --- ONLY_VALID = True # True: 只输出识别到三角形的记录; False: 输出所有记录 VERBOSE = True # ============================================================================ # pkl 数据加载 # ============================================================================ class FakeModule: """空壳模块,绕过 model 依赖""" ndarray = np.ndarray def load_pkl(pkl_path: str) -> dict: """加载 pkl 文件,返回字典 {mtx, dtes, tkrs, tkrs_name, ...}""" sys.modules['model'] = FakeModule() sys.modules['model.index_info'] = FakeModule() with open(pkl_path, 'rb') as f: data = pickle.load(f) return data def load_ohlcv_from_pkl(data_dir: str) -> tuple: """ 从 pkl 文件加载 OHLCV 数据 Returns: open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx: shape=(n_stocks, n_days) dates: shape=(n_days,) 真实日期 (如 20050104) tkrs: shape=(n_stocks,) 股票代码 tkrs_name: shape=(n_stocks,) 股票名称 """ open_data = load_pkl(os.path.join(data_dir, "open.pkl")) high_data = load_pkl(os.path.join(data_dir, "high.pkl")) low_data = load_pkl(os.path.join(data_dir, "low.pkl")) close_data = load_pkl(os.path.join(data_dir, "close.pkl")) volume_data = load_pkl(os.path.join(data_dir, "volume.pkl")) # 使用 close 的元数据 dates = close_data["dtes"] tkrs = close_data["tkrs"] tkrs_name = close_data["tkrs_name"] return ( open_data["mtx"], high_data["mtx"], low_data["mtx"], close_data["mtx"], volume_data["mtx"], dates, tkrs, tkrs_name, ) # ============================================================================ # 主流程 # ============================================================================ def main() -> None: start_time = time.time() print("=" * 70) print("Converging Triangle Batch Detection") print("=" * 70) # 1. 加载数据 print("\n[1] Loading OHLCV pkl files...") load_start = time.time() open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(DATA_DIR) load_time = time.time() - load_start n_stocks, n_days = close_mtx.shape print(f" Stocks: {n_stocks}") print(f" Days: {n_days}") print(f" Date range: {dates[0]} ~ {dates[-1]}") print(f" 加载耗时: {load_time:.2f} 秒") # 2. 找有效数据范围(排除全 NaN 的列) any_valid = np.any(~np.isnan(close_mtx), axis=0) valid_day_idx = np.where(any_valid)[0] if len(valid_day_idx) == 0: print("No valid data found!") return last_valid_day = valid_day_idx[-1] print(f" Last valid day: {last_valid_day} (date: {dates[last_valid_day]})") # 3. 计算范围 start_day = PARAMS.window - 1 end_day = last_valid_day # 使用最后有效日,而不是矩阵末尾 if RECENT_DAYS is not None: start_day = max(start_day, end_day - RECENT_DAYS + 1) print(f"\n[2] Detection range: day {start_day} ~ {end_day}") print(f" Window size: {PARAMS.window}") print(f" Total points: {n_stocks} x {end_day - start_day + 1} = {n_stocks * (end_day - start_day + 1)}") # 3. 批量检测 print("\n[3] Running batch detection...") detect_start = time.time() df = detect_converging_triangle_batch( open_mtx=open_mtx, high_mtx=high_mtx, low_mtx=low_mtx, close_mtx=close_mtx, volume_mtx=volume_mtx, params=PARAMS, start_day=start_day, end_day=end_day, only_valid=ONLY_VALID, verbose=VERBOSE, ) detect_time = time.time() - detect_start print(f" 检测耗时: {detect_time:.2f} 秒") # 4. 添加股票代码、名称和真实日期 if len(df) > 0: df["stock_code"] = df["stock_idx"].map(lambda x: tkrs[x] if x < len(tkrs) else "") df["stock_name"] = df["stock_idx"].map(lambda x: tkrs_name[x] if x < len(tkrs_name) else "") df["date"] = df["date_idx"].map(lambda x: dates[x] if x < len(dates) else 0) # 5. 输出结果 print("\n" + "=" * 70) print("Detection Results") print("=" * 70) if ONLY_VALID: print(f"\nTotal valid triangles detected: {len(df)}") else: valid_count = df["is_valid"].sum() print(f"\nTotal records: {len(df)}") print(f"Valid triangles: {valid_count} ({valid_count/len(df)*100:.1f}%)") # 按突破方向统计 if len(df) > 0 and "breakout_dir" in df.columns: breakout_stats = df[df["is_valid"] == True]["breakout_dir"].value_counts() print(f"\nBreakout statistics:") for dir_name, count in breakout_stats.items(): print(f" - {dir_name}: {count}") # 突破强度统计 if len(df) > 0 and "breakout_strength_up" in df.columns: valid_df = df[df["is_valid"] == True] if len(valid_df) > 0: print(f"\nBreakout strength (valid triangles):") print(f" - Up mean: {valid_df['breakout_strength_up'].mean():.4f}, max: {valid_df['breakout_strength_up'].max():.4f}") print(f" - Down mean: {valid_df['breakout_strength_down'].mean():.4f}, max: {valid_df['breakout_strength_down'].max():.4f}") # 6. 保存结果 outputs_dir = os.path.join(os.path.dirname(__file__), "..", "outputs", "converging_triangles") os.makedirs(outputs_dir, exist_ok=True) # 保存完整 CSV (使用 utf-8-sig 支持中文) csv_path = os.path.join(outputs_dir, "all_results.csv") df.to_csv(csv_path, index=False, encoding="utf-8-sig") print(f"\nResults saved to: {csv_path}") # 保存高强度突破记录 (strength > 0.3) if len(df) > 0: strong_up = df[(df["is_valid"] == True) & (df["breakout_strength_up"] > 0.3)] strong_down = df[(df["is_valid"] == True) & (df["breakout_strength_down"] > 0.3)] if len(strong_up) > 0: strong_up_path = os.path.join(outputs_dir, "strong_breakout_up.csv") strong_up.to_csv(strong_up_path, index=False, encoding="utf-8-sig") print(f"Strong up breakouts ({len(strong_up)}): {strong_up_path}") if len(strong_down) > 0: strong_down_path = os.path.join(outputs_dir, "strong_breakout_down.csv") strong_down.to_csv(strong_down_path, index=False, encoding="utf-8-sig") print(f"Strong down breakouts ({len(strong_down)}): {strong_down_path}") # 7. 显示样本 if len(df) > 0: print("\n" + "-" * 70) print("Sample results (first 10):") print("-" * 70) display_cols = [ "stock_code", "date", "is_valid", "breakout_strength_up", "breakout_strength_down", "breakout_dir", "width_ratio" ] display_cols = [c for c in display_cols if c in df.columns] print(df[display_cols].head(10).to_string(index=False)) total_time = time.time() - start_time print("\n" + "=" * 70) print(f"总耗时: {total_time:.2f} 秒 ({total_time/60:.2f} 分钟)") print("=" * 70) if __name__ == "__main__": main()