- Introduced a new "tilt" parameter to the strength scoring system, allowing for the assessment of triangle slope directionality. - Renamed existing parameters: "拟合贴合度" to "形态规则度" and "边界利用率" to "价格活跃度" for improved clarity. - Updated normalization methods for all strength components to ensure they remain within the [0, 1] range, facilitating LLM tuning. - Enhanced documentation to reflect changes in parameter names and scoring logic, including detailed explanations of the new tilt parameter. - Modified multiple source files and scripts to accommodate the new scoring structure and ensure backward compatibility. Files modified: - `src/converging_triangle.py`, `src/converging_triangle_optimized.py`, `src/triangle_detector_api.py`: Updated parameter names and scoring logic. - `scripts/plot_converging_triangles.py`, `scripts/generate_stock_viewer.py`: Adjusted for new scoring parameters in output. - New documentation files created to explain the renaming and new scoring system in detail.
682 lines
28 KiB
Python
682 lines
28 KiB
Python
"""
|
||
为当日满足收敛三角形的个股生成图表(K线图模式)
|
||
|
||
用法:
|
||
# 简洁模式(默认)- 仅显示K线、上沿、下沿
|
||
python scripts/plot_converging_triangles.py
|
||
|
||
# 详细模式 - 显示所有枢轴点、拟合点
|
||
python scripts/plot_converging_triangles.py --show-details
|
||
|
||
# 指定日期
|
||
python scripts/plot_converging_triangles.py --date 20260120
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import csv
|
||
import os
|
||
import pickle
|
||
import sys
|
||
from typing import Dict, List, Optional
|
||
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.patches as mpatches
|
||
from matplotlib.patches import Rectangle
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
# 配置中文字体
|
||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS'] # 支持中文
|
||
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
||
|
||
# 让脚本能找到 src/ 下的模块
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||
|
||
from converging_triangle import (
|
||
ConvergingTriangleParams,
|
||
calc_geometry_score,
|
||
detect_converging_triangle,
|
||
fit_pivot_line_dispatch,
|
||
line_y,
|
||
pivots_fractal,
|
||
pivots_fractal_hybrid,
|
||
)
|
||
|
||
# 导入统一的参数配置
|
||
from triangle_config import DETECTION_PARAMS, DISPLAY_WINDOW, SHOW_CHART_DETAILS, REALTIME_MODE, FLEXIBLE_ZONE
|
||
|
||
|
||
class FakeModule:
|
||
"""空壳模块,绕过 model 依赖"""
|
||
ndarray = np.ndarray
|
||
|
||
|
||
def load_pkl(pkl_path: str) -> dict:
|
||
"""加载 pkl 文件"""
|
||
sys.modules['model'] = FakeModule()
|
||
sys.modules['model.index_info'] = FakeModule()
|
||
|
||
with open(pkl_path, 'rb') as f:
|
||
data = pickle.load(f)
|
||
return data
|
||
|
||
|
||
def load_ohlcv_from_pkl(data_dir: str) -> tuple:
|
||
"""从 pkl 文件加载 OHLCV 数据"""
|
||
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
|
||
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
|
||
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
|
||
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
|
||
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
|
||
|
||
dates = close_data["dtes"]
|
||
tkrs = close_data["tkrs"]
|
||
tkrs_name = close_data["tkrs_name"]
|
||
|
||
return (
|
||
open_data["mtx"],
|
||
high_data["mtx"],
|
||
low_data["mtx"],
|
||
close_data["mtx"],
|
||
volume_data["mtx"],
|
||
dates,
|
||
tkrs,
|
||
tkrs_name,
|
||
)
|
||
|
||
|
||
def load_daily_stocks(csv_path: str, target_date: int) -> List[Dict]:
|
||
"""从CSV读取指定日期的股票列表"""
|
||
stocks = []
|
||
with open(csv_path, newline="", encoding="utf-8-sig") as f:
|
||
reader = csv.DictReader(f)
|
||
for row in reader:
|
||
try:
|
||
date = int(row.get("date", "0"))
|
||
if date == target_date:
|
||
stocks.append({
|
||
"stock_idx": int(row.get("stock_idx", "0")),
|
||
"stock_code": row.get("stock_code", ""),
|
||
"stock_name": row.get("stock_name", ""),
|
||
"breakout_dir": row.get("breakout_dir", "none"),
|
||
"breakout_strength_up": float(row.get("breakout_strength_up", "0")),
|
||
"breakout_strength_down": float(row.get("breakout_strength_down", "0")),
|
||
})
|
||
except (ValueError, TypeError):
|
||
continue
|
||
return stocks
|
||
|
||
|
||
def plot_triangle(
|
||
stock_idx: int,
|
||
stock_code: str,
|
||
stock_name: str,
|
||
date_idx: int,
|
||
open_mtx: np.ndarray,
|
||
high_mtx: np.ndarray,
|
||
low_mtx: np.ndarray,
|
||
close_mtx: np.ndarray,
|
||
volume_mtx: np.ndarray,
|
||
dates: np.ndarray,
|
||
params: ConvergingTriangleParams,
|
||
output_path: str,
|
||
display_window: int = 500, # 显示窗口大小
|
||
show_details: bool = False, # 是否显示详细调试信息
|
||
force_plot: bool = False, # 强制绘图(即使不满足三角形条件)
|
||
plot_boundary_source: str = "hl", # 边界线拟合数据源: "hl" | "close"
|
||
) -> None:
|
||
"""绘制单只股票的收敛三角形图(K线图模式)"""
|
||
|
||
# 提取该股票数据并过滤NaN
|
||
open_stock = open_mtx[stock_idx, :]
|
||
high_stock = high_mtx[stock_idx, :]
|
||
low_stock = low_mtx[stock_idx, :]
|
||
close_stock = close_mtx[stock_idx, :]
|
||
volume_stock = volume_mtx[stock_idx, :]
|
||
|
||
valid_mask = ~np.isnan(close_stock)
|
||
valid_indices = np.where(valid_mask)[0]
|
||
|
||
if len(valid_indices) < params.window:
|
||
if force_plot:
|
||
print(f" [警告] {stock_code} {stock_name}: 有效数据不足,仅绘制基础K线")
|
||
# 继续绘制基础K线
|
||
else:
|
||
print(f" [跳过] {stock_code} {stock_name}: 有效数据不足")
|
||
return
|
||
|
||
# 找到date_idx在有效数据中的位置
|
||
if date_idx not in valid_indices:
|
||
print(f" [跳过] {stock_code} {stock_name}: date_idx不在有效范围")
|
||
return
|
||
|
||
valid_end = np.where(valid_indices == date_idx)[0][0]
|
||
|
||
# 检查窗口数据是否充足
|
||
has_enough_data = valid_end >= params.window - 1
|
||
|
||
if not has_enough_data and not force_plot:
|
||
print(f" [跳过] {stock_code} {stock_name}: 窗口数据不足")
|
||
return
|
||
|
||
# 提取检测窗口数据(用于三角形检测)
|
||
if has_enough_data:
|
||
detect_start = valid_end - params.window + 1
|
||
high_win = high_stock[valid_mask][detect_start:valid_end + 1]
|
||
low_win = low_stock[valid_mask][detect_start:valid_end + 1]
|
||
close_win = close_stock[valid_mask][detect_start:valid_end + 1]
|
||
volume_win = volume_stock[valid_mask][detect_start:valid_end + 1]
|
||
else:
|
||
# 数据不足,使用所有可用数据
|
||
detect_start = 0
|
||
high_win = high_stock[valid_mask][:valid_end + 1]
|
||
low_win = low_stock[valid_mask][:valid_end + 1]
|
||
close_win = close_stock[valid_mask][:valid_end + 1]
|
||
volume_win = volume_stock[valid_mask][:valid_end + 1]
|
||
|
||
# 提取显示窗口数据(用于绘图,更长的历史)
|
||
display_start = max(0, valid_end - display_window + 1)
|
||
display_open = open_stock[valid_mask][display_start:valid_end + 1]
|
||
display_high = high_stock[valid_mask][display_start:valid_end + 1]
|
||
display_low = low_stock[valid_mask][display_start:valid_end + 1]
|
||
display_close = close_stock[valid_mask][display_start:valid_end + 1]
|
||
display_volume = volume_stock[valid_mask][display_start:valid_end + 1]
|
||
display_dates = dates[valid_indices[display_start:valid_end + 1]]
|
||
|
||
# ========================================================================
|
||
# 计算三角形参数(用于绘图)
|
||
# force_plot模式:即使不满足条件也尝试检测,检测失败则只画K线
|
||
# ========================================================================
|
||
result = None
|
||
has_triangle = False
|
||
fitting_adherence_plot = 0.0 # 初始化贴合度
|
||
|
||
if has_enough_data:
|
||
result = detect_converging_triangle(
|
||
high=high_win,
|
||
low=low_win,
|
||
close=close_win,
|
||
volume=volume_win,
|
||
params=params,
|
||
stock_idx=stock_idx,
|
||
date_idx=date_idx,
|
||
real_time_mode=REALTIME_MODE,
|
||
flexible_zone=FLEXIBLE_ZONE,
|
||
)
|
||
has_triangle = result.is_valid if result else False
|
||
|
||
# 在force_plot模式下,即使没有有效三角形也继续绘图(仅K线)
|
||
if not force_plot and not has_triangle:
|
||
print(f" [跳过] {stock_code} {stock_name}: 不满足收敛三角形条件")
|
||
return
|
||
|
||
# 绘图准备
|
||
x_display = np.arange(len(display_close), dtype=float)
|
||
|
||
# 只在有三角形时计算三角形相关参数
|
||
if has_triangle and has_enough_data:
|
||
# 计算三角形在显示窗口中的位置偏移
|
||
triangle_offset = len(display_close) - len(close_win)
|
||
|
||
# 获取检测窗口的起止索引(相对于检测窗口内部)
|
||
n = len(close_win)
|
||
x_win = np.arange(n, dtype=float)
|
||
|
||
# 计算枢轴点(与检测算法一致,考虑实时模式)
|
||
if REALTIME_MODE:
|
||
confirmed_ph, confirmed_pl, candidate_ph, candidate_pl = pivots_fractal_hybrid(
|
||
high_win, low_win, k=params.pivot_k, flexible_zone=FLEXIBLE_ZONE
|
||
)
|
||
# 合并确认枢轴点和候选枢轴点
|
||
ph_idx = np.concatenate([confirmed_ph, candidate_ph]) if len(candidate_ph) > 0 else confirmed_ph
|
||
pl_idx = np.concatenate([confirmed_pl, candidate_pl]) if len(candidate_pl) > 0 else confirmed_pl
|
||
# 排序以保证顺序
|
||
ph_idx = np.sort(ph_idx)
|
||
pl_idx = np.sort(pl_idx)
|
||
else:
|
||
ph_idx, pl_idx = pivots_fractal(high_win, low_win, k=params.pivot_k)
|
||
|
||
# 使用枢轴点连线法拟合边界线(与检测算法一致)
|
||
# 注意:绘图用的是检测窗口数据,因此 window_start=0, window_end=n-1
|
||
if plot_boundary_source == "close":
|
||
upper_fit_values = close_win[ph_idx]
|
||
lower_fit_values = close_win[pl_idx]
|
||
upper_all_prices = close_win
|
||
lower_all_prices = close_win
|
||
else:
|
||
upper_fit_values = high_win[ph_idx]
|
||
lower_fit_values = low_win[pl_idx]
|
||
upper_all_prices = high_win
|
||
lower_all_prices = low_win
|
||
|
||
a_u, b_u, selected_ph = fit_pivot_line_dispatch(
|
||
pivot_indices=ph_idx,
|
||
pivot_values=upper_fit_values,
|
||
mode="upper",
|
||
method=params.fitting_method,
|
||
all_prices=upper_all_prices,
|
||
window_start=0,
|
||
window_end=n - 1,
|
||
)
|
||
a_l, b_l, selected_pl = fit_pivot_line_dispatch(
|
||
pivot_indices=pl_idx,
|
||
pivot_values=lower_fit_values,
|
||
mode="lower",
|
||
method=params.fitting_method,
|
||
all_prices=lower_all_prices,
|
||
window_start=0,
|
||
window_end=n - 1,
|
||
)
|
||
|
||
# 三角形线段在显示窗口中的X坐标(只画检测窗口范围)
|
||
xw_in_display = np.arange(triangle_offset, triangle_offset + n, dtype=float)
|
||
# 计算Y值使用检测窗口内部的X坐标
|
||
upper_line = line_y(a_u, b_u, x_win)
|
||
lower_line = line_y(a_l, b_l, x_win)
|
||
|
||
# 获取选中的枢轴点在显示窗口中的位置(用于标注)
|
||
ph_display_idx = ph_idx + triangle_offset
|
||
pl_display_idx = pl_idx + triangle_offset
|
||
selected_ph_pos = ph_idx[selected_ph] if len(selected_ph) > 0 else np.array([], dtype=int)
|
||
selected_pl_pos = pl_idx[selected_pl] if len(selected_pl) > 0 else np.array([], dtype=int)
|
||
selected_ph_display = selected_ph_pos + triangle_offset
|
||
selected_pl_display = selected_pl_pos + triangle_offset
|
||
|
||
# 三角形检测窗口的日期范围(用于标题)
|
||
detect_dates = dates[valid_indices[detect_start:valid_end + 1]]
|
||
|
||
# 如果使用收盘价拟合,重新计算贴合度(基于实际拟合线)
|
||
if plot_boundary_source == "close" and len(selected_ph) > 0 and len(selected_pl) > 0:
|
||
# 使用收盘价重新计算贴合度
|
||
adherence_upper_close = calc_geometry_score(
|
||
pivot_indices=selected_ph_pos.astype(float),
|
||
pivot_values=close_win[selected_ph_pos],
|
||
slope=a_u,
|
||
intercept=b_u,
|
||
)
|
||
adherence_lower_close = calc_geometry_score(
|
||
pivot_indices=selected_pl_pos.astype(float),
|
||
pivot_values=close_win[selected_pl_pos],
|
||
slope=a_l,
|
||
intercept=b_l,
|
||
)
|
||
geometry_score_plot = (adherence_upper_close + adherence_lower_close) / 2.0
|
||
else:
|
||
# 使用检测算法计算的贴合度
|
||
geometry_score_plot = result.geometry_score if result else 0.0
|
||
else:
|
||
# 无三角形时,贴合度为0
|
||
geometry_score_plot = 0.0
|
||
|
||
# 创建图表
|
||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8),
|
||
gridspec_kw={'height_ratios': [3, 1]})
|
||
|
||
# ========================================================================
|
||
# 主图:K线图 + 趋势线
|
||
# ========================================================================
|
||
# 绘制K线
|
||
for i in range(len(x_display)):
|
||
# 确定颜色:涨为红色,跌为绿色
|
||
if i == 0:
|
||
color = 'red' if display_close[i] >= display_open[i] else 'green'
|
||
else:
|
||
color = 'red' if display_close[i] >= display_open[i] else 'green'
|
||
|
||
# 绘制影线(最高到最低)
|
||
ax1.plot([x_display[i], x_display[i]],
|
||
[display_low[i], display_high[i]],
|
||
color=color, linewidth=0.8, alpha=0.8)
|
||
|
||
# 绘制实体(开盘到收盘)
|
||
body_height = abs(display_close[i] - display_open[i])
|
||
body_bottom = min(display_open[i], display_close[i])
|
||
|
||
if body_height < 0.001: # 十字星
|
||
ax1.plot([x_display[i] - 0.3, x_display[i] + 0.3],
|
||
[display_close[i], display_close[i]],
|
||
color=color, linewidth=1.5)
|
||
else:
|
||
rect = Rectangle((x_display[i] - 0.3, body_bottom),
|
||
0.6, body_height,
|
||
facecolor=color if display_close[i] < display_open[i] else 'white',
|
||
edgecolor=color, linewidth=1.2)
|
||
ax1.add_patch(rect)
|
||
|
||
# 只在有三角形时绘制趋势线
|
||
if has_triangle and has_enough_data:
|
||
boundary_label = "收盘价" if plot_boundary_source == "close" else "高低价"
|
||
ax1.plot(xw_in_display, upper_line, linewidth=2.5, label=f'上沿({boundary_label})',
|
||
color='darkred', linestyle='--', alpha=0.8, zorder=5)
|
||
ax1.plot(xw_in_display, lower_line, linewidth=2.5, label=f'下沿({boundary_label})',
|
||
color='darkgreen', linestyle='--', alpha=0.8, zorder=5)
|
||
|
||
# 当前日期的标记线
|
||
ax1.axvline(len(display_close) - 1, color='blue', linestyle=':', linewidth=1.5, alpha=0.6)
|
||
|
||
# ========================================================================
|
||
# 详细模式:显示拟合点(仅在 show_details=True 且有三角形时)
|
||
# ========================================================================
|
||
if show_details and has_triangle and has_enough_data:
|
||
# 根据数据源选择枢轴点的Y坐标
|
||
if plot_boundary_source == "close":
|
||
ph_pivot_y = close_win[ph_idx]
|
||
pl_pivot_y = close_win[pl_idx]
|
||
selected_ph_y = close_win[selected_ph_pos] if len(selected_ph_pos) > 0 else np.array([])
|
||
selected_pl_y = close_win[selected_pl_pos] if len(selected_pl_pos) > 0 else np.array([])
|
||
else:
|
||
ph_pivot_y = high_win[ph_idx]
|
||
pl_pivot_y = low_win[pl_idx]
|
||
selected_ph_y = high_win[selected_ph_pos] if len(selected_ph_pos) > 0 else np.array([])
|
||
selected_pl_y = low_win[selected_pl_pos] if len(selected_pl_pos) > 0 else np.array([])
|
||
|
||
# 标注所有枢轴点(用于查看拐点分布)
|
||
if len(ph_display_idx) > 0:
|
||
ax1.scatter(
|
||
ph_display_idx,
|
||
ph_pivot_y,
|
||
marker='x',
|
||
s=60,
|
||
color='red',
|
||
alpha=0.6,
|
||
zorder=4,
|
||
label=f'上沿枢轴点({len(ph_idx)})',
|
||
)
|
||
if len(pl_display_idx) > 0:
|
||
ax1.scatter(
|
||
pl_display_idx,
|
||
pl_pivot_y,
|
||
marker='x',
|
||
s=60,
|
||
color='green',
|
||
alpha=0.6,
|
||
zorder=4,
|
||
label=f'下沿枢轴点({len(pl_idx)})',
|
||
)
|
||
|
||
# 标注选中的枢轴点(用于拟合线的关键点)
|
||
if len(selected_ph_display) >= 2:
|
||
ax1.scatter(
|
||
selected_ph_display,
|
||
selected_ph_y,
|
||
marker='o',
|
||
s=120,
|
||
facecolors='none',
|
||
edgecolors='red',
|
||
linewidths=2.5,
|
||
zorder=5,
|
||
label=f'上沿拟合点({len(selected_ph_pos)})',
|
||
)
|
||
if len(selected_pl_display) >= 2:
|
||
ax1.scatter(
|
||
selected_pl_display,
|
||
selected_pl_y,
|
||
marker='o',
|
||
s=120,
|
||
facecolors='none',
|
||
edgecolors='green',
|
||
linewidths=2.5,
|
||
zorder=5,
|
||
label=f'下沿拟合点({len(selected_pl_pos)})',
|
||
)
|
||
|
||
# 准备标题内容
|
||
if has_triangle and has_enough_data and result:
|
||
# 有效三角形:显示完整信息和强度分
|
||
if result.breakout_dir == "up":
|
||
strength = result.breakout_strength_up
|
||
price_score = result.price_score_up
|
||
elif result.breakout_dir == "down":
|
||
strength = result.breakout_strength_down
|
||
price_score = result.price_score_down
|
||
else:
|
||
strength = max(result.breakout_strength_up, result.breakout_strength_down)
|
||
price_score = max(result.price_score_up, result.price_score_down)
|
||
|
||
# 获取价格活跃度与惩罚系数(兼容旧数据)
|
||
activity_score = getattr(result, 'activity_score', 0.0)
|
||
tilt_score = getattr(result, 'tilt_score', 0.0) # 新增:获取倾斜度分
|
||
activity_floor = 0.20
|
||
if activity_floor > 0:
|
||
activity_penalty = min(1.0, activity_score / activity_floor)
|
||
else:
|
||
activity_penalty = 1.0
|
||
|
||
# 选择显示的贴合度:如果使用收盘价拟合,显示重新计算的贴合度
|
||
if plot_boundary_source == "close" and has_triangle and has_enough_data:
|
||
display_geometry = geometry_score_plot
|
||
geometry_note = f"形态规则度(收盘价): {display_geometry:.3f}"
|
||
else:
|
||
display_geometry = result.geometry_score
|
||
geometry_note = f"形态规则度: {display_geometry:.3f}"
|
||
|
||
ax1.set_title(
|
||
f"{stock_code} {stock_name} - 收敛三角形 (检测窗口: {detect_dates[0]} ~ {detect_dates[-1]})\n"
|
||
f"显示范围: {display_dates[0]} ~ {display_dates[-1]} ({len(display_dates)}个交易日) "
|
||
f"突破方向: {result.breakout_dir} 宽度比: {result.width_ratio:.2f} "
|
||
f"枢轴点: 高{len(ph_idx)}/低{len(pl_idx)} 触碰: 上{result.touches_upper}/下{result.touches_lower} "
|
||
f"放量确认: {'是' if result.volume_confirmed else '否' if result.volume_confirmed is False else '-'}\n"
|
||
f"强度分: {strength:.3f} "
|
||
f"(突破幅度: {price_score:.3f}×45% + 收敛度: {result.convergence_score:.3f}×15% + "
|
||
f"成交量: {result.volume_score:.3f}×10% + {geometry_note}×10% + "
|
||
f"价格活跃度: {activity_score:.3f}×15% + 倾斜度: {tilt_score:.3f}×5%) × 活跃度惩罚: {activity_penalty:.2f}",
|
||
fontsize=11, pad=10
|
||
)
|
||
else:
|
||
# 无三角形:仅显示基础信息和强度分0分
|
||
ax1.set_title(
|
||
f"{stock_code} {stock_name} - K线图(不满足收敛三角形条件)\n"
|
||
f"显示范围: {display_dates[0]} ~ {display_dates[-1]} ({len(display_dates)}个交易日)\n"
|
||
f"强度分: 0.000 (未检测到收敛三角形形态)",
|
||
fontsize=11, pad=10
|
||
)
|
||
ax1.set_ylabel('价格', fontsize=10)
|
||
ax1.legend(loc='best', fontsize=9)
|
||
ax1.grid(True, alpha=0.3)
|
||
|
||
# X轴日期标签(稀疏显示,基于显示窗口)
|
||
step = max(1, len(display_dates) // 10)
|
||
tick_indices = np.arange(0, len(display_dates), step)
|
||
ax1.set_xticks(tick_indices)
|
||
ax1.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8)
|
||
|
||
# 副图:成交量(使用显示窗口数据)
|
||
ax2.bar(x_display, display_volume, width=0.8, color='skyblue', alpha=0.6)
|
||
ax2.set_ylabel('成交量', fontsize=10)
|
||
ax2.set_xlabel('交易日', fontsize=10)
|
||
ax2.grid(True, alpha=0.3, axis='y')
|
||
ax2.set_xticks(tick_indices)
|
||
ax2.set_xticklabels(display_dates[tick_indices], rotation=45, ha='right', fontsize=8)
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(output_path, dpi=120)
|
||
plt.close()
|
||
|
||
print(f" [完成] {stock_code} {stock_name} -> {output_path}")
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description="为当日满足收敛三角形的个股生成图表")
|
||
parser.add_argument(
|
||
"--input",
|
||
default=os.path.join("outputs", "converging_triangles", "all_results.csv"),
|
||
help="输入 CSV 路径",
|
||
)
|
||
parser.add_argument(
|
||
"--date",
|
||
type=int,
|
||
default=None,
|
||
help="指定日期(YYYYMMDD),默认为数据最新日",
|
||
)
|
||
parser.add_argument(
|
||
"--output-dir",
|
||
default=os.path.join("outputs", "converging_triangles", "charts"),
|
||
help="图表输出目录",
|
||
)
|
||
parser.add_argument(
|
||
"--show-details",
|
||
action="store_true",
|
||
help="显示详细调试信息(枢轴点、拟合点、分段线等)",
|
||
)
|
||
parser.add_argument(
|
||
"--plot-boundary-source",
|
||
choices=["hl", "close"],
|
||
default="hl",
|
||
help="绘图时边界线拟合数据源: hl=高低价(默认), close=收盘价(不影响检测)",
|
||
)
|
||
parser.add_argument(
|
||
"--all-stocks",
|
||
action="store_true",
|
||
help="为所有108只股票生成图表(包括不满足收敛三角形条件的)",
|
||
)
|
||
parser.add_argument(
|
||
"--clear",
|
||
action="store_true",
|
||
help="清空当前模式的旧图片(默认不清空)",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
# 确定是否显示详细信息(命令行参数优先)
|
||
show_details = args.show_details if hasattr(args, 'show_details') else SHOW_CHART_DETAILS
|
||
all_stocks = args.all_stocks if hasattr(args, 'all_stocks') else False
|
||
plot_boundary_source = args.plot_boundary_source if hasattr(args, 'plot_boundary_source') else "hl"
|
||
|
||
print("=" * 70)
|
||
print("收敛三角形图表生成")
|
||
print("=" * 70)
|
||
print(f"详细模式: {'开启' if show_details else '关闭'} {'(--show-details)' if show_details else '(简洁模式)'}")
|
||
print(f"图表范围: {'所有108只股票' if all_stocks else '仅满足条件的股票'} {'(--all-stocks)' if all_stocks else ''}")
|
||
|
||
# 1. 加载数据
|
||
print("\n[1] 加载 OHLCV 数据...")
|
||
data_dir = os.path.join(os.path.dirname(__file__), "..", "data")
|
||
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(data_dir)
|
||
print(f" 股票数: {len(tkrs)}, 交易日数: {len(dates)}")
|
||
|
||
# 2. 确定目标日期
|
||
if args.date:
|
||
target_date = args.date
|
||
else:
|
||
# 从CSV读取最新日期
|
||
with open(args.input, newline="", encoding="utf-8-sig") as f:
|
||
reader = csv.DictReader(f)
|
||
all_dates = [int(r.get("date", "0")) for r in reader if r.get("date") and r["date"].isdigit()]
|
||
target_date = max(all_dates) if all_dates else 0
|
||
|
||
if target_date == 0:
|
||
print("错误: 无法确定目标日期")
|
||
return
|
||
|
||
print(f"\n[2] 目标日期: {target_date}")
|
||
|
||
# 3. 加载当日股票列表
|
||
if all_stocks:
|
||
# 模式1: 绘制所有股票(包括不满足条件的)
|
||
print(f" 模式: 所有股票")
|
||
print(f" 股票总数: {len(tkrs)}")
|
||
|
||
# 为所有股票创建stock字典(从CSV获取已有的强度分,未检测的置为0)
|
||
stocks = []
|
||
csv_stocks = load_daily_stocks(args.input, target_date)
|
||
csv_map = {s["stock_idx"]: s for s in csv_stocks}
|
||
|
||
for idx in range(len(tkrs)):
|
||
if idx in csv_map:
|
||
# 已检测过的股票,使用CSV数据
|
||
stocks.append(csv_map[idx])
|
||
else:
|
||
# 未检测过的股票,使用默认值
|
||
stocks.append({
|
||
"stock_idx": idx,
|
||
"stock_code": tkrs[idx],
|
||
"stock_name": tkrs_name[idx],
|
||
"breakout_dir": "none",
|
||
"breakout_strength_up": 0.0,
|
||
"breakout_strength_down": 0.0,
|
||
})
|
||
|
||
print(f" 其中满足三角形条件: {len(csv_stocks)} 只")
|
||
print(f" 不满足条件(将显示基础K线): {len(tkrs) - len(csv_stocks)} 只")
|
||
else:
|
||
# 模式2: 仅绘制满足条件的股票
|
||
stocks = load_daily_stocks(args.input, target_date)
|
||
print(f" 当日满足三角形的股票数: {len(stocks)}")
|
||
|
||
if not stocks:
|
||
print("当日无满足条件的股票")
|
||
return
|
||
|
||
# 4. 创建输出目录,并按需清空对应模式的旧图片
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
|
||
if args.clear:
|
||
# 只清空当前模式的图片(简洁模式或详细模式)
|
||
print(f"\n[4] 清空当前模式的旧图片...")
|
||
# 找出当前模式的文件:简洁模式是不含_detail的.png,详细模式是_detail.png
|
||
if show_details:
|
||
old_files = [f for f in os.listdir(args.output_dir) if f.endswith('_detail.png')]
|
||
else:
|
||
old_files = [f for f in os.listdir(args.output_dir)
|
||
if f.endswith('.png') and not f.endswith('_detail.png')]
|
||
|
||
for f in old_files:
|
||
os.remove(os.path.join(args.output_dir, f))
|
||
print(f" 已删除 {len(old_files)} 个旧图片 ({'详细模式' if show_details else '简洁模式'})")
|
||
else:
|
||
print(f"\n[4] 跳过清空旧图片(使用 --clear 可手动清空)")
|
||
|
||
# 5. 检测参数(从统一配置导入)
|
||
params = DETECTION_PARAMS
|
||
|
||
# 6. 找到target_date在dates中的索引
|
||
date_idx = np.where(dates == target_date)[0]
|
||
if len(date_idx) == 0:
|
||
print(f"错误: 日期 {target_date} 不在数据范围内")
|
||
return
|
||
date_idx = date_idx[0]
|
||
|
||
# 7. 生成图表
|
||
print(f"\n[3] 生成图表...")
|
||
for stock in stocks:
|
||
stock_idx = stock["stock_idx"]
|
||
stock_code = stock["stock_code"]
|
||
stock_name = stock["stock_name"]
|
||
|
||
# 清理文件名中的非法字符(Windows文件名不允许: * ? " < > | : / \)
|
||
stock_name_clean = stock_name.replace('*', '').replace('?', '').replace('"', '').replace('<', '').replace('>', '').replace('|', '').replace(':', '').replace('/', '').replace('\\', '')
|
||
|
||
# 根据详细模式添加文件名后缀
|
||
suffix = "_detail" if show_details else ""
|
||
output_filename = f"{target_date}_{stock_code}_{stock_name_clean}{suffix}.png"
|
||
output_path = os.path.join(args.output_dir, output_filename)
|
||
|
||
try:
|
||
plot_triangle(
|
||
stock_idx=stock_idx,
|
||
stock_code=stock_code,
|
||
stock_name=stock_name,
|
||
date_idx=date_idx,
|
||
open_mtx=open_mtx,
|
||
high_mtx=high_mtx,
|
||
low_mtx=low_mtx,
|
||
close_mtx=close_mtx,
|
||
volume_mtx=volume_mtx,
|
||
dates=dates,
|
||
params=params,
|
||
output_path=output_path,
|
||
display_window=DISPLAY_WINDOW, # 从配置文件读取
|
||
show_details=show_details, # 传递详细模式参数
|
||
force_plot=all_stocks, # 在all_stocks模式下强制绘图
|
||
plot_boundary_source=plot_boundary_source,
|
||
)
|
||
except Exception as e:
|
||
print(f" [错误] {stock_code} {stock_name}: {e}")
|
||
|
||
print(f"\n完成!图表已保存至: {args.output_dir}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|