功能: - 支持中文名称搜索股票(如"中控技术"、"英伟达") - 支持 A 股和美股 - 趋势线从窗口起点延伸到终点(与前端一致) - Y 轴自动调整以包含趋势线 - 支持 --window 参数(1Y/3Y/5Y/ALL) 用法: python validate.py 中控技术 日 --window 3Y --save python validate.py 英伟达 日 --window 3Y --save Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
362 lines
14 KiB
Python
362 lines
14 KiB
Python
"""
|
||
收敛三角形可视化绘制模块
|
||
|
||
根据 chart_data (包含 klines) 绘制 K 线图 + 三角形趋势线
|
||
"""
|
||
import os
|
||
import numpy as np
|
||
import pandas as pd
|
||
import mplfinance as mpf
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.font_manager as fm
|
||
|
||
|
||
# ── 中文字体支持 ──────────────────────────────────────────────────────────────
|
||
def _setup_chinese_font():
|
||
"""尝试配置中文字体,避免乱码"""
|
||
candidates = [
|
||
'Microsoft YaHei', 'SimHei', 'PingFang SC', 'STSong',
|
||
'Noto Sans CJK SC', 'WenQuanYi Zen Hei',
|
||
]
|
||
available = {f.name for f in fm.fontManager.ttflist}
|
||
for name in candidates:
|
||
if name in available:
|
||
plt.rcParams['font.family'] = name
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
return name
|
||
# 找不到就用默认,可能出现方框
|
||
return None
|
||
|
||
|
||
_CHINESE_FONT = _setup_chinese_font()
|
||
|
||
|
||
# ── 工具函数 ──────────────────────────────────────────────────────────────────
|
||
def _date_int_to_str(d: int) -> str:
|
||
"""20250327 → '2025-03-27'"""
|
||
s = str(int(d))
|
||
return f'{s[:4]}-{s[4:6]}-{s[6:]}'
|
||
|
||
|
||
def _build_dataframe(klines: dict) -> pd.DataFrame:
|
||
"""从 klines 字典构建 mplfinance 需要的 DataFrame"""
|
||
dates_str = [_date_int_to_str(d) for d in klines['dates']]
|
||
df = pd.DataFrame(
|
||
{
|
||
'Open': [v if v is not None else np.nan for v in klines['open']],
|
||
'High': [v if v is not None else np.nan for v in klines['high']],
|
||
'Low': [v if v is not None else np.nan for v in klines['low']],
|
||
'Close': [v if v is not None else np.nan for v in klines['close']],
|
||
'Volume': [v if v is not None else np.nan for v in klines['volume']],
|
||
},
|
||
index=pd.DatetimeIndex(dates_str),
|
||
)
|
||
return df, dates_str
|
||
|
||
|
||
def _make_aline(line_pts: list):
|
||
"""
|
||
将 [{"date": int, "price": float}, ...] 转换为 mplfinance aline 格式的 2-端点元组
|
||
|
||
Returns:
|
||
(("2025-03-27", price1), ("2026-01-20", price2)) or None
|
||
"""
|
||
if not line_pts or len(line_pts) < 2:
|
||
return None
|
||
p1 = (_date_int_to_str(line_pts[0]['date']), line_pts[0]['price'])
|
||
p2 = (_date_int_to_str(line_pts[-1]['date']), line_pts[-1]['price'])
|
||
return (p1, p2)
|
||
|
||
|
||
def _build_pivot_arrays(pivots: list, dates_str: list, n: int) -> np.ndarray:
|
||
"""
|
||
将枢轴点列表映射为长度 n 的 float 数组(非枢轴位置为 NaN)
|
||
"""
|
||
date_to_idx = {d: i for i, d in enumerate(dates_str)}
|
||
arr = np.full(n, np.nan)
|
||
for pt in pivots:
|
||
ds = _date_int_to_str(pt['date'])
|
||
idx = date_to_idx.get(ds)
|
||
if idx is not None:
|
||
arr[idx] = pt['price']
|
||
return arr
|
||
|
||
|
||
# ── 主绘图函数 ────────────────────────────────────────────────────────────────
|
||
def draw_triangle_chart(
|
||
detail: dict,
|
||
output_path: str = None,
|
||
show: bool = True,
|
||
show_pivots: bool = False, # 是否显示枢轴点(默认不显示,与前端一致)
|
||
):
|
||
"""
|
||
绘制收敛三角形 K 线图
|
||
|
||
Args:
|
||
detail: 收敛三角形详情字典,即 ``result.info['detail']``
|
||
output_path: 保存路径(如 ``'output/CSCO_日.png'``),None 则不保存
|
||
show: 是否弹出交互窗口(在无 GUI 环境下设 False)
|
||
show_pivots: 是否显示枢轴点标记(默认 False,与前端一致)
|
||
|
||
Returns:
|
||
output_path (便于链式调用)
|
||
"""
|
||
chart = detail.get('chart_data', {})
|
||
klines = chart.get('klines')
|
||
if not klines:
|
||
raise ValueError(
|
||
"chart_data 中没有 klines,请用 include_klines=True 调用收敛三角形详情"
|
||
)
|
||
|
||
# ── 基本属性 ─────────────────────────────────
|
||
ticker = chart.get('ticker', '')
|
||
freq = chart.get('freq', 'D')
|
||
freq_label = {'D': '日K', 'W': '周K', 'M': '月K'}.get(freq, freq)
|
||
strength = detail.get('strength', 0.0)
|
||
direction = detail.get('direction', 'none')
|
||
dir_label = {'up': '↑上涨突破', 'down': '↓下跌突破', 'none': '整理中'}.get(direction, direction)
|
||
window_start = chart.get('window_start_date', '')
|
||
window_end = chart.get('window_end_date', '')
|
||
touches_upper = detail.get('touches_upper', '?')
|
||
touches_lower = detail.get('touches_lower', '?')
|
||
|
||
# ── 构建 DataFrame ─────────────────────────────
|
||
df, dates_str = _build_dataframe(klines)
|
||
n = len(df)
|
||
|
||
# ── 趋势线:从 window_start_date 延伸到 window_end_date(与前端一致)──────────
|
||
alines_list = []
|
||
alines_colors = []
|
||
|
||
upper_line = chart.get('upper_line', [])
|
||
lower_line = chart.get('lower_line', [])
|
||
|
||
# 将日期字符串转为集合,用于查找
|
||
date_to_idx = {d: i for i, d in enumerate(dates_str)}
|
||
window_start_str = _date_int_to_str(window_start) if window_start else None
|
||
window_end_str = _date_int_to_str(window_end) if window_end else None
|
||
|
||
# 辅助函数:查找最近的日期(与前端 getClosestDateCategory 一致)
|
||
def get_closest_date_idx(target_date: int, dates_str: list) -> int:
|
||
"""查找目标日期在 K 线数据中最接近的位置"""
|
||
target_str = _date_int_to_str(target_date)
|
||
if target_str in date_to_idx:
|
||
return date_to_idx[target_str]
|
||
|
||
# 找不到精确匹配时,找最近的日期
|
||
from datetime import datetime
|
||
target_dt = datetime.strptime(target_str, '%Y-%m-%d')
|
||
min_diff = float('inf')
|
||
closest_idx = 0
|
||
|
||
for i, d in enumerate(dates_str):
|
||
dt = datetime.strptime(d, '%Y-%m-%d')
|
||
diff = abs((dt - target_dt).days)
|
||
if diff < min_diff:
|
||
min_diff = diff
|
||
closest_idx = i
|
||
|
||
return closest_idx
|
||
|
||
# 辅助函数:计算延伸后的趋势线(与前端 getExtendedLinePoints 完全一致)
|
||
def make_extended_aline(line_pts, pivot_point, color):
|
||
"""
|
||
与前端 TechPattern.vue 的 getExtendedLinePoints 函数逻辑完全一致
|
||
趋势线从 window_start_date 延伸到 window_end_date
|
||
"""
|
||
if not line_pts or len(line_pts) < 2:
|
||
return None
|
||
|
||
p1 = line_pts[0]
|
||
p2 = line_pts[-1]
|
||
|
||
# 检查是否有 index 字段(公式返回的数据)
|
||
has_index = 'index' in p1 and 'index' in p2 and pivot_point and 'index' in pivot_point
|
||
|
||
if has_index and window_start and window_end:
|
||
# 使用原始 index 计算斜率(与前端一致)
|
||
slope = (p2['price'] - p1['price']) / (p2['index'] - p1['index']) if p2['index'] != p1['index'] else 0
|
||
|
||
# 前端: startIdx = 0 (window_start_date 对应 index 0)
|
||
start_idx = 0
|
||
|
||
# 查找 window_end_date 对应的 index(与前端一致)
|
||
end_idx = p2['index']
|
||
all_points = (chart.get('upper_line', []) + chart.get('lower_line', []) +
|
||
chart.get('upper_pivots', []) + chart.get('lower_pivots', []))
|
||
for pt in all_points:
|
||
if pt.get('date') == window_end:
|
||
end_idx = pt.get('index', end_idx)
|
||
break
|
||
|
||
# 计算延伸后的价格(与前端一致)
|
||
pivot_index = pivot_point['index']
|
||
pivot_price = pivot_point['price']
|
||
start_price = pivot_price + slope * (start_idx - pivot_index)
|
||
end_price = pivot_price + slope * (end_idx - pivot_index)
|
||
|
||
# 将 window_start_date 和 window_end_date 映射到 K 线数据的索引
|
||
kline_start_idx = get_closest_date_idx(window_start, dates_str)
|
||
kline_end_idx = get_closest_date_idx(window_end, dates_str)
|
||
|
||
print(f'[chart] 趋势线延伸: {window_start_str}({kline_start_idx}) -> {window_end_str}({kline_end_idx})')
|
||
print(f'[chart] 价格: {start_price:.2f} -> {end_price:.2f}')
|
||
|
||
return (
|
||
(kline_start_idx, start_price),
|
||
(kline_end_idx, end_price)
|
||
)
|
||
|
||
# 降级处理:仅连接两点(与前端一致)
|
||
d1 = _date_int_to_str(p1['date'])
|
||
d2 = _date_int_to_str(p2['date'])
|
||
|
||
idx1 = date_to_idx.get(d1)
|
||
idx2 = date_to_idx.get(d2)
|
||
|
||
if idx1 is None or idx2 is None:
|
||
print(f'[chart] 警告: 趋势线日期不在 K 线数据中: {d1} 或 {d2}')
|
||
return None
|
||
|
||
return (
|
||
(idx1, p1['price']),
|
||
(idx2, p2['price'])
|
||
)
|
||
|
||
# 绘制上轨线(红色)- 与前端一致
|
||
if upper_line and window_start and window_end:
|
||
upper_aline = make_extended_aline(upper_line, upper_line[0] if upper_line else None, '#ef4444')
|
||
if upper_aline:
|
||
alines_list.append(upper_aline)
|
||
alines_colors.append('#ef4444') # 红色
|
||
|
||
# 绘制下轨线(绿色)- 与前端一致
|
||
if lower_line and window_start and window_end:
|
||
lower_aline = make_extended_aline(lower_line, lower_line[0] if lower_line else None, '#10b981')
|
||
if lower_aline:
|
||
alines_list.append(lower_aline)
|
||
alines_colors.append('#10b981') # 绿色
|
||
|
||
# ── 枢轴点散点图(可选)─────────────────────────
|
||
addplots = []
|
||
if show_pivots:
|
||
upper_pivots_arr = _build_pivot_arrays(chart.get('upper_pivots', []), dates_str, n)
|
||
lower_pivots_arr = _build_pivot_arrays(chart.get('lower_pivots', []), dates_str, n)
|
||
|
||
if np.any(~np.isnan(upper_pivots_arr)):
|
||
addplots.append(mpf.make_addplot(
|
||
upper_pivots_arr, type='scatter',
|
||
markersize=100, marker='^', color='#ef4444', alpha=0.85,
|
||
))
|
||
if np.any(~np.isnan(lower_pivots_arr)):
|
||
addplots.append(mpf.make_addplot(
|
||
lower_pivots_arr, type='scatter',
|
||
markersize=100, marker='v', color='#10b981', alpha=0.85,
|
||
))
|
||
|
||
# ── 突破价位水平线 ─────────────────────────────
|
||
bp_up = detail.get('breakout_price_up')
|
||
bp_down = detail.get('breakout_price_down')
|
||
hlines_vals = []
|
||
hlines_colors = []
|
||
if bp_up and float(bp_up) > 0:
|
||
hlines_vals.append(float(bp_up))
|
||
hlines_colors.append('#ef4444')
|
||
if bp_down and float(bp_down) > 0:
|
||
hlines_vals.append(float(bp_down))
|
||
hlines_colors.append('#10b981')
|
||
|
||
# ── 标题 ────────────────────────────────────────
|
||
title = (
|
||
f'{ticker} {freq_label} '
|
||
f'强度={strength:.3f} {dir_label}\n'
|
||
f'窗口 {window_start}~{window_end} '
|
||
f'(上沿触碰={touches_upper} 下沿触碰={touches_lower})'
|
||
)
|
||
|
||
# ── 组装 plot 参数 ────────────────────────────────
|
||
_rc = {'axes.unicode_minus': False}
|
||
if _CHINESE_FONT:
|
||
_rc['font.family'] = _CHINESE_FONT
|
||
_rc['font.sans-serif'] = [_CHINESE_FONT, 'DejaVu Sans']
|
||
|
||
_style = mpf.make_mpf_style(base_mpf_style='yahoo', rc=_rc)
|
||
|
||
kwargs = dict(
|
||
type='candle',
|
||
volume=False, # 不显示成交量
|
||
title=title,
|
||
style=_style,
|
||
figsize=(18, 10),
|
||
tight_layout=True,
|
||
)
|
||
|
||
if alines_list:
|
||
kwargs['alines'] = dict(
|
||
alines=alines_list,
|
||
colors=alines_colors,
|
||
linewidths=2.0,
|
||
alpha=0.90,
|
||
)
|
||
|
||
if addplots:
|
||
kwargs['addplot'] = addplots
|
||
|
||
if hlines_vals:
|
||
kwargs['hlines'] = dict(
|
||
hlines=hlines_vals,
|
||
colors=hlines_colors,
|
||
linestyle='--',
|
||
linewidths=1.2,
|
||
alpha=0.7,
|
||
)
|
||
|
||
if output_path:
|
||
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
||
|
||
# ── 绘图 ─────────────────────────────────────────
|
||
# 先不绘制趋势线,获取 figure 和 axes
|
||
_kwargs = {k: v for k, v in kwargs.items() if k not in ['alines', 'savefig']}
|
||
fig, axes = mpf.plot(df, **_kwargs, returnfig=True, warn_too_much_data=1000)
|
||
|
||
# 使用 matplotlib 直接在主图上绘制趋势线
|
||
# alines_list 格式: [((idx1, price1), (idx2, price2)), ...]
|
||
if alines_list and len(axes) > 0:
|
||
main_ax = axes[0] # 主图(K线图)
|
||
|
||
# 收集趋势线价格,用于调整 Y 轴范围(与前端一致)
|
||
trend_line_prices = []
|
||
for aline, color in zip(alines_list, alines_colors):
|
||
(idx1, p1), (idx2, p2) = aline
|
||
trend_line_prices.extend([p1, p2])
|
||
# 直接使用整数索引作为 x 坐标(与 mplfinance 内部一致)
|
||
main_ax.plot([idx1, idx2], [p1, p2], color=color, linewidth=2, alpha=0.9)
|
||
|
||
# 调整 Y 轴范围以包含趋势线(与前端一致,增加 5% 边距)
|
||
if trend_line_prices:
|
||
current_ylim = main_ax.get_ylim()
|
||
kline_min, kline_max = current_ylim
|
||
|
||
# 合并 K 线和趋势线的价格范围
|
||
all_min = min(kline_min, min(trend_line_prices))
|
||
all_max = max(kline_max, max(trend_line_prices))
|
||
price_range = all_max - all_min
|
||
|
||
# 增加 5% 边距(与前端一致)
|
||
new_min = all_min - price_range * 0.05
|
||
new_max = all_max + price_range * 0.05
|
||
|
||
main_ax.set_ylim(new_min, new_max)
|
||
print(f'[chart] Y 轴范围调整: {kline_min:.2f}~{kline_max:.2f} -> {new_min:.2f}~{new_max:.2f}')
|
||
|
||
if show:
|
||
plt.show()
|
||
|
||
if output_path:
|
||
fig.savefig(output_path, dpi=150, bbox_inches='tight')
|
||
print(f'[chart] 已保存: {output_path}')
|
||
|
||
plt.close(fig)
|
||
|
||
return output_path
|