""" 为当日满足收敛三角形的个股生成图表 用法: # 简洁模式(默认)- 仅显示收盘价、上沿、下沿 python scripts/plot_converging_triangles.py # 详细模式 - 显示所有枢轴点、拟合点、分段线 python scripts/plot_converging_triangles.py --show-details # 指定日期 python scripts/plot_converging_triangles.py --date 20260120 """ from __future__ import annotations import argparse import csv import os import pickle import sys from typing import Dict, List, Optional import matplotlib.pyplot as plt import numpy as np # 配置中文字体 plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS'] # 支持中文 plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 # 让脚本能找到 src/ 下的模块 sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) from converging_triangle import ( ConvergingTriangleParams, calc_fitting_adherence, detect_converging_triangle, fit_pivot_line_dispatch, line_y, pivots_fractal, pivots_fractal_hybrid, ) # 导入统一的参数配置 from triangle_config import DETECTION_PARAMS, DISPLAY_WINDOW, SHOW_CHART_DETAILS, REALTIME_MODE, FLEXIBLE_ZONE class FakeModule: """空壳模块,绕过 model 依赖""" ndarray = np.ndarray def load_pkl(pkl_path: str) -> dict: """加载 pkl 文件""" sys.modules['model'] = FakeModule() sys.modules['model.index_info'] = FakeModule() with open(pkl_path, 'rb') as f: data = pickle.load(f) return data def load_ohlcv_from_pkl(data_dir: str) -> tuple: """从 pkl 文件加载 OHLCV 数据""" open_data = load_pkl(os.path.join(data_dir, "open.pkl")) high_data = load_pkl(os.path.join(data_dir, "high.pkl")) low_data = load_pkl(os.path.join(data_dir, "low.pkl")) close_data = load_pkl(os.path.join(data_dir, "close.pkl")) volume_data = load_pkl(os.path.join(data_dir, "volume.pkl")) dates = close_data["dtes"] tkrs = close_data["tkrs"] tkrs_name = close_data["tkrs_name"] return ( open_data["mtx"], high_data["mtx"], low_data["mtx"], close_data["mtx"], volume_data["mtx"], dates, tkrs, tkrs_name, ) def load_daily_stocks(csv_path: str, target_date: int) -> List[Dict]: """从CSV读取指定日期的股票列表""" stocks = [] with open(csv_path, newline="", encoding="utf-8-sig") as f: reader = csv.DictReader(f) for row in reader: try: date = int(row.get("date", "0")) if date == target_date: stocks.append({ "stock_idx": int(row.get("stock_idx", "0")), "stock_code": row.get("stock_code", ""), "stock_name": row.get("stock_name", ""), "breakout_dir": row.get("breakout_dir", "none"), "breakout_strength_up": float(row.get("breakout_strength_up", "0")), "breakout_strength_down": float(row.get("breakout_strength_down", "0")), }) except (ValueError, TypeError): continue return stocks def plot_triangle( stock_idx: int, stock_code: str, stock_name: str, date_idx: int, high_mtx: np.ndarray, low_mtx: np.ndarray, close_mtx: np.ndarray, volume_mtx: np.ndarray, dates: np.ndarray, params: ConvergingTriangleParams, output_path: str, display_window: int = 500, # 显示窗口大小 show_details: bool = False, # 是否显示详细调试信息 force_plot: bool = False, # 强制绘图(即使不满足三角形条件) plot_boundary_source: str = "hl", # 边界线拟合数据源: "hl" | "close" show_high_low: bool = False, # 是否显示日内高低价范围 ) -> None: """绘制单只股票的收敛三角形图""" # 提取该股票数据并过滤NaN high_stock = high_mtx[stock_idx, :] low_stock = low_mtx[stock_idx, :] close_stock = close_mtx[stock_idx, :] volume_stock = volume_mtx[stock_idx, :] valid_mask = ~np.isnan(close_stock) valid_indices = np.where(valid_mask)[0] if len(valid_indices) < params.window: if force_plot: print(f" [警告] {stock_code} {stock_name}: 有效数据不足,仅绘制基础K线") # 继续绘制基础K线 else: print(f" [跳过] {stock_code} {stock_name}: 有效数据不足") return # 找到date_idx在有效数据中的位置 if date_idx not in valid_indices: print(f" [跳过] {stock_code} {stock_name}: date_idx不在有效范围") return valid_end = np.where(valid_indices == date_idx)[0][0] # 检查窗口数据是否充足 has_enough_data = valid_end >= params.window - 1 if not has_enough_data and not force_plot: print(f" [跳过] {stock_code} {stock_name}: 窗口数据不足") return # 提取检测窗口数据(用于三角形检测) if has_enough_data: detect_start = valid_end - params.window + 1 high_win = high_stock[valid_mask][detect_start:valid_end + 1] low_win = low_stock[valid_mask][detect_start:valid_end + 1] close_win = close_stock[valid_mask][detect_start:valid_end + 1] volume_win = volume_stock[valid_mask][detect_start:valid_end + 1] else: # 数据不足,使用所有可用数据 detect_start = 0 high_win = high_stock[valid_mask][:valid_end + 1] low_win = low_stock[valid_mask][:valid_end + 1] close_win = close_stock[valid_mask][:valid_end + 1] volume_win = volume_stock[valid_mask][:valid_end + 1] # 提取显示窗口数据(用于绘图,更长的历史) display_start = max(0, valid_end - display_window + 1) display_high = high_stock[valid_mask][display_start:valid_end + 1] display_low = low_stock[valid_mask][display_start:valid_end + 1] display_close = close_stock[valid_mask][display_start:valid_end + 1] display_volume = volume_stock[valid_mask][display_start:valid_end + 1] display_dates = dates[valid_indices[display_start:valid_end + 1]] # ======================================================================== # 计算三角形参数(用于绘图) # force_plot模式:即使不满足条件也尝试检测,检测失败则只画K线 # ======================================================================== result = None has_triangle = False fitting_adherence_plot = 0.0 # 初始化贴合度 if has_enough_data: result = detect_converging_triangle( high=high_win, low=low_win, close=close_win, volume=volume_win, params=params, stock_idx=stock_idx, date_idx=date_idx, real_time_mode=REALTIME_MODE, flexible_zone=FLEXIBLE_ZONE, ) has_triangle = result.is_valid if result else False # 在force_plot模式下,即使没有有效三角形也继续绘图(仅K线) if not force_plot and not has_triangle: print(f" [跳过] {stock_code} {stock_name}: 不满足收敛三角形条件") return # 绘图准备 x_display = np.arange(len(display_close), dtype=float) # 只在有三角形时计算三角形相关参数 if has_triangle and has_enough_data: # 计算三角形在显示窗口中的位置偏移 triangle_offset = len(display_close) - len(close_win) # 获取检测窗口的起止索引(相对于检测窗口内部) n = len(close_win) x_win = np.arange(n, dtype=float) # 计算枢轴点(与检测算法一致,考虑实时模式) if REALTIME_MODE: confirmed_ph, confirmed_pl, candidate_ph, candidate_pl = pivots_fractal_hybrid( high_win, low_win, k=params.pivot_k, flexible_zone=FLEXIBLE_ZONE ) # 合并确认枢轴点和候选枢轴点 ph_idx = np.concatenate([confirmed_ph, candidate_ph]) if len(candidate_ph) > 0 else confirmed_ph pl_idx = np.concatenate([confirmed_pl, candidate_pl]) if len(candidate_pl) > 0 else confirmed_pl # 排序以保证顺序 ph_idx = np.sort(ph_idx) pl_idx = np.sort(pl_idx) else: ph_idx, pl_idx = pivots_fractal(high_win, low_win, k=params.pivot_k) # 使用枢轴点连线法拟合边界线(与检测算法一致) # 注意:绘图用的是检测窗口数据,因此 window_start=0, window_end=n-1 if plot_boundary_source == "close": upper_fit_values = close_win[ph_idx] lower_fit_values = close_win[pl_idx] upper_all_prices = close_win lower_all_prices = close_win else: upper_fit_values = high_win[ph_idx] lower_fit_values = low_win[pl_idx] upper_all_prices = high_win lower_all_prices = low_win a_u, b_u, selected_ph = fit_pivot_line_dispatch( pivot_indices=ph_idx, pivot_values=upper_fit_values, mode="upper", method=params.fitting_method, all_prices=upper_all_prices, window_start=0, window_end=n - 1, ) a_l, b_l, selected_pl = fit_pivot_line_dispatch( pivot_indices=pl_idx, pivot_values=lower_fit_values, mode="lower", method=params.fitting_method, all_prices=lower_all_prices, window_start=0, window_end=n - 1, ) # 三角形线段在显示窗口中的X坐标(只画检测窗口范围) xw_in_display = np.arange(triangle_offset, triangle_offset + n, dtype=float) # 计算Y值使用检测窗口内部的X坐标 upper_line = line_y(a_u, b_u, x_win) lower_line = line_y(a_l, b_l, x_win) # 获取选中的枢轴点在显示窗口中的位置(用于标注) ph_display_idx = ph_idx + triangle_offset pl_display_idx = pl_idx + triangle_offset selected_ph_pos = ph_idx[selected_ph] if len(selected_ph) > 0 else np.array([], dtype=int) selected_pl_pos = pl_idx[selected_pl] if len(selected_pl) > 0 else np.array([], dtype=int) selected_ph_display = selected_ph_pos + triangle_offset selected_pl_display = selected_pl_pos + triangle_offset # 三角形检测窗口的日期范围(用于标题) detect_dates = dates[valid_indices[detect_start:valid_end + 1]] # 如果使用收盘价拟合,重新计算贴合度(基于实际拟合线) if plot_boundary_source == "close" and len(selected_ph) > 0 and len(selected_pl) > 0: # 使用收盘价重新计算贴合度 adherence_upper_close = calc_fitting_adherence( pivot_indices=selected_ph_pos.astype(float), pivot_values=close_win[selected_ph_pos], slope=a_u, intercept=b_u, ) adherence_lower_close = calc_fitting_adherence( pivot_indices=selected_pl_pos.astype(float), pivot_values=close_win[selected_pl_pos], slope=a_l, intercept=b_l, ) fitting_adherence_plot = (adherence_upper_close + adherence_lower_close) / 2.0 else: # 使用检测算法计算的贴合度 fitting_adherence_plot = result.fitting_score if result else 0.0 else: # 无三角形时,贴合度为0 fitting_adherence_plot = 0.0 # 创建图表 fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), gridspec_kw={'height_ratios': [3, 1]}) # 主图:价格和趋势线(使用显示窗口数据) ax1.plot(x_display, display_close, linewidth=1.5, label='收盘价', color='black', alpha=0.7) if show_high_low: ax1.fill_between( x_display, display_low, display_high, color='gray', alpha=0.12, label='日内高低范围', ) # 只在有三角形时绘制趋势线 if has_triangle and has_enough_data: boundary_label = "收盘价" if plot_boundary_source == "close" else "高低价" ax1.plot(xw_in_display, upper_line, linewidth=2, label=f'上沿({boundary_label})', color='red', linestyle='--') ax1.plot(xw_in_display, lower_line, linewidth=2, label=f'下沿({boundary_label})', color='green', linestyle='--') ax1.axvline(len(display_close) - 1, color='gray', linestyle=':', linewidth=1, alpha=0.5) # ======================================================================== # 详细模式:显示拟合点(仅在 show_details=True 且有三角形时) # ======================================================================== if show_details and has_triangle and has_enough_data: # 根据数据源选择枢轴点的Y坐标 if plot_boundary_source == "close": ph_pivot_y = close_win[ph_idx] pl_pivot_y = close_win[pl_idx] selected_ph_y = close_win[selected_ph_pos] if len(selected_ph_pos) > 0 else np.array([]) selected_pl_y = close_win[selected_pl_pos] if len(selected_pl_pos) > 0 else np.array([]) else: ph_pivot_y = high_win[ph_idx] pl_pivot_y = low_win[pl_idx] selected_ph_y = high_win[selected_ph_pos] if len(selected_ph_pos) > 0 else np.array([]) selected_pl_y = low_win[selected_pl_pos] if len(selected_pl_pos) > 0 else np.array([]) # 标注所有枢轴点(用于查看拐点分布) if len(ph_display_idx) > 0: ax1.scatter( ph_display_idx, ph_pivot_y, marker='x', s=60, color='red', alpha=0.6, zorder=4, label=f'上沿枢轴点({len(ph_idx)})', ) if len(pl_display_idx) > 0: ax1.scatter( pl_display_idx, pl_pivot_y, marker='x', s=60, color='green', alpha=0.6, zorder=4, label=f'下沿枢轴点({len(pl_idx)})', ) # 标注选中的枢轴点(用于拟合线的关键点) if len(selected_ph_display) >= 2: ax1.scatter( selected_ph_display, selected_ph_y, marker='o', s=120, facecolors='none', edgecolors='red', linewidths=2.5, zorder=5, label=f'上沿拟合点({len(selected_ph_pos)})', ) if len(selected_pl_display) >= 2: ax1.scatter( selected_pl_display, selected_pl_y, marker='o', s=120, facecolors='none', edgecolors='green', linewidths=2.5, zorder=5, label=f'下沿拟合点({len(selected_pl_pos)})', ) # 准备标题内容 if has_triangle and has_enough_data and result: # 有效三角形:显示完整信息和强度分 if result.breakout_dir == "up": strength = result.breakout_strength_up price_score = result.price_score_up elif result.breakout_dir == "down": strength = result.breakout_strength_down price_score = result.price_score_down else: strength = max(result.breakout_strength_up, result.breakout_strength_down) price_score = max(result.price_score_up, result.price_score_down) # 获取边界利用率与惩罚系数(兼容旧数据) boundary_util = getattr(result, 'boundary_utilization', 0.0) utilization_floor = 0.20 if utilization_floor > 0: utilization_penalty = min(1.0, boundary_util / utilization_floor) else: utilization_penalty = 1.0 # 选择显示的贴合度:如果使用收盘价拟合,显示重新计算的贴合度 if plot_boundary_source == "close" and has_triangle and has_enough_data: display_fitting_score = fitting_adherence_plot fitting_note = f"拟合贴合度(收盘价): {display_fitting_score:.3f}" else: display_fitting_score = result.fitting_score fitting_note = f"拟合贴合度: {display_fitting_score:.3f}" ax1.set_title( f"{stock_code} {stock_name} - 收敛三角形 (检测窗口: {detect_dates[0]} ~ {detect_dates[-1]})\n" f"显示范围: {display_dates[0]} ~ {display_dates[-1]} ({len(display_dates)}个交易日) " f"突破方向: {result.breakout_dir} 宽度比: {result.width_ratio:.2f} " f"枢轴点: 高{len(ph_idx)}/低{len(pl_idx)} 触碰: 上{result.touches_upper}/下{result.touches_lower} " f"放量确认: {'是' if result.volume_confirmed else '否' if result.volume_confirmed is False else '-'}\n" f"强度分: {strength:.3f} " f"(价格: {price_score:.3f}×50% + 收敛: {result.convergence_score:.3f}×15% + " f"成交量: {result.volume_score:.3f}×10% + {fitting_note}×10% + " f"边界利用率: {boundary_util:.3f}×15%) × 利用率惩罚: {utilization_penalty:.2f}", fontsize=11, pad=10 ) else: # 无三角形:仅显示基础信息和强度分0分 ax1.set_title( f"{stock_code} {stock_name} - K线图(不满足收敛三角形条件)\n" f"显示范围: {display_dates[0]} ~ {display_dates[-1]} ({len(display_dates)}个交易日)\n" f"强度分: 0.000 (未检测到收敛三角形形态)", fontsize=11, pad=10 ) ax1.set_ylabel('价格', fontsize=10) ax1.legend(loc='best', fontsize=9) ax1.grid(True, alpha=0.3) # X轴日期标签(稀疏显示,基于显示窗口) step = max(1, len(display_dates) // 10) tick_indices = np.arange(0, len(display_dates), step) ax1.set_xticks(tick_indices) ax1.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8) # 副图:成交量(使用显示窗口数据) ax2.bar(x_display, display_volume, width=0.8, color='skyblue', alpha=0.6) ax2.set_ylabel('成交量', fontsize=10) ax2.set_xlabel('交易日', fontsize=10) ax2.grid(True, alpha=0.3, axis='y') ax2.set_xticks(tick_indices) ax2.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8) plt.tight_layout() plt.savefig(output_path, dpi=120) plt.close() print(f" [完成] {stock_code} {stock_name} -> {output_path}") def main() -> None: parser = argparse.ArgumentParser(description="为当日满足收敛三角形的个股生成图表") parser.add_argument( "--input", default=os.path.join("outputs", "converging_triangles", "all_results.csv"), help="输入 CSV 路径", ) parser.add_argument( "--date", type=int, default=None, help="指定日期(YYYYMMDD),默认为数据最新日", ) parser.add_argument( "--output-dir", default=os.path.join("outputs", "converging_triangles", "charts"), help="图表输出目录", ) parser.add_argument( "--show-details", action="store_true", help="显示详细调试信息(枢轴点、拟合点、分段线等)", ) parser.add_argument( "--plot-boundary-source", choices=["hl", "close"], default="close", help="绘图时边界线拟合数据源: hl=高低价, close=收盘价(不影响检测)", ) parser.add_argument( "--show-high-low", action="store_true", help="显示日内高低价范围(仅影响图形展示)", ) parser.add_argument( "--all-stocks", action="store_true", help="为所有108只股票生成图表(包括不满足收敛三角形条件的)", ) parser.add_argument( "--clear", action="store_true", help="清空当前模式的旧图片(默认不清空)", ) args = parser.parse_args() # 确定是否显示详细信息(命令行参数优先) show_details = args.show_details if hasattr(args, 'show_details') else SHOW_CHART_DETAILS all_stocks = args.all_stocks if hasattr(args, 'all_stocks') else False plot_boundary_source = args.plot_boundary_source if hasattr(args, 'plot_boundary_source') else "hl" show_high_low = args.show_high_low if hasattr(args, 'show_high_low') else False print("=" * 70) print("收敛三角形图表生成") print("=" * 70) print(f"详细模式: {'开启' if show_details else '关闭'} {'(--show-details)' if show_details else '(简洁模式)'}") print(f"图表范围: {'所有108只股票' if all_stocks else '仅满足条件的股票'} {'(--all-stocks)' if all_stocks else ''}") # 1. 加载数据 print("\n[1] 加载 OHLCV 数据...") data_dir = os.path.join(os.path.dirname(__file__), "..", "data") open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(data_dir) print(f" 股票数: {len(tkrs)}, 交易日数: {len(dates)}") # 2. 确定目标日期 if args.date: target_date = args.date else: # 从CSV读取最新日期 with open(args.input, newline="", encoding="utf-8-sig") as f: reader = csv.DictReader(f) all_dates = [int(r.get("date", "0")) for r in reader if r.get("date") and r["date"].isdigit()] target_date = max(all_dates) if all_dates else 0 if target_date == 0: print("错误: 无法确定目标日期") return print(f"\n[2] 目标日期: {target_date}") # 3. 加载当日股票列表 if all_stocks: # 模式1: 绘制所有股票(包括不满足条件的) print(f" 模式: 所有股票") print(f" 股票总数: {len(tkrs)}") # 为所有股票创建stock字典(从CSV获取已有的强度分,未检测的置为0) stocks = [] csv_stocks = load_daily_stocks(args.input, target_date) csv_map = {s["stock_idx"]: s for s in csv_stocks} for idx in range(len(tkrs)): if idx in csv_map: # 已检测过的股票,使用CSV数据 stocks.append(csv_map[idx]) else: # 未检测过的股票,使用默认值 stocks.append({ "stock_idx": idx, "stock_code": tkrs[idx], "stock_name": tkrs_name[idx], "breakout_dir": "none", "breakout_strength_up": 0.0, "breakout_strength_down": 0.0, }) print(f" 其中满足三角形条件: {len(csv_stocks)} 只") print(f" 不满足条件(将显示基础K线): {len(tkrs) - len(csv_stocks)} 只") else: # 模式2: 仅绘制满足条件的股票 stocks = load_daily_stocks(args.input, target_date) print(f" 当日满足三角形的股票数: {len(stocks)}") if not stocks: print("当日无满足条件的股票") return # 4. 创建输出目录,并按需清空对应模式的旧图片 os.makedirs(args.output_dir, exist_ok=True) if args.clear: # 只清空当前模式的图片(简洁模式或详细模式) print(f"\n[4] 清空当前模式的旧图片...") # 找出当前模式的文件:简洁模式是不含_detail的.png,详细模式是_detail.png if show_details: old_files = [f for f in os.listdir(args.output_dir) if f.endswith('_detail.png')] else: old_files = [f for f in os.listdir(args.output_dir) if f.endswith('.png') and not f.endswith('_detail.png')] for f in old_files: os.remove(os.path.join(args.output_dir, f)) print(f" 已删除 {len(old_files)} 个旧图片 ({'详细模式' if show_details else '简洁模式'})") else: print(f"\n[4] 跳过清空旧图片(使用 --clear 可手动清空)") # 5. 检测参数(从统一配置导入) params = DETECTION_PARAMS # 6. 找到target_date在dates中的索引 date_idx = np.where(dates == target_date)[0] if len(date_idx) == 0: print(f"错误: 日期 {target_date} 不在数据范围内") return date_idx = date_idx[0] # 7. 生成图表 print(f"\n[3] 生成图表...") for stock in stocks: stock_idx = stock["stock_idx"] stock_code = stock["stock_code"] stock_name = stock["stock_name"] # 清理文件名中的非法字符(Windows文件名不允许: * ? " < > | : / \) stock_name_clean = stock_name.replace('*', '').replace('?', '').replace('"', '').replace('<', '').replace('>', '').replace('|', '').replace(':', '').replace('/', '').replace('\\', '') # 根据详细模式添加文件名后缀 suffix = "_detail" if show_details else "" output_filename = f"{target_date}_{stock_code}_{stock_name_clean}{suffix}.png" output_path = os.path.join(args.output_dir, output_filename) try: plot_triangle( stock_idx=stock_idx, stock_code=stock_code, stock_name=stock_name, date_idx=date_idx, high_mtx=high_mtx, low_mtx=low_mtx, close_mtx=close_mtx, volume_mtx=volume_mtx, dates=dates, params=params, output_path=output_path, display_window=DISPLAY_WINDOW, # 从配置文件读取 show_details=show_details, # 传递详细模式参数 force_plot=all_stocks, # 在all_stocks模式下强制绘图 plot_boundary_source=plot_boundary_source, show_high_low=show_high_low, ) except Exception as e: print(f" [错误] {stock_code} {stock_name}: {e}") print(f"\n完成!图表已保存至: {args.output_dir}") if __name__ == "__main__": main()