technical-patterns-lab/scripts/plot_converging_triangles.py
褚宏光 95d13b2cce Enhance converging triangle analysis with detailed mode and outlier removal algorithm
- Added `--show-details` parameter to `pipeline_converging_triangle.py` for generating detailed charts that display all pivot points and fitting lines.
- Implemented an iterative outlier removal algorithm in `fit_pivot_line` to improve the accuracy of pivot point fitting by eliminating weak points.
- Updated `USAGE.md` to include new command examples for the detailed mode.
- Revised multiple documentation files to reflect recent changes and improvements in the pivot detection and visualization processes.
2026-01-26 18:43:18 +08:00

459 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
为当日满足收敛三角形的个股生成图表
用法:
# 简洁模式(默认)- 仅显示收盘价、上沿、下沿
python scripts/plot_converging_triangles.py
# 详细模式 - 显示所有枢轴点、拟合点、分段线
python scripts/plot_converging_triangles.py --show-details
# 指定日期
python scripts/plot_converging_triangles.py --date 20260120
"""
from __future__ import annotations
import argparse
import csv
import os
import pickle
import sys
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
import numpy as np
# 配置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS'] # 支持中文
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 让脚本能找到 src/ 下的模块
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from converging_triangle import (
ConvergingTriangleParams,
detect_converging_triangle,
fit_pivot_line,
line_y,
pivots_fractal,
pivots_fractal_hybrid,
)
# 导入统一的参数配置
from triangle_config import DETECTION_PARAMS, DISPLAY_WINDOW, SHOW_CHART_DETAILS, REALTIME_MODE, FLEXIBLE_ZONE
class FakeModule:
"""空壳模块,绕过 model 依赖"""
ndarray = np.ndarray
def load_pkl(pkl_path: str) -> dict:
"""加载 pkl 文件"""
sys.modules['model'] = FakeModule()
sys.modules['model.index_info'] = FakeModule()
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data
def load_ohlcv_from_pkl(data_dir: str) -> tuple:
"""从 pkl 文件加载 OHLCV 数据"""
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
dates = close_data["dtes"]
tkrs = close_data["tkrs"]
tkrs_name = close_data["tkrs_name"]
return (
open_data["mtx"],
high_data["mtx"],
low_data["mtx"],
close_data["mtx"],
volume_data["mtx"],
dates,
tkrs,
tkrs_name,
)
def load_daily_stocks(csv_path: str, target_date: int) -> List[Dict]:
"""从CSV读取指定日期的股票列表"""
stocks = []
with open(csv_path, newline="", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row in reader:
try:
date = int(row.get("date", "0"))
if date == target_date:
stocks.append({
"stock_idx": int(row.get("stock_idx", "0")),
"stock_code": row.get("stock_code", ""),
"stock_name": row.get("stock_name", ""),
"breakout_dir": row.get("breakout_dir", "none"),
"breakout_strength_up": float(row.get("breakout_strength_up", "0")),
"breakout_strength_down": float(row.get("breakout_strength_down", "0")),
})
except (ValueError, TypeError):
continue
return stocks
def plot_triangle(
stock_idx: int,
stock_code: str,
stock_name: str,
date_idx: int,
high_mtx: np.ndarray,
low_mtx: np.ndarray,
close_mtx: np.ndarray,
volume_mtx: np.ndarray,
dates: np.ndarray,
params: ConvergingTriangleParams,
output_path: str,
display_window: int = 500, # 显示窗口大小
show_details: bool = False, # 是否显示详细调试信息
) -> None:
"""绘制单只股票的收敛三角形图"""
# 提取该股票数据并过滤NaN
high_stock = high_mtx[stock_idx, :]
low_stock = low_mtx[stock_idx, :]
close_stock = close_mtx[stock_idx, :]
volume_stock = volume_mtx[stock_idx, :]
valid_mask = ~np.isnan(close_stock)
valid_indices = np.where(valid_mask)[0]
if len(valid_indices) < params.window:
print(f" [跳过] {stock_code} {stock_name}: 有效数据不足")
return
# 找到date_idx在有效数据中的位置
if date_idx not in valid_indices:
print(f" [跳过] {stock_code} {stock_name}: date_idx不在有效范围")
return
valid_end = np.where(valid_indices == date_idx)[0][0]
if valid_end < params.window - 1:
print(f" [跳过] {stock_code} {stock_name}: 窗口数据不足")
return
# 提取检测窗口数据(用于三角形检测)
detect_start = valid_end - params.window + 1
high_win = high_stock[valid_mask][detect_start:valid_end + 1]
low_win = low_stock[valid_mask][detect_start:valid_end + 1]
close_win = close_stock[valid_mask][detect_start:valid_end + 1]
volume_win = volume_stock[valid_mask][detect_start:valid_end + 1]
# 提取显示窗口数据(用于绘图,更长的历史)
display_start = max(0, valid_end - display_window + 1)
display_high = high_stock[valid_mask][display_start:valid_end + 1]
display_low = low_stock[valid_mask][display_start:valid_end + 1]
display_close = close_stock[valid_mask][display_start:valid_end + 1]
display_volume = volume_stock[valid_mask][display_start:valid_end + 1]
display_dates = dates[valid_indices[display_start:valid_end + 1]]
# ========================================================================
# 计算三角形参数(用于绘图)
# 注意:不验证 is_valid因为CSV中已经验证通过了
# 这里只是重新计算参数用于可视化
# ========================================================================
result = detect_converging_triangle(
high=high_win,
low=low_win,
close=close_win,
volume=volume_win,
params=params,
stock_idx=stock_idx,
date_idx=date_idx,
)
# 不再检查 is_valid直接绘图
# 原因CSV中已经包含了通过验证的股票这里只需要可视化
# 绘图准备
x_display = np.arange(len(display_close), dtype=float)
# 计算三角形在显示窗口中的位置偏移
triangle_offset = len(display_close) - len(close_win)
# 获取检测窗口的起止索引(相对于检测窗口内部)
n = len(close_win)
x_win = np.arange(n, dtype=float)
# 计算枢轴点(与检测算法一致,考虑实时模式)
if REALTIME_MODE:
confirmed_ph, confirmed_pl, candidate_ph, candidate_pl = pivots_fractal_hybrid(
high_win, low_win, k=params.pivot_k, flexible_zone=FLEXIBLE_ZONE
)
# 合并确认枢轴点和候选枢轴点
ph_idx = np.concatenate([confirmed_ph, candidate_ph]) if len(candidate_ph) > 0 else confirmed_ph
pl_idx = np.concatenate([confirmed_pl, candidate_pl]) if len(candidate_pl) > 0 else confirmed_pl
# 排序以保证顺序
ph_idx = np.sort(ph_idx)
pl_idx = np.sort(pl_idx)
else:
ph_idx, pl_idx = pivots_fractal(high_win, low_win, k=params.pivot_k)
# 使用枢轴点连线法拟合边界线(与检测算法一致)
a_u, b_u, selected_ph = fit_pivot_line(
pivot_indices=ph_idx,
pivot_values=high_win[ph_idx],
mode="upper",
)
a_l, b_l, selected_pl = fit_pivot_line(
pivot_indices=pl_idx,
pivot_values=low_win[pl_idx],
mode="lower",
)
# 三角形线段在显示窗口中的X坐标只画检测窗口范围
xw_in_display = np.arange(triangle_offset, triangle_offset + n, dtype=float)
# 计算Y值使用检测窗口内部的X坐标
upper_line = line_y(a_u, b_u, x_win)
lower_line = line_y(a_l, b_l, x_win)
# 获取选中的枢轴点在显示窗口中的位置(用于标注)
ph_display_idx = ph_idx + triangle_offset
pl_display_idx = pl_idx + triangle_offset
selected_ph_pos = ph_idx[selected_ph] if len(selected_ph) > 0 else np.array([], dtype=int)
selected_pl_pos = pl_idx[selected_pl] if len(selected_pl) > 0 else np.array([], dtype=int)
selected_ph_display = selected_ph_pos + triangle_offset
selected_pl_display = selected_pl_pos + triangle_offset
# 三角形检测窗口的日期范围(用于标题)
detect_dates = dates[valid_indices[detect_start:valid_end + 1]]
# 创建图表
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8),
gridspec_kw={'height_ratios': [3, 1]})
# 主图:价格和趋势线(使用显示窗口数据)
ax1.plot(x_display, display_close, linewidth=1.5, label='收盘价', color='black', alpha=0.7)
ax1.plot(xw_in_display, upper_line, linewidth=2, label='上沿', color='red', linestyle='--')
ax1.plot(xw_in_display, lower_line, linewidth=2, label='下沿', color='green', linestyle='--')
ax1.axvline(len(display_close) - 1, color='gray', linestyle=':', linewidth=1, alpha=0.5)
# ========================================================================
# 详细模式:显示所有枢轴点、拟合点、分段线(仅在 show_details=True 时)
# ========================================================================
if show_details:
# 标注所有枢轴点(小实心点,较浅颜色)
if len(ph_display_idx) > 0:
ax1.scatter(
ph_display_idx,
high_win[ph_idx],
marker='o',
s=50,
facecolors='red',
edgecolors='none',
alpha=0.4,
zorder=4,
label=f'所有高点枢轴点({len(ph_idx)})',
)
if len(pl_display_idx) > 0:
ax1.scatter(
pl_display_idx,
low_win[pl_idx],
marker='o',
s=50,
facecolors='green',
edgecolors='none',
alpha=0.4,
zorder=4,
label=f'所有低点枢轴点({len(pl_idx)})',
)
# 标注选中的枢轴点(用于拟合线的关键点,大空心圆)
if len(selected_ph_display) >= 2:
ax1.scatter(
selected_ph_display,
high_win[selected_ph_pos],
marker='o',
s=120,
facecolors='none',
edgecolors='red',
linewidths=2.5,
zorder=5,
label=f'上沿拟合点({len(selected_ph_pos)})',
)
if len(selected_pl_display) >= 2:
ax1.scatter(
selected_pl_display,
low_win[selected_pl_pos],
marker='o',
s=120,
facecolors='none',
edgecolors='green',
linewidths=2.5,
zorder=5,
label=f'下沿拟合点({len(selected_pl_pos)})',
)
ax1.set_title(
f"{stock_code} {stock_name} - 收敛三角形 (检测窗口: {detect_dates[0]} ~ {detect_dates[-1]})\n"
f"显示范围: {display_dates[0]} ~ {display_dates[-1]} ({len(display_dates)}个交易日) "
f"突破方向: {result.breakout_dir} 宽度比: {result.width_ratio:.2f} "
f"枢轴点: 高{len(ph_idx)}/低{len(pl_idx)} 触碰: 上{result.touches_upper}/下{result.touches_lower} "
f"放量确认: {'' if result.volume_confirmed else '' if result.volume_confirmed is False else '-'}",
fontsize=11, pad=10
)
ax1.set_ylabel('价格', fontsize=10)
ax1.legend(loc='best', fontsize=9)
ax1.grid(True, alpha=0.3)
# X轴日期标签稀疏显示基于显示窗口
step = max(1, len(display_dates) // 10)
tick_indices = np.arange(0, len(display_dates), step)
ax1.set_xticks(tick_indices)
ax1.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8)
# 副图:成交量(使用显示窗口数据)
ax2.bar(x_display, display_volume, width=0.8, color='skyblue', alpha=0.6)
ax2.set_ylabel('成交量', fontsize=10)
ax2.set_xlabel('交易日', fontsize=10)
ax2.grid(True, alpha=0.3, axis='y')
ax2.set_xticks(tick_indices)
ax2.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8)
plt.tight_layout()
plt.savefig(output_path, dpi=120)
plt.close()
print(f" [完成] {stock_code} {stock_name} -> {output_path}")
def main() -> None:
parser = argparse.ArgumentParser(description="为当日满足收敛三角形的个股生成图表")
parser.add_argument(
"--input",
default=os.path.join("outputs", "converging_triangles", "all_results.csv"),
help="输入 CSV 路径",
)
parser.add_argument(
"--date",
type=int,
default=None,
help="指定日期YYYYMMDD默认为数据最新日",
)
parser.add_argument(
"--output-dir",
default=os.path.join("outputs", "converging_triangles", "charts"),
help="图表输出目录",
)
parser.add_argument(
"--show-details",
action="store_true",
help="显示详细调试信息(枢轴点、拟合点、分段线等)",
)
args = parser.parse_args()
# 确定是否显示详细信息(命令行参数优先)
show_details = args.show_details if hasattr(args, 'show_details') else SHOW_CHART_DETAILS
print("=" * 70)
print("收敛三角形图表生成")
print("=" * 70)
print(f"详细模式: {'开启' if show_details else '关闭'} {'(--show-details)' if show_details else '(简洁模式)'}")
# 1. 加载数据
print("\n[1] 加载 OHLCV 数据...")
data_dir = os.path.join(os.path.dirname(__file__), "..", "data")
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(data_dir)
print(f" 股票数: {len(tkrs)}, 交易日数: {len(dates)}")
# 2. 确定目标日期
if args.date:
target_date = args.date
else:
# 从CSV读取最新日期
with open(args.input, newline="", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
all_dates = [int(r.get("date", "0")) for r in reader if r.get("date") and r["date"].isdigit()]
target_date = max(all_dates) if all_dates else 0
if target_date == 0:
print("错误: 无法确定目标日期")
return
print(f"\n[2] 目标日期: {target_date}")
# 3. 加载当日股票列表
stocks = load_daily_stocks(args.input, target_date)
print(f" 当日满足三角形的股票数: {len(stocks)}")
if not stocks:
print("当日无满足条件的股票")
return
# 4. 创建输出目录并清空对应模式的旧图片
os.makedirs(args.output_dir, exist_ok=True)
# 只清空当前模式的图片(简洁模式或详细模式)
print(f"\n[4] 清空当前模式的旧图片...")
suffix = "_detail.png" if show_details else ".png"
# 找出当前模式的文件简洁模式是不含_detail的.png详细模式是_detail.png
if show_details:
old_files = [f for f in os.listdir(args.output_dir) if f.endswith('_detail.png')]
else:
old_files = [f for f in os.listdir(args.output_dir)
if f.endswith('.png') and not f.endswith('_detail.png')]
for f in old_files:
os.remove(os.path.join(args.output_dir, f))
print(f" 已删除 {len(old_files)} 个旧图片 ({'详细模式' if show_details else '简洁模式'})")
# 5. 检测参数(从统一配置导入)
params = DETECTION_PARAMS
# 6. 找到target_date在dates中的索引
date_idx = np.where(dates == target_date)[0]
if len(date_idx) == 0:
print(f"错误: 日期 {target_date} 不在数据范围内")
return
date_idx = date_idx[0]
# 7. 生成图表
print(f"\n[3] 生成图表...")
for stock in stocks:
stock_idx = stock["stock_idx"]
stock_code = stock["stock_code"]
stock_name = stock["stock_name"]
# 根据详细模式添加文件名后缀
suffix = "_detail" if show_details else ""
output_filename = f"{target_date}_{stock_code}_{stock_name}{suffix}.png"
output_path = os.path.join(args.output_dir, output_filename)
try:
plot_triangle(
stock_idx=stock_idx,
stock_code=stock_code,
stock_name=stock_name,
date_idx=date_idx,
high_mtx=high_mtx,
low_mtx=low_mtx,
close_mtx=close_mtx,
volume_mtx=volume_mtx,
dates=dates,
params=params,
output_path=output_path,
display_window=DISPLAY_WINDOW, # 从配置文件读取
show_details=show_details, # 传递详细模式参数
)
except Exception as e:
print(f" [错误] {stock_code} {stock_name}: {e}")
print(f"\n完成!图表已保存至: {args.output_dir}")
if __name__ == "__main__":
main()