technical-patterns-lab/scripts/plot_converging_triangles.py
褚宏光 8dea3fbccb Enhance converging triangle analysis with new scripts and data outputs
- Updated all_results.csv with additional stock data and breakout strength metrics.
- Revised report.md to improve clarity and detail on stock selection criteria and results.
- Expanded strong_breakout_down.csv and strong_breakout_up.csv with new entries reflecting recent analysis.
- Introduced new chart images for selected stocks to visualize breakout patterns.
- Added plot_converging_triangles.py script for generating visualizations of stocks meeting convergence criteria.
- Enhanced report_converging_triangles.py to allow for date-specific reporting and improved output formatting.
- Optimized run_converging_triangle.py for performance and added execution time logging.
2026-01-22 10:00:47 +08:00

356 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
为当日满足收敛三角形的个股生成图表
用法:
python scripts/plot_converging_triangles.py
python scripts/plot_converging_triangles.py --date 20260120
"""
from __future__ import annotations
import argparse
import csv
import os
import pickle
import sys
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
import numpy as np
# 配置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS'] # 支持中文
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 让脚本能找到 src/ 下的模块
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from converging_triangle import (
ConvergingTriangleParams,
detect_converging_triangle,
line_y,
)
class FakeModule:
"""空壳模块,绕过 model 依赖"""
ndarray = np.ndarray
def load_pkl(pkl_path: str) -> dict:
"""加载 pkl 文件"""
sys.modules['model'] = FakeModule()
sys.modules['model.index_info'] = FakeModule()
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data
def load_ohlcv_from_pkl(data_dir: str) -> tuple:
"""从 pkl 文件加载 OHLCV 数据"""
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
dates = close_data["dtes"]
tkrs = close_data["tkrs"]
tkrs_name = close_data["tkrs_name"]
return (
open_data["mtx"],
high_data["mtx"],
low_data["mtx"],
close_data["mtx"],
volume_data["mtx"],
dates,
tkrs,
tkrs_name,
)
def load_daily_stocks(csv_path: str, target_date: int) -> List[Dict]:
"""从CSV读取指定日期的股票列表"""
stocks = []
with open(csv_path, newline="", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row in reader:
try:
date = int(row.get("date", "0"))
if date == target_date:
stocks.append({
"stock_idx": int(row.get("stock_idx", "0")),
"stock_code": row.get("stock_code", ""),
"stock_name": row.get("stock_name", ""),
"breakout_dir": row.get("breakout_dir", "none"),
"breakout_strength_up": float(row.get("breakout_strength_up", "0")),
"breakout_strength_down": float(row.get("breakout_strength_down", "0")),
})
except (ValueError, TypeError):
continue
return stocks
def plot_triangle(
stock_idx: int,
stock_code: str,
stock_name: str,
date_idx: int,
high_mtx: np.ndarray,
low_mtx: np.ndarray,
close_mtx: np.ndarray,
volume_mtx: np.ndarray,
dates: np.ndarray,
params: ConvergingTriangleParams,
output_path: str,
display_window: int = 500, # 新增:显示窗口大小
) -> None:
"""绘制单只股票的收敛三角形图"""
# 提取该股票数据并过滤NaN
high_stock = high_mtx[stock_idx, :]
low_stock = low_mtx[stock_idx, :]
close_stock = close_mtx[stock_idx, :]
volume_stock = volume_mtx[stock_idx, :]
valid_mask = ~np.isnan(close_stock)
valid_indices = np.where(valid_mask)[0]
if len(valid_indices) < params.window:
print(f" [跳过] {stock_code} {stock_name}: 有效数据不足")
return
# 找到date_idx在有效数据中的位置
if date_idx not in valid_indices:
print(f" [跳过] {stock_code} {stock_name}: date_idx不在有效范围")
return
valid_end = np.where(valid_indices == date_idx)[0][0]
if valid_end < params.window - 1:
print(f" [跳过] {stock_code} {stock_name}: 窗口数据不足")
return
# 提取检测窗口数据(用于三角形检测)
detect_start = valid_end - params.window + 1
high_win = high_stock[valid_mask][detect_start:valid_end + 1]
low_win = low_stock[valid_mask][detect_start:valid_end + 1]
close_win = close_stock[valid_mask][detect_start:valid_end + 1]
volume_win = volume_stock[valid_mask][detect_start:valid_end + 1]
# 提取显示窗口数据(用于绘图,更长的历史)
display_start = max(0, valid_end - display_window + 1)
display_high = high_stock[valid_mask][display_start:valid_end + 1]
display_low = low_stock[valid_mask][display_start:valid_end + 1]
display_close = close_stock[valid_mask][display_start:valid_end + 1]
display_volume = volume_stock[valid_mask][display_start:valid_end + 1]
display_dates = dates[valid_indices[display_start:valid_end + 1]]
# 检测三角形(使用检测窗口数据)
result = detect_converging_triangle(
high=high_win,
low=low_win,
close=close_win,
volume=volume_win,
params=params,
stock_idx=stock_idx,
date_idx=date_idx,
)
if not result.is_valid:
print(f" [跳过] {stock_code} {stock_name}: 未识别到有效三角形")
return
# 绘图准备
x_display = np.arange(len(display_close), dtype=float)
# 计算三角形在显示窗口中的位置偏移
triangle_offset = len(display_close) - len(close_win)
# 三角形的上下沿线(相对于检测窗口)
a_u, a_l = result.upper_slope, result.lower_slope
start = result.window_start
end = result.window_end
# 计算截距(基于检测窗口)
upper_end_val = high_win[end]
lower_end_val = low_win[end]
b_u = upper_end_val - a_u * end
b_l = lower_end_val - a_l * end
# 三角形线段在显示窗口中的X坐标
xw_in_display = np.arange(start + triangle_offset, end + triangle_offset + 1, dtype=float)
# 但计算Y值时仍使用原始检测窗口的X坐标
xw_original = np.arange(start, end + 1, dtype=float)
upper_line = line_y(a_u, b_u, xw_original)
lower_line = line_y(a_l, b_l, xw_original)
# 三角形检测窗口的日期范围(用于标题)
detect_dates = dates[valid_indices[detect_start:valid_end + 1]]
# 创建图表
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8),
gridspec_kw={'height_ratios': [3, 1]})
# 主图:价格和趋势线(使用显示窗口数据)
ax1.plot(x_display, display_close, linewidth=1.5, label='收盘价', color='black', alpha=0.7)
ax1.plot(xw_in_display, upper_line, linewidth=2, label='上沿', color='red', linestyle='--')
ax1.plot(xw_in_display, lower_line, linewidth=2, label='下沿', color='green', linestyle='--')
ax1.axvline(len(display_close) - 1, color='gray', linestyle=':', linewidth=1, alpha=0.5)
ax1.set_title(
f"{stock_code} {stock_name} - 收敛三角形 (检测窗口: {detect_dates[0]} ~ {detect_dates[-1]})\n"
f"显示范围: {display_dates[0]} ~ {display_dates[-1]} ({len(display_dates)}个交易日) "
f"突破方向: {result.breakout_dir} 宽度比: {result.width_ratio:.2f} "
f"触碰: 上{result.touches_upper}/下{result.touches_lower} "
f"放量确认: {'' if result.volume_confirmed else '' if result.volume_confirmed is False else '-'}",
fontsize=11, pad=10
)
ax1.set_ylabel('价格', fontsize=10)
ax1.legend(loc='best', fontsize=9)
ax1.grid(True, alpha=0.3)
# X轴日期标签稀疏显示基于显示窗口
step = max(1, len(display_dates) // 10)
tick_indices = np.arange(0, len(display_dates), step)
ax1.set_xticks(tick_indices)
ax1.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8)
# 副图:成交量(使用显示窗口数据)
ax2.bar(x_display, display_volume, width=0.8, color='skyblue', alpha=0.6)
ax2.set_ylabel('成交量', fontsize=10)
ax2.set_xlabel('交易日', fontsize=10)
ax2.grid(True, alpha=0.3, axis='y')
ax2.set_xticks(tick_indices)
ax2.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8)
plt.tight_layout()
plt.savefig(output_path, dpi=120)
plt.close()
print(f" [完成] {stock_code} {stock_name} -> {output_path}")
def main() -> None:
parser = argparse.ArgumentParser(description="为当日满足收敛三角形的个股生成图表")
parser.add_argument(
"--input",
default=os.path.join("outputs", "converging_triangles", "all_results.csv"),
help="输入 CSV 路径",
)
parser.add_argument(
"--date",
type=int,
default=None,
help="指定日期YYYYMMDD默认为数据最新日",
)
parser.add_argument(
"--output-dir",
default=os.path.join("outputs", "converging_triangles", "charts"),
help="图表输出目录",
)
args = parser.parse_args()
print("=" * 70)
print("收敛三角形图表生成")
print("=" * 70)
# 1. 加载数据
print("\n[1] 加载 OHLCV 数据...")
data_dir = os.path.join(os.path.dirname(__file__), "..", "data")
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(data_dir)
print(f" 股票数: {len(tkrs)}, 交易日数: {len(dates)}")
# 2. 确定目标日期
if args.date:
target_date = args.date
else:
# 从CSV读取最新日期
with open(args.input, newline="", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
all_dates = [int(r.get("date", "0")) for r in reader if r.get("date") and r["date"].isdigit()]
target_date = max(all_dates) if all_dates else 0
if target_date == 0:
print("错误: 无法确定目标日期")
return
print(f"\n[2] 目标日期: {target_date}")
# 3. 加载当日股票列表
stocks = load_daily_stocks(args.input, target_date)
print(f" 当日满足三角形的股票数: {len(stocks)}")
if not stocks:
print("当日无满足条件的股票")
return
# 4. 创建输出目录并清空旧图片
os.makedirs(args.output_dir, exist_ok=True)
# 清空目录中的旧图片
print(f"\n[4] 清空输出目录...")
old_files = [f for f in os.listdir(args.output_dir) if f.endswith('.png')]
for f in old_files:
os.remove(os.path.join(args.output_dir, f))
print(f" 已删除 {len(old_files)} 个旧图片")
# 5. 检测参数(与 run_converging_triangle.py 保持一致)
params = ConvergingTriangleParams(
window=120,
pivot_k=15,
boundary_n_segments=2,
boundary_source="full",
upper_slope_max=0.10,
lower_slope_min=-0.10,
touch_tol=0.10,
touch_loss_max=0.10,
shrink_ratio=0.8,
break_tol=0.001,
vol_window=20,
vol_k=1.3,
false_break_m=5,
)
# 6. 找到target_date在dates中的索引
date_idx = np.where(dates == target_date)[0]
if len(date_idx) == 0:
print(f"错误: 日期 {target_date} 不在数据范围内")
return
date_idx = date_idx[0]
# 7. 生成图表
print(f"\n[3] 生成图表...")
for stock in stocks:
stock_idx = stock["stock_idx"]
stock_code = stock["stock_code"]
stock_name = stock["stock_name"]
output_filename = f"{target_date}_{stock_code}_{stock_name}.png"
output_path = os.path.join(args.output_dir, output_filename)
try:
plot_triangle(
stock_idx=stock_idx,
stock_code=stock_code,
stock_name=stock_name,
date_idx=date_idx,
high_mtx=high_mtx,
low_mtx=low_mtx,
close_mtx=close_mtx,
volume_mtx=volume_mtx,
dates=dates,
params=params,
output_path=output_path,
display_window=500, # 显示500个交易日
)
except Exception as e:
print(f" [错误] {stock_code} {stock_name}: {e}")
print(f"\n完成!图表已保存至: {args.output_dir}")
if __name__ == "__main__":
main()