褚宏光 430511e8c4 feat: 收敛三角形形态可视化验证工具
功能:
- 支持中文名称搜索股票(如"中控技术"、"英伟达")
- 支持 A 股和美股
- 趋势线从窗口起点延伸到终点(与前端一致)
- Y 轴自动调整以包含趋势线
- 支持 --window 参数(1Y/3Y/5Y/ALL)

用法:
python validate.py 中控技术 日 --window 3Y --save
python validate.py 英伟达 日 --window 3Y --save

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-04 14:51:16 +08:00

362 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
收敛三角形可视化绘制模块
根据 chart_data (包含 klines) 绘制 K 线图 + 三角形趋势线
"""
import os
import numpy as np
import pandas as pd
import mplfinance as mpf
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
# ── 中文字体支持 ──────────────────────────────────────────────────────────────
def _setup_chinese_font():
"""尝试配置中文字体,避免乱码"""
candidates = [
'Microsoft YaHei', 'SimHei', 'PingFang SC', 'STSong',
'Noto Sans CJK SC', 'WenQuanYi Zen Hei',
]
available = {f.name for f in fm.fontManager.ttflist}
for name in candidates:
if name in available:
plt.rcParams['font.family'] = name
plt.rcParams['axes.unicode_minus'] = False
return name
# 找不到就用默认,可能出现方框
return None
_CHINESE_FONT = _setup_chinese_font()
# ── 工具函数 ──────────────────────────────────────────────────────────────────
def _date_int_to_str(d: int) -> str:
"""20250327 → '2025-03-27'"""
s = str(int(d))
return f'{s[:4]}-{s[4:6]}-{s[6:]}'
def _build_dataframe(klines: dict) -> pd.DataFrame:
"""从 klines 字典构建 mplfinance 需要的 DataFrame"""
dates_str = [_date_int_to_str(d) for d in klines['dates']]
df = pd.DataFrame(
{
'Open': [v if v is not None else np.nan for v in klines['open']],
'High': [v if v is not None else np.nan for v in klines['high']],
'Low': [v if v is not None else np.nan for v in klines['low']],
'Close': [v if v is not None else np.nan for v in klines['close']],
'Volume': [v if v is not None else np.nan for v in klines['volume']],
},
index=pd.DatetimeIndex(dates_str),
)
return df, dates_str
def _make_aline(line_pts: list):
"""
将 [{"date": int, "price": float}, ...] 转换为 mplfinance aline 格式的 2-端点元组
Returns:
(("2025-03-27", price1), ("2026-01-20", price2)) or None
"""
if not line_pts or len(line_pts) < 2:
return None
p1 = (_date_int_to_str(line_pts[0]['date']), line_pts[0]['price'])
p2 = (_date_int_to_str(line_pts[-1]['date']), line_pts[-1]['price'])
return (p1, p2)
def _build_pivot_arrays(pivots: list, dates_str: list, n: int) -> np.ndarray:
"""
将枢轴点列表映射为长度 n 的 float 数组(非枢轴位置为 NaN
"""
date_to_idx = {d: i for i, d in enumerate(dates_str)}
arr = np.full(n, np.nan)
for pt in pivots:
ds = _date_int_to_str(pt['date'])
idx = date_to_idx.get(ds)
if idx is not None:
arr[idx] = pt['price']
return arr
# ── 主绘图函数 ────────────────────────────────────────────────────────────────
def draw_triangle_chart(
detail: dict,
output_path: str = None,
show: bool = True,
show_pivots: bool = False, # 是否显示枢轴点(默认不显示,与前端一致)
):
"""
绘制收敛三角形 K 线图
Args:
detail: 收敛三角形详情字典,即 ``result.info['detail']``
output_path: 保存路径(如 ``'output/CSCO_日.png'``None 则不保存
show: 是否弹出交互窗口(在无 GUI 环境下设 False
show_pivots: 是否显示枢轴点标记(默认 False与前端一致
Returns:
output_path (便于链式调用)
"""
chart = detail.get('chart_data', {})
klines = chart.get('klines')
if not klines:
raise ValueError(
"chart_data 中没有 klines请用 include_klines=True 调用收敛三角形详情"
)
# ── 基本属性 ─────────────────────────────────
ticker = chart.get('ticker', '')
freq = chart.get('freq', 'D')
freq_label = {'D': '日K', 'W': '周K', 'M': '月K'}.get(freq, freq)
strength = detail.get('strength', 0.0)
direction = detail.get('direction', 'none')
dir_label = {'up': '↑上涨突破', 'down': '↓下跌突破', 'none': '整理中'}.get(direction, direction)
window_start = chart.get('window_start_date', '')
window_end = chart.get('window_end_date', '')
touches_upper = detail.get('touches_upper', '?')
touches_lower = detail.get('touches_lower', '?')
# ── 构建 DataFrame ─────────────────────────────
df, dates_str = _build_dataframe(klines)
n = len(df)
# ── 趋势线:从 window_start_date 延伸到 window_end_date与前端一致──────────
alines_list = []
alines_colors = []
upper_line = chart.get('upper_line', [])
lower_line = chart.get('lower_line', [])
# 将日期字符串转为集合,用于查找
date_to_idx = {d: i for i, d in enumerate(dates_str)}
window_start_str = _date_int_to_str(window_start) if window_start else None
window_end_str = _date_int_to_str(window_end) if window_end else None
# 辅助函数:查找最近的日期(与前端 getClosestDateCategory 一致)
def get_closest_date_idx(target_date: int, dates_str: list) -> int:
"""查找目标日期在 K 线数据中最接近的位置"""
target_str = _date_int_to_str(target_date)
if target_str in date_to_idx:
return date_to_idx[target_str]
# 找不到精确匹配时,找最近的日期
from datetime import datetime
target_dt = datetime.strptime(target_str, '%Y-%m-%d')
min_diff = float('inf')
closest_idx = 0
for i, d in enumerate(dates_str):
dt = datetime.strptime(d, '%Y-%m-%d')
diff = abs((dt - target_dt).days)
if diff < min_diff:
min_diff = diff
closest_idx = i
return closest_idx
# 辅助函数:计算延伸后的趋势线(与前端 getExtendedLinePoints 完全一致)
def make_extended_aline(line_pts, pivot_point, color):
"""
与前端 TechPattern.vue 的 getExtendedLinePoints 函数逻辑完全一致
趋势线从 window_start_date 延伸到 window_end_date
"""
if not line_pts or len(line_pts) < 2:
return None
p1 = line_pts[0]
p2 = line_pts[-1]
# 检查是否有 index 字段(公式返回的数据)
has_index = 'index' in p1 and 'index' in p2 and pivot_point and 'index' in pivot_point
if has_index and window_start and window_end:
# 使用原始 index 计算斜率(与前端一致)
slope = (p2['price'] - p1['price']) / (p2['index'] - p1['index']) if p2['index'] != p1['index'] else 0
# 前端: startIdx = 0 (window_start_date 对应 index 0)
start_idx = 0
# 查找 window_end_date 对应的 index与前端一致
end_idx = p2['index']
all_points = (chart.get('upper_line', []) + chart.get('lower_line', []) +
chart.get('upper_pivots', []) + chart.get('lower_pivots', []))
for pt in all_points:
if pt.get('date') == window_end:
end_idx = pt.get('index', end_idx)
break
# 计算延伸后的价格(与前端一致)
pivot_index = pivot_point['index']
pivot_price = pivot_point['price']
start_price = pivot_price + slope * (start_idx - pivot_index)
end_price = pivot_price + slope * (end_idx - pivot_index)
# 将 window_start_date 和 window_end_date 映射到 K 线数据的索引
kline_start_idx = get_closest_date_idx(window_start, dates_str)
kline_end_idx = get_closest_date_idx(window_end, dates_str)
print(f'[chart] 趋势线延伸: {window_start_str}({kline_start_idx}) -> {window_end_str}({kline_end_idx})')
print(f'[chart] 价格: {start_price:.2f} -> {end_price:.2f}')
return (
(kline_start_idx, start_price),
(kline_end_idx, end_price)
)
# 降级处理:仅连接两点(与前端一致)
d1 = _date_int_to_str(p1['date'])
d2 = _date_int_to_str(p2['date'])
idx1 = date_to_idx.get(d1)
idx2 = date_to_idx.get(d2)
if idx1 is None or idx2 is None:
print(f'[chart] 警告: 趋势线日期不在 K 线数据中: {d1}{d2}')
return None
return (
(idx1, p1['price']),
(idx2, p2['price'])
)
# 绘制上轨线(红色)- 与前端一致
if upper_line and window_start and window_end:
upper_aline = make_extended_aline(upper_line, upper_line[0] if upper_line else None, '#ef4444')
if upper_aline:
alines_list.append(upper_aline)
alines_colors.append('#ef4444') # 红色
# 绘制下轨线(绿色)- 与前端一致
if lower_line and window_start and window_end:
lower_aline = make_extended_aline(lower_line, lower_line[0] if lower_line else None, '#10b981')
if lower_aline:
alines_list.append(lower_aline)
alines_colors.append('#10b981') # 绿色
# ── 枢轴点散点图(可选)─────────────────────────
addplots = []
if show_pivots:
upper_pivots_arr = _build_pivot_arrays(chart.get('upper_pivots', []), dates_str, n)
lower_pivots_arr = _build_pivot_arrays(chart.get('lower_pivots', []), dates_str, n)
if np.any(~np.isnan(upper_pivots_arr)):
addplots.append(mpf.make_addplot(
upper_pivots_arr, type='scatter',
markersize=100, marker='^', color='#ef4444', alpha=0.85,
))
if np.any(~np.isnan(lower_pivots_arr)):
addplots.append(mpf.make_addplot(
lower_pivots_arr, type='scatter',
markersize=100, marker='v', color='#10b981', alpha=0.85,
))
# ── 突破价位水平线 ─────────────────────────────
bp_up = detail.get('breakout_price_up')
bp_down = detail.get('breakout_price_down')
hlines_vals = []
hlines_colors = []
if bp_up and float(bp_up) > 0:
hlines_vals.append(float(bp_up))
hlines_colors.append('#ef4444')
if bp_down and float(bp_down) > 0:
hlines_vals.append(float(bp_down))
hlines_colors.append('#10b981')
# ── 标题 ────────────────────────────────────────
title = (
f'{ticker} {freq_label} '
f'强度={strength:.3f} {dir_label}\n'
f'窗口 {window_start}~{window_end} '
f'(上沿触碰={touches_upper} 下沿触碰={touches_lower})'
)
# ── 组装 plot 参数 ────────────────────────────────
_rc = {'axes.unicode_minus': False}
if _CHINESE_FONT:
_rc['font.family'] = _CHINESE_FONT
_rc['font.sans-serif'] = [_CHINESE_FONT, 'DejaVu Sans']
_style = mpf.make_mpf_style(base_mpf_style='yahoo', rc=_rc)
kwargs = dict(
type='candle',
volume=False, # 不显示成交量
title=title,
style=_style,
figsize=(18, 10),
tight_layout=True,
)
if alines_list:
kwargs['alines'] = dict(
alines=alines_list,
colors=alines_colors,
linewidths=2.0,
alpha=0.90,
)
if addplots:
kwargs['addplot'] = addplots
if hlines_vals:
kwargs['hlines'] = dict(
hlines=hlines_vals,
colors=hlines_colors,
linestyle='--',
linewidths=1.2,
alpha=0.7,
)
if output_path:
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
# ── 绘图 ─────────────────────────────────────────
# 先不绘制趋势线,获取 figure 和 axes
_kwargs = {k: v for k, v in kwargs.items() if k not in ['alines', 'savefig']}
fig, axes = mpf.plot(df, **_kwargs, returnfig=True, warn_too_much_data=1000)
# 使用 matplotlib 直接在主图上绘制趋势线
# alines_list 格式: [((idx1, price1), (idx2, price2)), ...]
if alines_list and len(axes) > 0:
main_ax = axes[0] # 主图K线图
# 收集趋势线价格,用于调整 Y 轴范围(与前端一致)
trend_line_prices = []
for aline, color in zip(alines_list, alines_colors):
(idx1, p1), (idx2, p2) = aline
trend_line_prices.extend([p1, p2])
# 直接使用整数索引作为 x 坐标(与 mplfinance 内部一致)
main_ax.plot([idx1, idx2], [p1, p2], color=color, linewidth=2, alpha=0.9)
# 调整 Y 轴范围以包含趋势线(与前端一致,增加 5% 边距)
if trend_line_prices:
current_ylim = main_ax.get_ylim()
kline_min, kline_max = current_ylim
# 合并 K 线和趋势线的价格范围
all_min = min(kline_min, min(trend_line_prices))
all_max = max(kline_max, max(trend_line_prices))
price_range = all_max - all_min
# 增加 5% 边距(与前端一致)
new_min = all_min - price_range * 0.05
new_max = all_max + price_range * 0.05
main_ax.set_ylim(new_min, new_max)
print(f'[chart] Y 轴范围调整: {kline_min:.2f}~{kline_max:.2f} -> {new_min:.2f}~{new_max:.2f}')
if show:
plt.show()
if output_path:
fig.savefig(output_path, dpi=150, bbox_inches='tight')
print(f'[chart] 已保存: {output_path}')
plt.close(fig)
return output_path