""" 为当日满足收敛三角形的个股生成图表(K线图模式) 用法: # 简洁模式(默认)- 仅显示K线、上沿、下沿 python scripts/plot_converging_triangles.py # 详细模式 - 显示所有枢轴点、拟合点 python scripts/plot_converging_triangles.py --show-details # 指定日期 python scripts/plot_converging_triangles.py --date 20260120 """ from __future__ import annotations import argparse import csv import os import pickle import sys from typing import Dict, List, Optional import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.patches import Rectangle import numpy as np import pandas as pd # 配置中文字体 plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS'] # 支持中文 plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 # 让脚本能找到 src/ 下的模块 sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) from converging_triangle import ( ConvergingTriangleParams, calc_fitting_adherence, detect_converging_triangle, fit_pivot_line_dispatch, line_y, pivots_fractal, pivots_fractal_hybrid, ) # 导入统一的参数配置 from triangle_config import DETECTION_PARAMS, DISPLAY_WINDOW, SHOW_CHART_DETAILS, REALTIME_MODE, FLEXIBLE_ZONE class FakeModule: """空壳模块,绕过 model 依赖""" ndarray = np.ndarray def load_pkl(pkl_path: str) -> dict: """加载 pkl 文件""" sys.modules['model'] = FakeModule() sys.modules['model.index_info'] = FakeModule() with open(pkl_path, 'rb') as f: data = pickle.load(f) return data def load_ohlcv_from_pkl(data_dir: str) -> tuple: """从 pkl 文件加载 OHLCV 数据""" open_data = load_pkl(os.path.join(data_dir, "open.pkl")) high_data = load_pkl(os.path.join(data_dir, "high.pkl")) low_data = load_pkl(os.path.join(data_dir, "low.pkl")) close_data = load_pkl(os.path.join(data_dir, "close.pkl")) volume_data = load_pkl(os.path.join(data_dir, "volume.pkl")) dates = close_data["dtes"] tkrs = close_data["tkrs"] tkrs_name = close_data["tkrs_name"] return ( open_data["mtx"], high_data["mtx"], low_data["mtx"], close_data["mtx"], volume_data["mtx"], dates, tkrs, tkrs_name, ) def load_daily_stocks(csv_path: str, target_date: int) -> List[Dict]: """从CSV读取指定日期的股票列表""" stocks = [] with open(csv_path, newline="", encoding="utf-8-sig") as f: reader = csv.DictReader(f) for row in reader: try: date = int(row.get("date", "0")) if date == target_date: stocks.append({ "stock_idx": int(row.get("stock_idx", "0")), "stock_code": row.get("stock_code", ""), "stock_name": row.get("stock_name", ""), "breakout_dir": row.get("breakout_dir", "none"), "breakout_strength_up": float(row.get("breakout_strength_up", "0")), "breakout_strength_down": float(row.get("breakout_strength_down", "0")), }) except (ValueError, TypeError): continue return stocks def plot_triangle( stock_idx: int, stock_code: str, stock_name: str, date_idx: int, open_mtx: np.ndarray, high_mtx: np.ndarray, low_mtx: np.ndarray, close_mtx: np.ndarray, volume_mtx: np.ndarray, dates: np.ndarray, params: ConvergingTriangleParams, output_path: str, display_window: int = 500, # 显示窗口大小 show_details: bool = False, # 是否显示详细调试信息 force_plot: bool = False, # 强制绘图(即使不满足三角形条件) plot_boundary_source: str = "hl", # 边界线拟合数据源: "hl" | "close" ) -> None: """绘制单只股票的收敛三角形图(K线图模式)""" # 提取该股票数据并过滤NaN open_stock = open_mtx[stock_idx, :] high_stock = high_mtx[stock_idx, :] low_stock = low_mtx[stock_idx, :] close_stock = close_mtx[stock_idx, :] volume_stock = volume_mtx[stock_idx, :] valid_mask = ~np.isnan(close_stock) valid_indices = np.where(valid_mask)[0] if len(valid_indices) < params.window: if force_plot: print(f" [警告] {stock_code} {stock_name}: 有效数据不足,仅绘制基础K线") # 继续绘制基础K线 else: print(f" [跳过] {stock_code} {stock_name}: 有效数据不足") return # 找到date_idx在有效数据中的位置 if date_idx not in valid_indices: print(f" [跳过] {stock_code} {stock_name}: date_idx不在有效范围") return valid_end = np.where(valid_indices == date_idx)[0][0] # 检查窗口数据是否充足 has_enough_data = valid_end >= params.window - 1 if not has_enough_data and not force_plot: print(f" [跳过] {stock_code} {stock_name}: 窗口数据不足") return # 提取检测窗口数据(用于三角形检测) if has_enough_data: detect_start = valid_end - params.window + 1 high_win = high_stock[valid_mask][detect_start:valid_end + 1] low_win = low_stock[valid_mask][detect_start:valid_end + 1] close_win = close_stock[valid_mask][detect_start:valid_end + 1] volume_win = volume_stock[valid_mask][detect_start:valid_end + 1] else: # 数据不足,使用所有可用数据 detect_start = 0 high_win = high_stock[valid_mask][:valid_end + 1] low_win = low_stock[valid_mask][:valid_end + 1] close_win = close_stock[valid_mask][:valid_end + 1] volume_win = volume_stock[valid_mask][:valid_end + 1] # 提取显示窗口数据(用于绘图,更长的历史) display_start = max(0, valid_end - display_window + 1) display_open = open_stock[valid_mask][display_start:valid_end + 1] display_high = high_stock[valid_mask][display_start:valid_end + 1] display_low = low_stock[valid_mask][display_start:valid_end + 1] display_close = close_stock[valid_mask][display_start:valid_end + 1] display_volume = volume_stock[valid_mask][display_start:valid_end + 1] display_dates = dates[valid_indices[display_start:valid_end + 1]] # ======================================================================== # 计算三角形参数(用于绘图) # force_plot模式:即使不满足条件也尝试检测,检测失败则只画K线 # ======================================================================== result = None has_triangle = False fitting_adherence_plot = 0.0 # 初始化贴合度 if has_enough_data: result = detect_converging_triangle( high=high_win, low=low_win, close=close_win, volume=volume_win, params=params, stock_idx=stock_idx, date_idx=date_idx, real_time_mode=REALTIME_MODE, flexible_zone=FLEXIBLE_ZONE, ) has_triangle = result.is_valid if result else False # 在force_plot模式下,即使没有有效三角形也继续绘图(仅K线) if not force_plot and not has_triangle: print(f" [跳过] {stock_code} {stock_name}: 不满足收敛三角形条件") return # 绘图准备 x_display = np.arange(len(display_close), dtype=float) # 只在有三角形时计算三角形相关参数 if has_triangle and has_enough_data: # 计算三角形在显示窗口中的位置偏移 triangle_offset = len(display_close) - len(close_win) # 获取检测窗口的起止索引(相对于检测窗口内部) n = len(close_win) x_win = np.arange(n, dtype=float) # 计算枢轴点(与检测算法一致,考虑实时模式) if REALTIME_MODE: confirmed_ph, confirmed_pl, candidate_ph, candidate_pl = pivots_fractal_hybrid( high_win, low_win, k=params.pivot_k, flexible_zone=FLEXIBLE_ZONE ) # 合并确认枢轴点和候选枢轴点 ph_idx = np.concatenate([confirmed_ph, candidate_ph]) if len(candidate_ph) > 0 else confirmed_ph pl_idx = np.concatenate([confirmed_pl, candidate_pl]) if len(candidate_pl) > 0 else confirmed_pl # 排序以保证顺序 ph_idx = np.sort(ph_idx) pl_idx = np.sort(pl_idx) else: ph_idx, pl_idx = pivots_fractal(high_win, low_win, k=params.pivot_k) # 使用枢轴点连线法拟合边界线(与检测算法一致) # 注意:绘图用的是检测窗口数据,因此 window_start=0, window_end=n-1 if plot_boundary_source == "close": upper_fit_values = close_win[ph_idx] lower_fit_values = close_win[pl_idx] upper_all_prices = close_win lower_all_prices = close_win else: upper_fit_values = high_win[ph_idx] lower_fit_values = low_win[pl_idx] upper_all_prices = high_win lower_all_prices = low_win a_u, b_u, selected_ph = fit_pivot_line_dispatch( pivot_indices=ph_idx, pivot_values=upper_fit_values, mode="upper", method=params.fitting_method, all_prices=upper_all_prices, window_start=0, window_end=n - 1, ) a_l, b_l, selected_pl = fit_pivot_line_dispatch( pivot_indices=pl_idx, pivot_values=lower_fit_values, mode="lower", method=params.fitting_method, all_prices=lower_all_prices, window_start=0, window_end=n - 1, ) # 三角形线段在显示窗口中的X坐标(只画检测窗口范围) xw_in_display = np.arange(triangle_offset, triangle_offset + n, dtype=float) # 计算Y值使用检测窗口内部的X坐标 upper_line = line_y(a_u, b_u, x_win) lower_line = line_y(a_l, b_l, x_win) # 获取选中的枢轴点在显示窗口中的位置(用于标注) ph_display_idx = ph_idx + triangle_offset pl_display_idx = pl_idx + triangle_offset selected_ph_pos = ph_idx[selected_ph] if len(selected_ph) > 0 else np.array([], dtype=int) selected_pl_pos = pl_idx[selected_pl] if len(selected_pl) > 0 else np.array([], dtype=int) selected_ph_display = selected_ph_pos + triangle_offset selected_pl_display = selected_pl_pos + triangle_offset # 三角形检测窗口的日期范围(用于标题) detect_dates = dates[valid_indices[detect_start:valid_end + 1]] # 如果使用收盘价拟合,重新计算贴合度(基于实际拟合线) if plot_boundary_source == "close" and len(selected_ph) > 0 and len(selected_pl) > 0: # 使用收盘价重新计算贴合度 adherence_upper_close = calc_fitting_adherence( pivot_indices=selected_ph_pos.astype(float), pivot_values=close_win[selected_ph_pos], slope=a_u, intercept=b_u, ) adherence_lower_close = calc_fitting_adherence( pivot_indices=selected_pl_pos.astype(float), pivot_values=close_win[selected_pl_pos], slope=a_l, intercept=b_l, ) fitting_adherence_plot = (adherence_upper_close + adherence_lower_close) / 2.0 else: # 使用检测算法计算的贴合度 fitting_adherence_plot = result.fitting_score if result else 0.0 else: # 无三角形时,贴合度为0 fitting_adherence_plot = 0.0 # 创建图表 fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), gridspec_kw={'height_ratios': [3, 1]}) # ======================================================================== # 主图:K线图 + 趋势线 # ======================================================================== # 绘制K线 for i in range(len(x_display)): # 确定颜色:涨为红色,跌为绿色 if i == 0: color = 'red' if display_close[i] >= display_open[i] else 'green' else: color = 'red' if display_close[i] >= display_open[i] else 'green' # 绘制影线(最高到最低) ax1.plot([x_display[i], x_display[i]], [display_low[i], display_high[i]], color=color, linewidth=0.8, alpha=0.8) # 绘制实体(开盘到收盘) body_height = abs(display_close[i] - display_open[i]) body_bottom = min(display_open[i], display_close[i]) if body_height < 0.001: # 十字星 ax1.plot([x_display[i] - 0.3, x_display[i] + 0.3], [display_close[i], display_close[i]], color=color, linewidth=1.5) else: rect = Rectangle((x_display[i] - 0.3, body_bottom), 0.6, body_height, facecolor=color if display_close[i] < display_open[i] else 'white', edgecolor=color, linewidth=1.2) ax1.add_patch(rect) # 只在有三角形时绘制趋势线 if has_triangle and has_enough_data: boundary_label = "收盘价" if plot_boundary_source == "close" else "高低价" ax1.plot(xw_in_display, upper_line, linewidth=2.5, label=f'上沿({boundary_label})', color='darkred', linestyle='--', alpha=0.8, zorder=5) ax1.plot(xw_in_display, lower_line, linewidth=2.5, label=f'下沿({boundary_label})', color='darkgreen', linestyle='--', alpha=0.8, zorder=5) # 当前日期的标记线 ax1.axvline(len(display_close) - 1, color='blue', linestyle=':', linewidth=1.5, alpha=0.6) # ======================================================================== # 详细模式:显示拟合点(仅在 show_details=True 且有三角形时) # ======================================================================== if show_details and has_triangle and has_enough_data: # 根据数据源选择枢轴点的Y坐标 if plot_boundary_source == "close": ph_pivot_y = close_win[ph_idx] pl_pivot_y = close_win[pl_idx] selected_ph_y = close_win[selected_ph_pos] if len(selected_ph_pos) > 0 else np.array([]) selected_pl_y = close_win[selected_pl_pos] if len(selected_pl_pos) > 0 else np.array([]) else: ph_pivot_y = high_win[ph_idx] pl_pivot_y = low_win[pl_idx] selected_ph_y = high_win[selected_ph_pos] if len(selected_ph_pos) > 0 else np.array([]) selected_pl_y = low_win[selected_pl_pos] if len(selected_pl_pos) > 0 else np.array([]) # 标注所有枢轴点(用于查看拐点分布) if len(ph_display_idx) > 0: ax1.scatter( ph_display_idx, ph_pivot_y, marker='x', s=60, color='red', alpha=0.6, zorder=4, label=f'上沿枢轴点({len(ph_idx)})', ) if len(pl_display_idx) > 0: ax1.scatter( pl_display_idx, pl_pivot_y, marker='x', s=60, color='green', alpha=0.6, zorder=4, label=f'下沿枢轴点({len(pl_idx)})', ) # 标注选中的枢轴点(用于拟合线的关键点) if len(selected_ph_display) >= 2: ax1.scatter( selected_ph_display, selected_ph_y, marker='o', s=120, facecolors='none', edgecolors='red', linewidths=2.5, zorder=5, label=f'上沿拟合点({len(selected_ph_pos)})', ) if len(selected_pl_display) >= 2: ax1.scatter( selected_pl_display, selected_pl_y, marker='o', s=120, facecolors='none', edgecolors='green', linewidths=2.5, zorder=5, label=f'下沿拟合点({len(selected_pl_pos)})', ) # 准备标题内容 if has_triangle and has_enough_data and result: # 有效三角形:显示完整信息和强度分 if result.breakout_dir == "up": strength = result.breakout_strength_up price_score = result.price_score_up elif result.breakout_dir == "down": strength = result.breakout_strength_down price_score = result.price_score_down else: strength = max(result.breakout_strength_up, result.breakout_strength_down) price_score = max(result.price_score_up, result.price_score_down) # 获取边界利用率与惩罚系数(兼容旧数据) boundary_util = getattr(result, 'boundary_utilization', 0.0) utilization_floor = 0.20 if utilization_floor > 0: utilization_penalty = min(1.0, boundary_util / utilization_floor) else: utilization_penalty = 1.0 # 选择显示的贴合度:如果使用收盘价拟合,显示重新计算的贴合度 if plot_boundary_source == "close" and has_triangle and has_enough_data: display_fitting_score = fitting_adherence_plot fitting_note = f"拟合贴合度(收盘价): {display_fitting_score:.3f}" else: display_fitting_score = result.fitting_score fitting_note = f"拟合贴合度: {display_fitting_score:.3f}" ax1.set_title( f"{stock_code} {stock_name} - 收敛三角形 (检测窗口: {detect_dates[0]} ~ {detect_dates[-1]})\n" f"显示范围: {display_dates[0]} ~ {display_dates[-1]} ({len(display_dates)}个交易日) " f"突破方向: {result.breakout_dir} 宽度比: {result.width_ratio:.2f} " f"枢轴点: 高{len(ph_idx)}/低{len(pl_idx)} 触碰: 上{result.touches_upper}/下{result.touches_lower} " f"放量确认: {'是' if result.volume_confirmed else '否' if result.volume_confirmed is False else '-'}\n" f"强度分: {strength:.3f} " f"(价格: {price_score:.3f}×50% + 收敛: {result.convergence_score:.3f}×15% + " f"成交量: {result.volume_score:.3f}×10% + {fitting_note}×10% + " f"边界利用率: {boundary_util:.3f}×15%) × 利用率惩罚: {utilization_penalty:.2f}", fontsize=11, pad=10 ) else: # 无三角形:仅显示基础信息和强度分0分 ax1.set_title( f"{stock_code} {stock_name} - K线图(不满足收敛三角形条件)\n" f"显示范围: {display_dates[0]} ~ {display_dates[-1]} ({len(display_dates)}个交易日)\n" f"强度分: 0.000 (未检测到收敛三角形形态)", fontsize=11, pad=10 ) ax1.set_ylabel('价格', fontsize=10) ax1.legend(loc='best', fontsize=9) ax1.grid(True, alpha=0.3) # X轴日期标签(稀疏显示,基于显示窗口) step = max(1, len(display_dates) // 10) tick_indices = np.arange(0, len(display_dates), step) ax1.set_xticks(tick_indices) ax1.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8) # 副图:成交量(使用显示窗口数据) ax2.bar(x_display, display_volume, width=0.8, color='skyblue', alpha=0.6) ax2.set_ylabel('成交量', fontsize=10) ax2.set_xlabel('交易日', fontsize=10) ax2.grid(True, alpha=0.3, axis='y') ax2.set_xticks(tick_indices) ax2.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8) plt.tight_layout() plt.savefig(output_path, dpi=120) plt.close() print(f" [完成] {stock_code} {stock_name} -> {output_path}") def main() -> None: parser = argparse.ArgumentParser(description="为当日满足收敛三角形的个股生成图表") parser.add_argument( "--input", default=os.path.join("outputs", "converging_triangles", "all_results.csv"), help="输入 CSV 路径", ) parser.add_argument( "--date", type=int, default=None, help="指定日期(YYYYMMDD),默认为数据最新日", ) parser.add_argument( "--output-dir", default=os.path.join("outputs", "converging_triangles", "charts"), help="图表输出目录", ) parser.add_argument( "--show-details", action="store_true", help="显示详细调试信息(枢轴点、拟合点、分段线等)", ) parser.add_argument( "--plot-boundary-source", choices=["hl", "close"], default="hl", help="绘图时边界线拟合数据源: hl=高低价(默认), close=收盘价(不影响检测)", ) parser.add_argument( "--all-stocks", action="store_true", help="为所有108只股票生成图表(包括不满足收敛三角形条件的)", ) parser.add_argument( "--clear", action="store_true", help="清空当前模式的旧图片(默认不清空)", ) args = parser.parse_args() # 确定是否显示详细信息(命令行参数优先) show_details = args.show_details if hasattr(args, 'show_details') else SHOW_CHART_DETAILS all_stocks = args.all_stocks if hasattr(args, 'all_stocks') else False plot_boundary_source = args.plot_boundary_source if hasattr(args, 'plot_boundary_source') else "hl" print("=" * 70) print("收敛三角形图表生成") print("=" * 70) print(f"详细模式: {'开启' if show_details else '关闭'} {'(--show-details)' if show_details else '(简洁模式)'}") print(f"图表范围: {'所有108只股票' if all_stocks else '仅满足条件的股票'} {'(--all-stocks)' if all_stocks else ''}") # 1. 加载数据 print("\n[1] 加载 OHLCV 数据...") data_dir = os.path.join(os.path.dirname(__file__), "..", "data") open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(data_dir) print(f" 股票数: {len(tkrs)}, 交易日数: {len(dates)}") # 2. 确定目标日期 if args.date: target_date = args.date else: # 从CSV读取最新日期 with open(args.input, newline="", encoding="utf-8-sig") as f: reader = csv.DictReader(f) all_dates = [int(r.get("date", "0")) for r in reader if r.get("date") and r["date"].isdigit()] target_date = max(all_dates) if all_dates else 0 if target_date == 0: print("错误: 无法确定目标日期") return print(f"\n[2] 目标日期: {target_date}") # 3. 加载当日股票列表 if all_stocks: # 模式1: 绘制所有股票(包括不满足条件的) print(f" 模式: 所有股票") print(f" 股票总数: {len(tkrs)}") # 为所有股票创建stock字典(从CSV获取已有的强度分,未检测的置为0) stocks = [] csv_stocks = load_daily_stocks(args.input, target_date) csv_map = {s["stock_idx"]: s for s in csv_stocks} for idx in range(len(tkrs)): if idx in csv_map: # 已检测过的股票,使用CSV数据 stocks.append(csv_map[idx]) else: # 未检测过的股票,使用默认值 stocks.append({ "stock_idx": idx, "stock_code": tkrs[idx], "stock_name": tkrs_name[idx], "breakout_dir": "none", "breakout_strength_up": 0.0, "breakout_strength_down": 0.0, }) print(f" 其中满足三角形条件: {len(csv_stocks)} 只") print(f" 不满足条件(将显示基础K线): {len(tkrs) - len(csv_stocks)} 只") else: # 模式2: 仅绘制满足条件的股票 stocks = load_daily_stocks(args.input, target_date) print(f" 当日满足三角形的股票数: {len(stocks)}") if not stocks: print("当日无满足条件的股票") return # 4. 创建输出目录,并按需清空对应模式的旧图片 os.makedirs(args.output_dir, exist_ok=True) if args.clear: # 只清空当前模式的图片(简洁模式或详细模式) print(f"\n[4] 清空当前模式的旧图片...") # 找出当前模式的文件:简洁模式是不含_detail的.png,详细模式是_detail.png if show_details: old_files = [f for f in os.listdir(args.output_dir) if f.endswith('_detail.png')] else: old_files = [f for f in os.listdir(args.output_dir) if f.endswith('.png') and not f.endswith('_detail.png')] for f in old_files: os.remove(os.path.join(args.output_dir, f)) print(f" 已删除 {len(old_files)} 个旧图片 ({'详细模式' if show_details else '简洁模式'})") else: print(f"\n[4] 跳过清空旧图片(使用 --clear 可手动清空)") # 5. 检测参数(从统一配置导入) params = DETECTION_PARAMS # 6. 找到target_date在dates中的索引 date_idx = np.where(dates == target_date)[0] if len(date_idx) == 0: print(f"错误: 日期 {target_date} 不在数据范围内") return date_idx = date_idx[0] # 7. 生成图表 print(f"\n[3] 生成图表...") for stock in stocks: stock_idx = stock["stock_idx"] stock_code = stock["stock_code"] stock_name = stock["stock_name"] # 清理文件名中的非法字符(Windows文件名不允许: * ? " < > | : / \) stock_name_clean = stock_name.replace('*', '').replace('?', '').replace('"', '').replace('<', '').replace('>', '').replace('|', '').replace(':', '').replace('/', '').replace('\\', '') # 根据详细模式添加文件名后缀 suffix = "_detail" if show_details else "" output_filename = f"{target_date}_{stock_code}_{stock_name_clean}{suffix}.png" output_path = os.path.join(args.output_dir, output_filename) try: plot_triangle( stock_idx=stock_idx, stock_code=stock_code, stock_name=stock_name, date_idx=date_idx, open_mtx=open_mtx, high_mtx=high_mtx, low_mtx=low_mtx, close_mtx=close_mtx, volume_mtx=volume_mtx, dates=dates, params=params, output_path=output_path, display_window=DISPLAY_WINDOW, # 从配置文件读取 show_details=show_details, # 传递详细模式参数 force_plot=all_stocks, # 在all_stocks模式下强制绘图 plot_boundary_source=plot_boundary_source, ) except Exception as e: print(f" [错误] {stock_code} {stock_name}: {e}") print(f"\n完成!图表已保存至: {args.output_dir}") if __name__ == "__main__": main()