technical-patterns-lab/scripts/plot_converging_triangles.py
褚宏光 6d545eb231 Enhance converging triangle detection with new features and documentation updates
- Added support for a detailed chart mode in plot_converging_triangles.py, allowing users to visualize all pivot points and fitting lines.
- Improved pivot fitting logic to utilize multiple representative points, enhancing detection accuracy and reducing false positives.
- Introduced a new real-time detection mode with flexible zone parameters for better responsiveness in stock analysis.
- Updated README.md and USAGE.md to reflect new features and usage instructions.
- Added multiple documentation files detailing recent improvements, including pivot point fitting and visualization enhancements.
- Cleaned up and archived outdated scripts to streamline the project structure.
2026-01-26 16:21:36 +08:00

546 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
为当日满足收敛三角形的个股生成图表
用法:
# 简洁模式(默认)- 仅显示收盘价、上沿、下沿
python scripts/plot_converging_triangles.py
# 详细模式 - 显示所有枢轴点、拟合点、分段线
python scripts/plot_converging_triangles.py --show-details
# 指定日期
python scripts/plot_converging_triangles.py --date 20260120
"""
from __future__ import annotations
import argparse
import csv
import os
import pickle
import sys
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
import numpy as np
# 配置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS'] # 支持中文
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 让脚本能找到 src/ 下的模块
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from converging_triangle import (
ConvergingTriangleParams,
detect_converging_triangle,
fit_pivot_line,
line_y,
pivots_fractal,
)
# 导入统一的参数配置
from triangle_config import DETECTION_PARAMS, DISPLAY_WINDOW, SHOW_CHART_DETAILS
class FakeModule:
"""空壳模块,绕过 model 依赖"""
ndarray = np.ndarray
def load_pkl(pkl_path: str) -> dict:
"""加载 pkl 文件"""
sys.modules['model'] = FakeModule()
sys.modules['model.index_info'] = FakeModule()
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data
def load_ohlcv_from_pkl(data_dir: str) -> tuple:
"""从 pkl 文件加载 OHLCV 数据"""
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
dates = close_data["dtes"]
tkrs = close_data["tkrs"]
tkrs_name = close_data["tkrs_name"]
return (
open_data["mtx"],
high_data["mtx"],
low_data["mtx"],
close_data["mtx"],
volume_data["mtx"],
dates,
tkrs,
tkrs_name,
)
def load_daily_stocks(csv_path: str, target_date: int) -> List[Dict]:
"""从CSV读取指定日期的股票列表"""
stocks = []
with open(csv_path, newline="", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row in reader:
try:
date = int(row.get("date", "0"))
if date == target_date:
stocks.append({
"stock_idx": int(row.get("stock_idx", "0")),
"stock_code": row.get("stock_code", ""),
"stock_name": row.get("stock_name", ""),
"breakout_dir": row.get("breakout_dir", "none"),
"breakout_strength_up": float(row.get("breakout_strength_up", "0")),
"breakout_strength_down": float(row.get("breakout_strength_down", "0")),
})
except (ValueError, TypeError):
continue
return stocks
def plot_triangle(
stock_idx: int,
stock_code: str,
stock_name: str,
date_idx: int,
high_mtx: np.ndarray,
low_mtx: np.ndarray,
close_mtx: np.ndarray,
volume_mtx: np.ndarray,
dates: np.ndarray,
params: ConvergingTriangleParams,
output_path: str,
display_window: int = 500, # 显示窗口大小
show_details: bool = False, # 是否显示详细调试信息
) -> None:
"""绘制单只股票的收敛三角形图"""
# 提取该股票数据并过滤NaN
high_stock = high_mtx[stock_idx, :]
low_stock = low_mtx[stock_idx, :]
close_stock = close_mtx[stock_idx, :]
volume_stock = volume_mtx[stock_idx, :]
valid_mask = ~np.isnan(close_stock)
valid_indices = np.where(valid_mask)[0]
if len(valid_indices) < params.window:
print(f" [跳过] {stock_code} {stock_name}: 有效数据不足")
return
# 找到date_idx在有效数据中的位置
if date_idx not in valid_indices:
print(f" [跳过] {stock_code} {stock_name}: date_idx不在有效范围")
return
valid_end = np.where(valid_indices == date_idx)[0][0]
if valid_end < params.window - 1:
print(f" [跳过] {stock_code} {stock_name}: 窗口数据不足")
return
# 提取检测窗口数据(用于三角形检测)
detect_start = valid_end - params.window + 1
high_win = high_stock[valid_mask][detect_start:valid_end + 1]
low_win = low_stock[valid_mask][detect_start:valid_end + 1]
close_win = close_stock[valid_mask][detect_start:valid_end + 1]
volume_win = volume_stock[valid_mask][detect_start:valid_end + 1]
# 提取显示窗口数据(用于绘图,更长的历史)
display_start = max(0, valid_end - display_window + 1)
display_high = high_stock[valid_mask][display_start:valid_end + 1]
display_low = low_stock[valid_mask][display_start:valid_end + 1]
display_close = close_stock[valid_mask][display_start:valid_end + 1]
display_volume = volume_stock[valid_mask][display_start:valid_end + 1]
display_dates = dates[valid_indices[display_start:valid_end + 1]]
# 检测三角形(使用检测窗口数据)
result = detect_converging_triangle(
high=high_win,
low=low_win,
close=close_win,
volume=volume_win,
params=params,
stock_idx=stock_idx,
date_idx=date_idx,
)
if not result.is_valid:
print(f" [跳过] {stock_code} {stock_name}: 未识别到有效三角形")
return
# 绘图准备
x_display = np.arange(len(display_close), dtype=float)
# 计算三角形在显示窗口中的位置偏移
triangle_offset = len(display_close) - len(close_win)
# 获取检测窗口的起止索引(相对于检测窗口内部)
n = len(close_win)
x_win = np.arange(n, dtype=float)
# 计算枢轴点(与检测算法一致)
ph_idx, pl_idx = pivots_fractal(high_win, low_win, k=params.pivot_k)
# 使用枢轴点连线法拟合边界线(与检测算法一致)
a_u, b_u, selected_ph = fit_pivot_line(
pivot_indices=ph_idx,
pivot_values=high_win[ph_idx],
mode="upper",
)
a_l, b_l, selected_pl = fit_pivot_line(
pivot_indices=pl_idx,
pivot_values=low_win[pl_idx],
mode="lower",
)
# 三角形线段在显示窗口中的X坐标只画检测窗口范围
xw_in_display = np.arange(triangle_offset, triangle_offset + n, dtype=float)
# 计算Y值使用检测窗口内部的X坐标
upper_line = line_y(a_u, b_u, x_win)
lower_line = line_y(a_l, b_l, x_win)
# 获取选中的枢轴点在显示窗口中的位置(用于标注)
ph_display_idx = ph_idx + triangle_offset
pl_display_idx = pl_idx + triangle_offset
selected_ph_pos = ph_idx[selected_ph] if len(selected_ph) > 0 else np.array([], dtype=int)
selected_pl_pos = pl_idx[selected_pl] if len(selected_pl) > 0 else np.array([], dtype=int)
selected_ph_display = selected_ph_pos + triangle_offset
selected_pl_display = selected_pl_pos + triangle_offset
# 三角形检测窗口的日期范围(用于标题)
detect_dates = dates[valid_indices[detect_start:valid_end + 1]]
# 创建图表
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8),
gridspec_kw={'height_ratios': [3, 1]})
# 主图:价格和趋势线(使用显示窗口数据)
ax1.plot(x_display, display_close, linewidth=1.5, label='收盘价', color='black', alpha=0.7)
ax1.plot(xw_in_display, upper_line, linewidth=2, label='上沿', color='red', linestyle='--')
ax1.plot(xw_in_display, lower_line, linewidth=2, label='下沿', color='green', linestyle='--')
ax1.axvline(len(display_close) - 1, color='gray', linestyle=':', linewidth=1, alpha=0.5)
# ========================================================================
# 详细模式:显示所有枢轴点、拟合点、分段线(仅在 show_details=True 时)
# ========================================================================
if show_details:
# 标注所有枢轴点(小实心点,较浅颜色)
if len(ph_display_idx) > 0:
ax1.scatter(
ph_display_idx,
high_win[ph_idx],
marker='o',
s=50,
facecolors='red',
edgecolors='none',
alpha=0.4,
zorder=4,
label=f'所有高点枢轴点({len(ph_idx)})',
)
if len(pl_display_idx) > 0:
ax1.scatter(
pl_display_idx,
low_win[pl_idx],
marker='o',
s=50,
facecolors='green',
edgecolors='none',
alpha=0.4,
zorder=4,
label=f'所有低点枢轴点({len(pl_idx)})',
)
# 标注选中的枢轴点(用于拟合线的关键点,大空心圆)
if len(selected_ph_display) >= 2:
ax1.scatter(
selected_ph_display,
high_win[selected_ph_pos],
marker='o',
s=120,
facecolors='none',
edgecolors='red',
linewidths=2.5,
zorder=5,
label=f'上沿拟合点({len(selected_ph_pos)})',
)
if len(selected_pl_display) >= 2:
ax1.scatter(
selected_pl_display,
low_win[selected_pl_pos],
marker='o',
s=120,
facecolors='none',
edgecolors='green',
linewidths=2.5,
zorder=5,
label=f'下沿拟合点({len(selected_pl_pos)})',
)
# 绘制分段竖线(显示算法如何分段选择枢轴点)
# 高点和低点分别独立分段,用不同颜色显示
y_min, y_max = ax1.get_ylim()
# 绘制高点枢轴点的分段线(红色)
if len(ph_idx) > 4:
n_high = len(ph_idx)
segment_size_high = n_high // 3
# 第1段结束 = 第2段开始
if segment_size_high < n_high:
boundary_1 = ph_idx[segment_size_high] + triangle_offset
ax1.axvline(
boundary_1,
color='red',
linestyle='-.',
linewidth=1.2,
alpha=0.4,
zorder=3,
)
ax1.text(
boundary_1,
y_max * 0.96,
'高1|2',
ha='center',
va='top',
fontsize=7,
color='red',
bbox=dict(boxstyle='round,pad=0.2', facecolor='white', edgecolor='red', alpha=0.7),
)
# 第2段结束 = 第3段开始
if 2 * segment_size_high < n_high:
boundary_2 = ph_idx[2 * segment_size_high] + triangle_offset
ax1.axvline(
boundary_2,
color='red',
linestyle='-.',
linewidth=1.2,
alpha=0.4,
zorder=3,
)
ax1.text(
boundary_2,
y_max * 0.96,
'高2|3',
ha='center',
va='top',
fontsize=7,
color='red',
bbox=dict(boxstyle='round,pad=0.2', facecolor='white', edgecolor='red', alpha=0.7),
)
# 绘制低点枢轴点的分段线(绿色)
if len(pl_idx) > 4:
n_low = len(pl_idx)
segment_size_low = n_low // 3
# 第1段结束 = 第2段开始
if segment_size_low < n_low:
boundary_1 = pl_idx[segment_size_low] + triangle_offset
ax1.axvline(
boundary_1,
color='green',
linestyle='-.',
linewidth=1.2,
alpha=0.4,
zorder=3,
)
ax1.text(
boundary_1,
y_min + (y_max - y_min) * 0.04,
'低1|2',
ha='center',
va='bottom',
fontsize=7,
color='green',
bbox=dict(boxstyle='round,pad=0.2', facecolor='white', edgecolor='green', alpha=0.7),
)
# 第2段结束 = 第3段开始
if 2 * segment_size_low < n_low:
boundary_2 = pl_idx[2 * segment_size_low] + triangle_offset
ax1.axvline(
boundary_2,
color='green',
linestyle='-.',
linewidth=1.2,
alpha=0.4,
zorder=3,
)
ax1.text(
boundary_2,
y_min + (y_max - y_min) * 0.04,
'低2|3',
ha='center',
va='bottom',
fontsize=7,
color='green',
bbox=dict(boxstyle='round,pad=0.2', facecolor='white', edgecolor='green', alpha=0.7),
)
ax1.set_title(
f"{stock_code} {stock_name} - 收敛三角形 (检测窗口: {detect_dates[0]} ~ {detect_dates[-1]})\n"
f"显示范围: {display_dates[0]} ~ {display_dates[-1]} ({len(display_dates)}个交易日) "
f"突破方向: {result.breakout_dir} 宽度比: {result.width_ratio:.2f} "
f"枢轴点: 高{len(ph_idx)}/低{len(pl_idx)} 触碰: 上{result.touches_upper}/下{result.touches_lower} "
f"放量确认: {'' if result.volume_confirmed else '' if result.volume_confirmed is False else '-'}",
fontsize=11, pad=10
)
ax1.set_ylabel('价格', fontsize=10)
ax1.legend(loc='best', fontsize=9)
ax1.grid(True, alpha=0.3)
# X轴日期标签稀疏显示基于显示窗口
step = max(1, len(display_dates) // 10)
tick_indices = np.arange(0, len(display_dates), step)
ax1.set_xticks(tick_indices)
ax1.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8)
# 副图:成交量(使用显示窗口数据)
ax2.bar(x_display, display_volume, width=0.8, color='skyblue', alpha=0.6)
ax2.set_ylabel('成交量', fontsize=10)
ax2.set_xlabel('交易日', fontsize=10)
ax2.grid(True, alpha=0.3, axis='y')
ax2.set_xticks(tick_indices)
ax2.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8)
plt.tight_layout()
plt.savefig(output_path, dpi=120)
plt.close()
print(f" [完成] {stock_code} {stock_name} -> {output_path}")
def main() -> None:
parser = argparse.ArgumentParser(description="为当日满足收敛三角形的个股生成图表")
parser.add_argument(
"--input",
default=os.path.join("outputs", "converging_triangles", "all_results.csv"),
help="输入 CSV 路径",
)
parser.add_argument(
"--date",
type=int,
default=None,
help="指定日期YYYYMMDD默认为数据最新日",
)
parser.add_argument(
"--output-dir",
default=os.path.join("outputs", "converging_triangles", "charts"),
help="图表输出目录",
)
parser.add_argument(
"--show-details",
action="store_true",
help="显示详细调试信息(枢轴点、拟合点、分段线等)",
)
args = parser.parse_args()
# 确定是否显示详细信息(命令行参数优先)
show_details = args.show_details if hasattr(args, 'show_details') else SHOW_CHART_DETAILS
print("=" * 70)
print("收敛三角形图表生成")
print("=" * 70)
print(f"详细模式: {'开启' if show_details else '关闭'} {'(--show-details)' if show_details else '(简洁模式)'}")
# 1. 加载数据
print("\n[1] 加载 OHLCV 数据...")
data_dir = os.path.join(os.path.dirname(__file__), "..", "data")
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(data_dir)
print(f" 股票数: {len(tkrs)}, 交易日数: {len(dates)}")
# 2. 确定目标日期
if args.date:
target_date = args.date
else:
# 从CSV读取最新日期
with open(args.input, newline="", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
all_dates = [int(r.get("date", "0")) for r in reader if r.get("date") and r["date"].isdigit()]
target_date = max(all_dates) if all_dates else 0
if target_date == 0:
print("错误: 无法确定目标日期")
return
print(f"\n[2] 目标日期: {target_date}")
# 3. 加载当日股票列表
stocks = load_daily_stocks(args.input, target_date)
print(f" 当日满足三角形的股票数: {len(stocks)}")
if not stocks:
print("当日无满足条件的股票")
return
# 4. 创建输出目录并清空对应模式的旧图片
os.makedirs(args.output_dir, exist_ok=True)
# 只清空当前模式的图片(简洁模式或详细模式)
print(f"\n[4] 清空当前模式的旧图片...")
suffix = "_detail.png" if show_details else ".png"
# 找出当前模式的文件简洁模式是不含_detail的.png详细模式是_detail.png
if show_details:
old_files = [f for f in os.listdir(args.output_dir) if f.endswith('_detail.png')]
else:
old_files = [f for f in os.listdir(args.output_dir)
if f.endswith('.png') and not f.endswith('_detail.png')]
for f in old_files:
os.remove(os.path.join(args.output_dir, f))
print(f" 已删除 {len(old_files)} 个旧图片 ({'详细模式' if show_details else '简洁模式'})")
# 5. 检测参数(从统一配置导入)
params = DETECTION_PARAMS
# 6. 找到target_date在dates中的索引
date_idx = np.where(dates == target_date)[0]
if len(date_idx) == 0:
print(f"错误: 日期 {target_date} 不在数据范围内")
return
date_idx = date_idx[0]
# 7. 生成图表
print(f"\n[3] 生成图表...")
for stock in stocks:
stock_idx = stock["stock_idx"]
stock_code = stock["stock_code"]
stock_name = stock["stock_name"]
# 根据详细模式添加文件名后缀
suffix = "_detail" if show_details else ""
output_filename = f"{target_date}_{stock_code}_{stock_name}{suffix}.png"
output_path = os.path.join(args.output_dir, output_filename)
try:
plot_triangle(
stock_idx=stock_idx,
stock_code=stock_code,
stock_name=stock_name,
date_idx=date_idx,
high_mtx=high_mtx,
low_mtx=low_mtx,
close_mtx=close_mtx,
volume_mtx=volume_mtx,
dates=dates,
params=params,
output_path=output_path,
display_window=DISPLAY_WINDOW, # 从配置文件读取
show_details=show_details, # 传递详细模式参数
)
except Exception as e:
print(f" [错误] {stock_code} {stock_name}: {e}")
print(f"\n完成!图表已保存至: {args.output_dir}")
if __name__ == "__main__":
main()