- Updated all_results.csv with additional stock data and breakout strength metrics. - Revised report.md to improve clarity and detail on stock selection criteria and results. - Expanded strong_breakout_down.csv and strong_breakout_up.csv with new entries reflecting recent analysis. - Introduced new chart images for selected stocks to visualize breakout patterns. - Added plot_converging_triangles.py script for generating visualizations of stocks meeting convergence criteria. - Enhanced report_converging_triangles.py to allow for date-specific reporting and improved output formatting. - Optimized run_converging_triangle.py for performance and added execution time logging.
356 lines
12 KiB
Python
356 lines
12 KiB
Python
"""
|
||
为当日满足收敛三角形的个股生成图表
|
||
|
||
用法:
|
||
python scripts/plot_converging_triangles.py
|
||
python scripts/plot_converging_triangles.py --date 20260120
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import csv
|
||
import os
|
||
import pickle
|
||
import sys
|
||
from typing import Dict, List, Optional
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
|
||
# 配置中文字体
|
||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS'] # 支持中文
|
||
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
||
|
||
# 让脚本能找到 src/ 下的模块
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||
|
||
from converging_triangle import (
|
||
ConvergingTriangleParams,
|
||
detect_converging_triangle,
|
||
line_y,
|
||
)
|
||
|
||
|
||
class FakeModule:
|
||
"""空壳模块,绕过 model 依赖"""
|
||
ndarray = np.ndarray
|
||
|
||
|
||
def load_pkl(pkl_path: str) -> dict:
|
||
"""加载 pkl 文件"""
|
||
sys.modules['model'] = FakeModule()
|
||
sys.modules['model.index_info'] = FakeModule()
|
||
|
||
with open(pkl_path, 'rb') as f:
|
||
data = pickle.load(f)
|
||
return data
|
||
|
||
|
||
def load_ohlcv_from_pkl(data_dir: str) -> tuple:
|
||
"""从 pkl 文件加载 OHLCV 数据"""
|
||
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
|
||
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
|
||
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
|
||
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
|
||
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
|
||
|
||
dates = close_data["dtes"]
|
||
tkrs = close_data["tkrs"]
|
||
tkrs_name = close_data["tkrs_name"]
|
||
|
||
return (
|
||
open_data["mtx"],
|
||
high_data["mtx"],
|
||
low_data["mtx"],
|
||
close_data["mtx"],
|
||
volume_data["mtx"],
|
||
dates,
|
||
tkrs,
|
||
tkrs_name,
|
||
)
|
||
|
||
|
||
def load_daily_stocks(csv_path: str, target_date: int) -> List[Dict]:
|
||
"""从CSV读取指定日期的股票列表"""
|
||
stocks = []
|
||
with open(csv_path, newline="", encoding="utf-8-sig") as f:
|
||
reader = csv.DictReader(f)
|
||
for row in reader:
|
||
try:
|
||
date = int(row.get("date", "0"))
|
||
if date == target_date:
|
||
stocks.append({
|
||
"stock_idx": int(row.get("stock_idx", "0")),
|
||
"stock_code": row.get("stock_code", ""),
|
||
"stock_name": row.get("stock_name", ""),
|
||
"breakout_dir": row.get("breakout_dir", "none"),
|
||
"breakout_strength_up": float(row.get("breakout_strength_up", "0")),
|
||
"breakout_strength_down": float(row.get("breakout_strength_down", "0")),
|
||
})
|
||
except (ValueError, TypeError):
|
||
continue
|
||
return stocks
|
||
|
||
|
||
def plot_triangle(
|
||
stock_idx: int,
|
||
stock_code: str,
|
||
stock_name: str,
|
||
date_idx: int,
|
||
high_mtx: np.ndarray,
|
||
low_mtx: np.ndarray,
|
||
close_mtx: np.ndarray,
|
||
volume_mtx: np.ndarray,
|
||
dates: np.ndarray,
|
||
params: ConvergingTriangleParams,
|
||
output_path: str,
|
||
display_window: int = 500, # 新增:显示窗口大小
|
||
) -> None:
|
||
"""绘制单只股票的收敛三角形图"""
|
||
|
||
# 提取该股票数据并过滤NaN
|
||
high_stock = high_mtx[stock_idx, :]
|
||
low_stock = low_mtx[stock_idx, :]
|
||
close_stock = close_mtx[stock_idx, :]
|
||
volume_stock = volume_mtx[stock_idx, :]
|
||
|
||
valid_mask = ~np.isnan(close_stock)
|
||
valid_indices = np.where(valid_mask)[0]
|
||
|
||
if len(valid_indices) < params.window:
|
||
print(f" [跳过] {stock_code} {stock_name}: 有效数据不足")
|
||
return
|
||
|
||
# 找到date_idx在有效数据中的位置
|
||
if date_idx not in valid_indices:
|
||
print(f" [跳过] {stock_code} {stock_name}: date_idx不在有效范围")
|
||
return
|
||
|
||
valid_end = np.where(valid_indices == date_idx)[0][0]
|
||
if valid_end < params.window - 1:
|
||
print(f" [跳过] {stock_code} {stock_name}: 窗口数据不足")
|
||
return
|
||
|
||
# 提取检测窗口数据(用于三角形检测)
|
||
detect_start = valid_end - params.window + 1
|
||
high_win = high_stock[valid_mask][detect_start:valid_end + 1]
|
||
low_win = low_stock[valid_mask][detect_start:valid_end + 1]
|
||
close_win = close_stock[valid_mask][detect_start:valid_end + 1]
|
||
volume_win = volume_stock[valid_mask][detect_start:valid_end + 1]
|
||
|
||
# 提取显示窗口数据(用于绘图,更长的历史)
|
||
display_start = max(0, valid_end - display_window + 1)
|
||
display_high = high_stock[valid_mask][display_start:valid_end + 1]
|
||
display_low = low_stock[valid_mask][display_start:valid_end + 1]
|
||
display_close = close_stock[valid_mask][display_start:valid_end + 1]
|
||
display_volume = volume_stock[valid_mask][display_start:valid_end + 1]
|
||
display_dates = dates[valid_indices[display_start:valid_end + 1]]
|
||
|
||
# 检测三角形(使用检测窗口数据)
|
||
result = detect_converging_triangle(
|
||
high=high_win,
|
||
low=low_win,
|
||
close=close_win,
|
||
volume=volume_win,
|
||
params=params,
|
||
stock_idx=stock_idx,
|
||
date_idx=date_idx,
|
||
)
|
||
|
||
if not result.is_valid:
|
||
print(f" [跳过] {stock_code} {stock_name}: 未识别到有效三角形")
|
||
return
|
||
|
||
# 绘图准备
|
||
x_display = np.arange(len(display_close), dtype=float)
|
||
|
||
# 计算三角形在显示窗口中的位置偏移
|
||
triangle_offset = len(display_close) - len(close_win)
|
||
|
||
# 三角形的上下沿线(相对于检测窗口)
|
||
a_u, a_l = result.upper_slope, result.lower_slope
|
||
start = result.window_start
|
||
end = result.window_end
|
||
|
||
# 计算截距(基于检测窗口)
|
||
upper_end_val = high_win[end]
|
||
lower_end_val = low_win[end]
|
||
b_u = upper_end_val - a_u * end
|
||
b_l = lower_end_val - a_l * end
|
||
|
||
# 三角形线段在显示窗口中的X坐标
|
||
xw_in_display = np.arange(start + triangle_offset, end + triangle_offset + 1, dtype=float)
|
||
# 但计算Y值时仍使用原始检测窗口的X坐标
|
||
xw_original = np.arange(start, end + 1, dtype=float)
|
||
upper_line = line_y(a_u, b_u, xw_original)
|
||
lower_line = line_y(a_l, b_l, xw_original)
|
||
|
||
# 三角形检测窗口的日期范围(用于标题)
|
||
detect_dates = dates[valid_indices[detect_start:valid_end + 1]]
|
||
|
||
# 创建图表
|
||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8),
|
||
gridspec_kw={'height_ratios': [3, 1]})
|
||
|
||
# 主图:价格和趋势线(使用显示窗口数据)
|
||
ax1.plot(x_display, display_close, linewidth=1.5, label='收盘价', color='black', alpha=0.7)
|
||
ax1.plot(xw_in_display, upper_line, linewidth=2, label='上沿', color='red', linestyle='--')
|
||
ax1.plot(xw_in_display, lower_line, linewidth=2, label='下沿', color='green', linestyle='--')
|
||
ax1.axvline(len(display_close) - 1, color='gray', linestyle=':', linewidth=1, alpha=0.5)
|
||
|
||
ax1.set_title(
|
||
f"{stock_code} {stock_name} - 收敛三角形 (检测窗口: {detect_dates[0]} ~ {detect_dates[-1]})\n"
|
||
f"显示范围: {display_dates[0]} ~ {display_dates[-1]} ({len(display_dates)}个交易日) "
|
||
f"突破方向: {result.breakout_dir} 宽度比: {result.width_ratio:.2f} "
|
||
f"触碰: 上{result.touches_upper}/下{result.touches_lower} "
|
||
f"放量确认: {'是' if result.volume_confirmed else '否' if result.volume_confirmed is False else '-'}",
|
||
fontsize=11, pad=10
|
||
)
|
||
ax1.set_ylabel('价格', fontsize=10)
|
||
ax1.legend(loc='best', fontsize=9)
|
||
ax1.grid(True, alpha=0.3)
|
||
|
||
# X轴日期标签(稀疏显示,基于显示窗口)
|
||
step = max(1, len(display_dates) // 10)
|
||
tick_indices = np.arange(0, len(display_dates), step)
|
||
ax1.set_xticks(tick_indices)
|
||
ax1.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8)
|
||
|
||
# 副图:成交量(使用显示窗口数据)
|
||
ax2.bar(x_display, display_volume, width=0.8, color='skyblue', alpha=0.6)
|
||
ax2.set_ylabel('成交量', fontsize=10)
|
||
ax2.set_xlabel('交易日', fontsize=10)
|
||
ax2.grid(True, alpha=0.3, axis='y')
|
||
ax2.set_xticks(tick_indices)
|
||
ax2.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8)
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(output_path, dpi=120)
|
||
plt.close()
|
||
|
||
print(f" [完成] {stock_code} {stock_name} -> {output_path}")
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description="为当日满足收敛三角形的个股生成图表")
|
||
parser.add_argument(
|
||
"--input",
|
||
default=os.path.join("outputs", "converging_triangles", "all_results.csv"),
|
||
help="输入 CSV 路径",
|
||
)
|
||
parser.add_argument(
|
||
"--date",
|
||
type=int,
|
||
default=None,
|
||
help="指定日期(YYYYMMDD),默认为数据最新日",
|
||
)
|
||
parser.add_argument(
|
||
"--output-dir",
|
||
default=os.path.join("outputs", "converging_triangles", "charts"),
|
||
help="图表输出目录",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
print("=" * 70)
|
||
print("收敛三角形图表生成")
|
||
print("=" * 70)
|
||
|
||
# 1. 加载数据
|
||
print("\n[1] 加载 OHLCV 数据...")
|
||
data_dir = os.path.join(os.path.dirname(__file__), "..", "data")
|
||
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(data_dir)
|
||
print(f" 股票数: {len(tkrs)}, 交易日数: {len(dates)}")
|
||
|
||
# 2. 确定目标日期
|
||
if args.date:
|
||
target_date = args.date
|
||
else:
|
||
# 从CSV读取最新日期
|
||
with open(args.input, newline="", encoding="utf-8-sig") as f:
|
||
reader = csv.DictReader(f)
|
||
all_dates = [int(r.get("date", "0")) for r in reader if r.get("date") and r["date"].isdigit()]
|
||
target_date = max(all_dates) if all_dates else 0
|
||
|
||
if target_date == 0:
|
||
print("错误: 无法确定目标日期")
|
||
return
|
||
|
||
print(f"\n[2] 目标日期: {target_date}")
|
||
|
||
# 3. 加载当日股票列表
|
||
stocks = load_daily_stocks(args.input, target_date)
|
||
print(f" 当日满足三角形的股票数: {len(stocks)}")
|
||
|
||
if not stocks:
|
||
print("当日无满足条件的股票")
|
||
return
|
||
|
||
# 4. 创建输出目录并清空旧图片
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
|
||
# 清空目录中的旧图片
|
||
print(f"\n[4] 清空输出目录...")
|
||
old_files = [f for f in os.listdir(args.output_dir) if f.endswith('.png')]
|
||
for f in old_files:
|
||
os.remove(os.path.join(args.output_dir, f))
|
||
print(f" 已删除 {len(old_files)} 个旧图片")
|
||
|
||
# 5. 检测参数(与 run_converging_triangle.py 保持一致)
|
||
params = ConvergingTriangleParams(
|
||
window=120,
|
||
pivot_k=15,
|
||
boundary_n_segments=2,
|
||
boundary_source="full",
|
||
upper_slope_max=0.10,
|
||
lower_slope_min=-0.10,
|
||
touch_tol=0.10,
|
||
touch_loss_max=0.10,
|
||
shrink_ratio=0.8,
|
||
break_tol=0.001,
|
||
vol_window=20,
|
||
vol_k=1.3,
|
||
false_break_m=5,
|
||
)
|
||
|
||
# 6. 找到target_date在dates中的索引
|
||
date_idx = np.where(dates == target_date)[0]
|
||
if len(date_idx) == 0:
|
||
print(f"错误: 日期 {target_date} 不在数据范围内")
|
||
return
|
||
date_idx = date_idx[0]
|
||
|
||
# 7. 生成图表
|
||
print(f"\n[3] 生成图表...")
|
||
for stock in stocks:
|
||
stock_idx = stock["stock_idx"]
|
||
stock_code = stock["stock_code"]
|
||
stock_name = stock["stock_name"]
|
||
|
||
output_filename = f"{target_date}_{stock_code}_{stock_name}.png"
|
||
output_path = os.path.join(args.output_dir, output_filename)
|
||
|
||
try:
|
||
plot_triangle(
|
||
stock_idx=stock_idx,
|
||
stock_code=stock_code,
|
||
stock_name=stock_name,
|
||
date_idx=date_idx,
|
||
high_mtx=high_mtx,
|
||
low_mtx=low_mtx,
|
||
close_mtx=close_mtx,
|
||
volume_mtx=volume_mtx,
|
||
dates=dates,
|
||
params=params,
|
||
output_path=output_path,
|
||
display_window=500, # 显示500个交易日
|
||
)
|
||
except Exception as e:
|
||
print(f" [错误] {stock_code} {stock_name}: {e}")
|
||
|
||
print(f"\n完成!图表已保存至: {args.output_dir}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|