""" 批量滚动检测收敛三角形 - 从 pkl 文件读取 OHLCV 数据 每个股票的每个交易日都会计算,输出 DataFrame """ import os import sys import pickle import time import numpy as np import pandas as pd from datetime import datetime from io import StringIO # 让脚本能找到 src/ 下的模块 sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) from converging_triangle import ( ConvergingTriangleParams, ConvergingTriangleResult, detect_converging_triangle_batch, detect_converging_triangle_batch_v2, # v2优化版本 ) # ============================================================================ # 【可调参数区】 # ============================================================================ # --- 数据源 --- DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") # --- 导入统一的参数配置 --- from triangle_config import ( DETECTION_PARAMS as PARAMS, RECENT_DAYS, ONLY_VALID, VERBOSE, REALTIME_MODE, # 新增 FLEXIBLE_ZONE, # 新增 USE_V2_OPTIMIZATION, # v2优化开关 ) # ============================================================================ # pkl 数据加载 # ============================================================================ class FakeModule: """空壳模块,绕过 model 依赖""" ndarray = np.ndarray def load_pkl(pkl_path: str) -> dict: """加载 pkl 文件,返回字典 {mtx, dtes, tkrs, tkrs_name, ...}""" sys.modules['model'] = FakeModule() sys.modules['model.index_info'] = FakeModule() with open(pkl_path, 'rb') as f: data = pickle.load(f) return data def load_ohlcv_from_pkl(data_dir: str) -> tuple: """ 从 pkl 文件加载 OHLCV 数据 Returns: open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx: shape=(n_stocks, n_days) dates: shape=(n_days,) 真实日期 (如 20050104) tkrs: shape=(n_stocks,) 股票代码 tkrs_name: shape=(n_stocks,) 股票名称 """ open_data = load_pkl(os.path.join(data_dir, "open.pkl")) high_data = load_pkl(os.path.join(data_dir, "high.pkl")) low_data = load_pkl(os.path.join(data_dir, "low.pkl")) close_data = load_pkl(os.path.join(data_dir, "close.pkl")) volume_data = load_pkl(os.path.join(data_dir, "volume.pkl")) # 使用 close 的元数据 dates = close_data["dtes"] tkrs = close_data["tkrs"] tkrs_name = close_data["tkrs_name"] return ( open_data["mtx"], high_data["mtx"], low_data["mtx"], close_data["mtx"], volume_data["mtx"], dates, tkrs, tkrs_name, ) # ============================================================================ # 日志工具 # ============================================================================ class Logger: """同时输出到控制台和日志文件""" def __init__(self, log_path: str): self.log_path = log_path self.buffer = StringIO() def print(self, *args, **kwargs): """打印到控制台和缓冲区""" # 打印到控制台 print(*args, **kwargs) # 写入缓冲区 print(*args, **kwargs, file=self.buffer) def save(self): """保存日志到文件""" with open(self.log_path, 'w', encoding='utf-8') as f: f.write(self.buffer.getvalue()) print(f"\n日志已保存到: {self.log_path}") # ============================================================================ # 主流程 # ============================================================================ def main() -> None: start_time = time.time() # 初始化日志 outputs_dir = os.path.join(os.path.dirname(__file__), "..", "outputs", "converging_triangles") os.makedirs(outputs_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_path = os.path.join(outputs_dir, f"run_log_{timestamp}.txt") log = Logger(log_path) log.print("=" * 70) log.print("Converging Triangle Batch Detection") log.print(f"运行时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") log.print("=" * 70) # 1. 加载数据 log.print("\n[1] Loading OHLCV pkl files...") load_start = time.time() open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(DATA_DIR) load_time = time.time() - load_start n_stocks, n_days = close_mtx.shape log.print(f" Stocks: {n_stocks}") log.print(f" Days: {n_days}") log.print(f" Date range: {dates[0]} ~ {dates[-1]}") log.print(f" 加载耗时: {load_time:.2f} 秒") # 2. 找有效数据范围(排除全 NaN 的列) any_valid = np.any(~np.isnan(close_mtx), axis=0) valid_day_idx = np.where(any_valid)[0] if len(valid_day_idx) == 0: log.print("No valid data found!") return last_valid_day = valid_day_idx[-1] log.print(f" Last valid day: {last_valid_day} (date: {dates[last_valid_day]})") # 3. 计算范围 start_day = PARAMS.window - 1 end_day = last_valid_day # 使用最后有效日,而不是矩阵末尾 if RECENT_DAYS is not None: start_day = max(start_day, end_day - RECENT_DAYS + 1) log.print(f"\n[2] Detection range: day {start_day} ~ {end_day}") log.print(f" Window size: {PARAMS.window}") log.print(f" Total points: {n_stocks} x {end_day - start_day + 1} = {n_stocks * (end_day - start_day + 1)}") # 3. 批量检测 log.print("\n[3] Running batch detection...") detect_start = time.time() # 根据配置选择v1或v2版本 if USE_V2_OPTIMIZATION and not REALTIME_MODE: # v2优化版本(预计算枢轴点,不支持实时模式) log.print(" 使用: v2优化版本(预计算枢轴点)") df = detect_converging_triangle_batch_v2( open_mtx=open_mtx, high_mtx=high_mtx, low_mtx=low_mtx, close_mtx=close_mtx, volume_mtx=volume_mtx, params=PARAMS, start_day=start_day, end_day=end_day, only_valid=ONLY_VALID, verbose=VERBOSE, ) else: # v1原版(支持实时模式) if USE_V2_OPTIMIZATION and REALTIME_MODE: log.print(" 注意: v2优化不支持实时模式,自动回退到v1") log.print(" 使用: v1原版") df = detect_converging_triangle_batch( open_mtx=open_mtx, high_mtx=high_mtx, low_mtx=low_mtx, close_mtx=close_mtx, volume_mtx=volume_mtx, params=PARAMS, start_day=start_day, end_day=end_day, only_valid=ONLY_VALID, verbose=VERBOSE, real_time_mode=REALTIME_MODE, flexible_zone=FLEXIBLE_ZONE, ) detect_time = time.time() - detect_start log.print(f" 检测耗时: {detect_time:.2f} 秒") if not USE_V2_OPTIMIZATION or REALTIME_MODE: log.print(f" 检测模式: {'实时模式' if REALTIME_MODE else '标准模式'}") if REALTIME_MODE: log.print(f" 灵活区域: {FLEXIBLE_ZONE} 天") # 4. 添加股票代码、名称和真实日期 if len(df) > 0: df["stock_code"] = df["stock_idx"].map(lambda x: tkrs[x] if x < len(tkrs) else "") df["stock_name"] = df["stock_idx"].map(lambda x: tkrs_name[x] if x < len(tkrs_name) else "") df["date"] = df["date_idx"].map(lambda x: dates[x] if x < len(dates) else 0) # 计算综合强度(取向上和向下的最大值) df["max_strength"] = df[["breakout_strength_up", "breakout_strength_down"]].max(axis=1) # 5. 输出结果 log.print("\n" + "=" * 70) log.print("Detection Results") log.print("=" * 70) if ONLY_VALID: log.print(f"\nTotal valid triangles detected: {len(df)}") else: valid_count = df["is_valid"].sum() log.print(f"\nTotal records: {len(df)}") log.print(f"Valid triangles: {valid_count} ({valid_count/len(df)*100:.1f}%)") # 按突破方向统计 if len(df) > 0 and "breakout_dir" in df.columns: breakout_stats = df[df["is_valid"] == True]["breakout_dir"].value_counts() log.print(f"\nBreakout statistics:") for dir_name, count in breakout_stats.items(): log.print(f" - {dir_name}: {count}") # 突破强度统计 if len(df) > 0 and "breakout_strength_up" in df.columns: valid_df = df[df["is_valid"] == True] if len(valid_df) > 0: log.print(f"\nBreakout strength (valid triangles):") log.print(f" - Up mean: {valid_df['breakout_strength_up'].mean():.4f}, max: {valid_df['breakout_strength_up'].max():.4f}") log.print(f" - Down mean: {valid_df['breakout_strength_down'].mean():.4f}, max: {valid_df['breakout_strength_down'].max():.4f}") # 6. 保存结果 # 保存完整 CSV (使用 utf-8-sig 支持中文) csv_path = os.path.join(outputs_dir, "all_results.csv") df.to_csv(csv_path, index=False, encoding="utf-8-sig") log.print(f"\nResults saved to: {csv_path}") # 保存高强度突破记录 (strength > 0.3) if len(df) > 0: strong_up = df[(df["is_valid"] == True) & (df["breakout_strength_up"] > 0.3)] strong_down = df[(df["is_valid"] == True) & (df["breakout_strength_down"] > 0.3)] if len(strong_up) > 0: strong_up_path = os.path.join(outputs_dir, "strong_breakout_up.csv") strong_up.to_csv(strong_up_path, index=False, encoding="utf-8-sig") log.print(f"Strong up breakouts ({len(strong_up)}): {strong_up_path}") if len(strong_down) > 0: strong_down_path = os.path.join(outputs_dir, "strong_breakout_down.csv") strong_down.to_csv(strong_down_path, index=False, encoding="utf-8-sig") log.print(f"Strong down breakouts ({len(strong_down)}): {strong_down_path}") # 7. 每日最佳股票报告(核心新功能) if len(df) > 0 and "date" in df.columns: log.print("\n" + "=" * 70) log.print("每日最佳收敛三角形股票 (按日期)") log.print("=" * 70) valid_df = df[df["is_valid"] == True].copy() if len(valid_df) > 0: # 按日期分组,每天取分数最高的股票 daily_best = valid_df.loc[valid_df.groupby("date")["max_strength"].idxmax()] daily_best = daily_best.sort_values("date", ascending=True) log.print(f"\n共 {len(daily_best)} 个交易日检测到收敛三角形:") log.print("-" * 90) log.print(f"{'日期':<12} {'股票代码':<10} {'股票名称':<12} {'强度↑':>8} {'强度↓':>8} {'方向':<6} {'收敛比':>8}") log.print("-" * 90) for _, row in daily_best.iterrows(): log.print( f"{int(row['date']):<12} " f"{row['stock_code']:<10} " f"{row['stock_name']:<12} " f"{row['breakout_strength_up']:>8.4f} " f"{row['breakout_strength_down']:>8.4f} " f"{row['breakout_dir']:<6} " f"{row['width_ratio']:>8.3f}" ) log.print("-" * 90) # 保存每日最佳到 CSV daily_best_path = os.path.join(outputs_dir, "daily_best.csv") daily_best.to_csv(daily_best_path, index=False, encoding="utf-8-sig") log.print(f"\n每日最佳已保存到: {daily_best_path}") # 统计:每只股票被选为每日最佳的次数 log.print("\n" + "-" * 70) log.print("股票被选为每日最佳的次数排行:") log.print("-" * 70) stock_counts = daily_best.groupby(["stock_code", "stock_name"]).size().reset_index(name="count") stock_counts = stock_counts.sort_values("count", ascending=False).head(20) for _, row in stock_counts.iterrows(): log.print(f" {row['stock_code']} {row['stock_name']}: {row['count']} 次") else: log.print("\n没有检测到有效的收敛三角形") # 8. 显示样本 if len(df) > 0: log.print("\n" + "-" * 70) log.print("Sample results (first 10):") log.print("-" * 70) display_cols = [ "stock_code", "date", "is_valid", "breakout_strength_up", "breakout_strength_down", "breakout_dir", "width_ratio" ] display_cols = [c for c in display_cols if c in df.columns] log.print(df[display_cols].head(10).to_string(index=False)) total_time = time.time() - start_time log.print("\n" + "=" * 70) log.print(f"总耗时: {total_time:.2f} 秒 ({total_time/60:.2f} 分钟)") log.print("=" * 70) # 保存日志 log.save() if __name__ == "__main__": main()