""" 为当日满足收敛三角形的个股生成图表 用法: python scripts/plot_converging_triangles.py python scripts/plot_converging_triangles.py --date 20260120 """ from __future__ import annotations import argparse import csv import os import pickle import sys from typing import Dict, List, Optional import matplotlib.pyplot as plt import numpy as np # 配置中文字体 plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS'] # 支持中文 plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 # 让脚本能找到 src/ 下的模块 sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) from converging_triangle import ( ConvergingTriangleParams, detect_converging_triangle, fit_pivot_line, line_y, pivots_fractal, ) # 导入统一的参数配置 from triangle_config import DETECTION_PARAMS, DISPLAY_WINDOW class FakeModule: """空壳模块,绕过 model 依赖""" ndarray = np.ndarray def load_pkl(pkl_path: str) -> dict: """加载 pkl 文件""" sys.modules['model'] = FakeModule() sys.modules['model.index_info'] = FakeModule() with open(pkl_path, 'rb') as f: data = pickle.load(f) return data def load_ohlcv_from_pkl(data_dir: str) -> tuple: """从 pkl 文件加载 OHLCV 数据""" open_data = load_pkl(os.path.join(data_dir, "open.pkl")) high_data = load_pkl(os.path.join(data_dir, "high.pkl")) low_data = load_pkl(os.path.join(data_dir, "low.pkl")) close_data = load_pkl(os.path.join(data_dir, "close.pkl")) volume_data = load_pkl(os.path.join(data_dir, "volume.pkl")) dates = close_data["dtes"] tkrs = close_data["tkrs"] tkrs_name = close_data["tkrs_name"] return ( open_data["mtx"], high_data["mtx"], low_data["mtx"], close_data["mtx"], volume_data["mtx"], dates, tkrs, tkrs_name, ) def load_daily_stocks(csv_path: str, target_date: int) -> List[Dict]: """从CSV读取指定日期的股票列表""" stocks = [] with open(csv_path, newline="", encoding="utf-8-sig") as f: reader = csv.DictReader(f) for row in reader: try: date = int(row.get("date", "0")) if date == target_date: stocks.append({ "stock_idx": int(row.get("stock_idx", "0")), "stock_code": row.get("stock_code", ""), "stock_name": row.get("stock_name", ""), "breakout_dir": row.get("breakout_dir", "none"), "breakout_strength_up": float(row.get("breakout_strength_up", "0")), "breakout_strength_down": float(row.get("breakout_strength_down", "0")), }) except (ValueError, TypeError): continue return stocks def plot_triangle( stock_idx: int, stock_code: str, stock_name: str, date_idx: int, high_mtx: np.ndarray, low_mtx: np.ndarray, close_mtx: np.ndarray, volume_mtx: np.ndarray, dates: np.ndarray, params: ConvergingTriangleParams, output_path: str, display_window: int = 500, # 新增:显示窗口大小 ) -> None: """绘制单只股票的收敛三角形图""" # 提取该股票数据并过滤NaN high_stock = high_mtx[stock_idx, :] low_stock = low_mtx[stock_idx, :] close_stock = close_mtx[stock_idx, :] volume_stock = volume_mtx[stock_idx, :] valid_mask = ~np.isnan(close_stock) valid_indices = np.where(valid_mask)[0] if len(valid_indices) < params.window: print(f" [跳过] {stock_code} {stock_name}: 有效数据不足") return # 找到date_idx在有效数据中的位置 if date_idx not in valid_indices: print(f" [跳过] {stock_code} {stock_name}: date_idx不在有效范围") return valid_end = np.where(valid_indices == date_idx)[0][0] if valid_end < params.window - 1: print(f" [跳过] {stock_code} {stock_name}: 窗口数据不足") return # 提取检测窗口数据(用于三角形检测) detect_start = valid_end - params.window + 1 high_win = high_stock[valid_mask][detect_start:valid_end + 1] low_win = low_stock[valid_mask][detect_start:valid_end + 1] close_win = close_stock[valid_mask][detect_start:valid_end + 1] volume_win = volume_stock[valid_mask][detect_start:valid_end + 1] # 提取显示窗口数据(用于绘图,更长的历史) display_start = max(0, valid_end - display_window + 1) display_high = high_stock[valid_mask][display_start:valid_end + 1] display_low = low_stock[valid_mask][display_start:valid_end + 1] display_close = close_stock[valid_mask][display_start:valid_end + 1] display_volume = volume_stock[valid_mask][display_start:valid_end + 1] display_dates = dates[valid_indices[display_start:valid_end + 1]] # 检测三角形(使用检测窗口数据) result = detect_converging_triangle( high=high_win, low=low_win, close=close_win, volume=volume_win, params=params, stock_idx=stock_idx, date_idx=date_idx, ) if not result.is_valid: print(f" [跳过] {stock_code} {stock_name}: 未识别到有效三角形") return # 绘图准备 x_display = np.arange(len(display_close), dtype=float) # 计算三角形在显示窗口中的位置偏移 triangle_offset = len(display_close) - len(close_win) # 获取检测窗口的起止索引(相对于检测窗口内部) n = len(close_win) x_win = np.arange(n, dtype=float) # 计算枢轴点(与检测算法一致) ph_idx, pl_idx = pivots_fractal(high_win, low_win, k=params.pivot_k) # 使用枢轴点连线法拟合边界线(与检测算法一致) a_u, b_u, selected_ph = fit_pivot_line( pivot_indices=ph_idx, pivot_values=high_win[ph_idx], mode="upper", ) a_l, b_l, selected_pl = fit_pivot_line( pivot_indices=pl_idx, pivot_values=low_win[pl_idx], mode="lower", ) # 三角形线段在显示窗口中的X坐标(只画检测窗口范围) xw_in_display = np.arange(triangle_offset, triangle_offset + n, dtype=float) # 计算Y值使用检测窗口内部的X坐标 upper_line = line_y(a_u, b_u, x_win) lower_line = line_y(a_l, b_l, x_win) # 获取选中的枢轴点在显示窗口中的位置(用于标注) ph_display_idx = ph_idx + triangle_offset pl_display_idx = pl_idx + triangle_offset selected_ph_pos = ph_idx[selected_ph] if len(selected_ph) > 0 else np.array([], dtype=int) selected_pl_pos = pl_idx[selected_pl] if len(selected_pl) > 0 else np.array([], dtype=int) selected_ph_display = selected_ph_pos + triangle_offset selected_pl_display = selected_pl_pos + triangle_offset # 三角形检测窗口的日期范围(用于标题) detect_dates = dates[valid_indices[detect_start:valid_end + 1]] # 创建图表 fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), gridspec_kw={'height_ratios': [3, 1]}) # 主图:价格和趋势线(使用显示窗口数据) ax1.plot(x_display, display_close, linewidth=1.5, label='收盘价', color='black', alpha=0.7) ax1.plot(xw_in_display, upper_line, linewidth=2, label='上沿', color='red', linestyle='--') ax1.plot(xw_in_display, lower_line, linewidth=2, label='下沿', color='green', linestyle='--') ax1.axvline(len(display_close) - 1, color='gray', linestyle=':', linewidth=1, alpha=0.5) # 标注选中的枢轴点(用于连线的关键点) if len(selected_ph_display) >= 2: ax1.scatter( selected_ph_display, high_win[selected_ph_pos], marker='o', s=90, facecolors='none', edgecolors='red', linewidths=1.5, zorder=5, label='上沿枢轴点', ) if len(selected_pl_display) >= 2: ax1.scatter( selected_pl_display, low_win[selected_pl_pos], marker='o', s=90, facecolors='none', edgecolors='green', linewidths=1.5, zorder=5, label='下沿枢轴点', ) ax1.set_title( f"{stock_code} {stock_name} - 收敛三角形 (检测窗口: {detect_dates[0]} ~ {detect_dates[-1]})\n" f"显示范围: {display_dates[0]} ~ {display_dates[-1]} ({len(display_dates)}个交易日) " f"突破方向: {result.breakout_dir} 宽度比: {result.width_ratio:.2f} " f"触碰: 上{result.touches_upper}/下{result.touches_lower} " f"放量确认: {'是' if result.volume_confirmed else '否' if result.volume_confirmed is False else '-'}", fontsize=11, pad=10 ) ax1.set_ylabel('价格', fontsize=10) ax1.legend(loc='best', fontsize=9) ax1.grid(True, alpha=0.3) # X轴日期标签(稀疏显示,基于显示窗口) step = max(1, len(display_dates) // 10) tick_indices = np.arange(0, len(display_dates), step) ax1.set_xticks(tick_indices) ax1.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8) # 副图:成交量(使用显示窗口数据) ax2.bar(x_display, display_volume, width=0.8, color='skyblue', alpha=0.6) ax2.set_ylabel('成交量', fontsize=10) ax2.set_xlabel('交易日', fontsize=10) ax2.grid(True, alpha=0.3, axis='y') ax2.set_xticks(tick_indices) ax2.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8) plt.tight_layout() plt.savefig(output_path, dpi=120) plt.close() print(f" [完成] {stock_code} {stock_name} -> {output_path}") def main() -> None: parser = argparse.ArgumentParser(description="为当日满足收敛三角形的个股生成图表") parser.add_argument( "--input", default=os.path.join("outputs", "converging_triangles", "all_results.csv"), help="输入 CSV 路径", ) parser.add_argument( "--date", type=int, default=None, help="指定日期(YYYYMMDD),默认为数据最新日", ) parser.add_argument( "--output-dir", default=os.path.join("outputs", "converging_triangles", "charts"), help="图表输出目录", ) args = parser.parse_args() print("=" * 70) print("收敛三角形图表生成") print("=" * 70) # 1. 加载数据 print("\n[1] 加载 OHLCV 数据...") data_dir = os.path.join(os.path.dirname(__file__), "..", "data") open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(data_dir) print(f" 股票数: {len(tkrs)}, 交易日数: {len(dates)}") # 2. 确定目标日期 if args.date: target_date = args.date else: # 从CSV读取最新日期 with open(args.input, newline="", encoding="utf-8-sig") as f: reader = csv.DictReader(f) all_dates = [int(r.get("date", "0")) for r in reader if r.get("date") and r["date"].isdigit()] target_date = max(all_dates) if all_dates else 0 if target_date == 0: print("错误: 无法确定目标日期") return print(f"\n[2] 目标日期: {target_date}") # 3. 加载当日股票列表 stocks = load_daily_stocks(args.input, target_date) print(f" 当日满足三角形的股票数: {len(stocks)}") if not stocks: print("当日无满足条件的股票") return # 4. 创建输出目录并清空旧图片 os.makedirs(args.output_dir, exist_ok=True) # 清空目录中的旧图片 print(f"\n[4] 清空输出目录...") old_files = [f for f in os.listdir(args.output_dir) if f.endswith('.png')] for f in old_files: os.remove(os.path.join(args.output_dir, f)) print(f" 已删除 {len(old_files)} 个旧图片") # 5. 检测参数(从统一配置导入) params = DETECTION_PARAMS # 6. 找到target_date在dates中的索引 date_idx = np.where(dates == target_date)[0] if len(date_idx) == 0: print(f"错误: 日期 {target_date} 不在数据范围内") return date_idx = date_idx[0] # 7. 生成图表 print(f"\n[3] 生成图表...") for stock in stocks: stock_idx = stock["stock_idx"] stock_code = stock["stock_code"] stock_name = stock["stock_name"] output_filename = f"{target_date}_{stock_code}_{stock_name}.png" output_path = os.path.join(args.output_dir, output_filename) try: plot_triangle( stock_idx=stock_idx, stock_code=stock_code, stock_name=stock_name, date_idx=date_idx, high_mtx=high_mtx, low_mtx=low_mtx, close_mtx=close_mtx, volume_mtx=volume_mtx, dates=dates, params=params, output_path=output_path, display_window=DISPLAY_WINDOW, # 从配置文件读取 ) except Exception as e: print(f" [错误] {stock_code} {stock_name}: {e}") print(f"\n完成!图表已保存至: {args.output_dir}") if __name__ == "__main__": main()