褚宏光 59622a6ef7 feat: 增加命令行参数以支持自定义枢轴点窗口、收敛比和最小收敛比例
docs: 更新 README 示例以包含比亚迪的日、周、月 K 线图
fix: 修复趋势线绘制逻辑,支持周K/月K聚合后的日期匹配
docs: 添加 K 线形态参数调整建议文档
2026-03-04 17:39:27 +08:00

362 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
收敛三角形可视化绘制模块
根据 chart_data (包含 klines) 绘制 K 线图 + 三角形趋势线
"""
import os
import numpy as np
import pandas as pd
import mplfinance as mpf
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
# ── 中文字体支持 ──────────────────────────────────────────────────────────────
def _setup_chinese_font():
"""尝试配置中文字体,避免乱码"""
candidates = [
'Microsoft YaHei', 'SimHei', 'PingFang SC', 'STSong',
'Noto Sans CJK SC', 'WenQuanYi Zen Hei',
]
available = {f.name for f in fm.fontManager.ttflist}
for name in candidates:
if name in available:
plt.rcParams['font.family'] = name
plt.rcParams['axes.unicode_minus'] = False
return name
# 找不到就用默认,可能出现方框
return None
_CHINESE_FONT = _setup_chinese_font()
# ── 工具函数 ──────────────────────────────────────────────────────────────────
def _date_int_to_str(d: int) -> str:
"""20250327 → '2025-03-27'"""
s = str(int(d))
return f'{s[:4]}-{s[4:6]}-{s[6:]}'
def _build_dataframe(klines: dict) -> pd.DataFrame:
"""从 klines 字典构建 mplfinance 需要的 DataFrame"""
dates_str = [_date_int_to_str(d) for d in klines['dates']]
df = pd.DataFrame(
{
'Open': [v if v is not None else np.nan for v in klines['open']],
'High': [v if v is not None else np.nan for v in klines['high']],
'Low': [v if v is not None else np.nan for v in klines['low']],
'Close': [v if v is not None else np.nan for v in klines['close']],
'Volume': [v if v is not None else np.nan for v in klines['volume']],
},
index=pd.DatetimeIndex(dates_str),
)
return df, dates_str
def _make_aline(line_pts: list):
"""
将 [{"date": int, "price": float}, ...] 转换为 mplfinance aline 格式的 2-端点元组
Returns:
(("2025-03-27", price1), ("2026-01-20", price2)) or None
"""
if not line_pts or len(line_pts) < 2:
return None
p1 = (_date_int_to_str(line_pts[0]['date']), line_pts[0]['price'])
p2 = (_date_int_to_str(line_pts[-1]['date']), line_pts[-1]['price'])
return (p1, p2)
def _build_pivot_arrays(pivots: list, dates_str: list, n: int) -> np.ndarray:
"""
将枢轴点列表映射为长度 n 的 float 数组(非枢轴位置为 NaN
"""
date_to_idx = {d: i for i, d in enumerate(dates_str)}
arr = np.full(n, np.nan)
for pt in pivots:
ds = _date_int_to_str(pt['date'])
idx = date_to_idx.get(ds)
if idx is not None:
arr[idx] = pt['price']
return arr
# ── 主绘图函数 ────────────────────────────────────────────────────────────────
def draw_triangle_chart(
detail: dict,
output_path: str = None,
show: bool = True,
show_pivots: bool = False, # 是否显示枢轴点(默认不显示,与前端一致)
):
"""
绘制收敛三角形 K 线图
Args:
detail: 收敛三角形详情字典,即 ``result.info['detail']``
output_path: 保存路径(如 ``'output/CSCO_日.png'``None 则不保存
show: 是否弹出交互窗口(在无 GUI 环境下设 False
show_pivots: 是否显示枢轴点标记(默认 False与前端一致
Returns:
output_path (便于链式调用)
"""
chart = detail.get('chart_data', {})
klines = chart.get('klines')
if not klines:
raise ValueError(
"chart_data 中没有 klines请用 include_klines=True 调用收敛三角形详情"
)
# ── 基本属性 ─────────────────────────────────
ticker = chart.get('ticker', '')
freq = chart.get('freq', 'D')
freq_label = {'D': '日K', 'W': '周K', 'M': '月K'}.get(freq, freq)
strength = detail.get('strength', 0.0)
direction = detail.get('direction', 'none')
dir_label = {'up': '↑上涨突破', 'down': '↓下跌突破', 'none': '整理中'}.get(direction, direction)
window_start = chart.get('window_start_date', '')
window_end = chart.get('window_end_date', '')
touches_upper = detail.get('touches_upper', '?')
touches_lower = detail.get('touches_lower', '?')
# ── 构建 DataFrame ─────────────────────────────
df, dates_str = _build_dataframe(klines)
n = len(df)
# ── 趋势线:从 window_start_date 延伸到 window_end_date与前端一致──────────
alines_list = []
alines_colors = []
upper_line = chart.get('upper_line', [])
lower_line = chart.get('lower_line', [])
# 将日期字符串转为集合,用于查找
date_to_idx = {d: i for i, d in enumerate(dates_str)}
window_start_str = _date_int_to_str(window_start) if window_start else None
window_end_str = _date_int_to_str(window_end) if window_end else None
# 辅助函数:查找最近的日期(与前端 getClosestDateCategory 一致)
def get_closest_date_idx(target_date: int, dates_str: list) -> int:
"""查找目标日期在 K 线数据中最接近的位置"""
target_str = _date_int_to_str(target_date)
if target_str in date_to_idx:
return date_to_idx[target_str]
# 找不到精确匹配时,找最近的日期
from datetime import datetime
target_dt = datetime.strptime(target_str, '%Y-%m-%d')
min_diff = float('inf')
closest_idx = 0
for i, d in enumerate(dates_str):
dt = datetime.strptime(d, '%Y-%m-%d')
diff = abs((dt - target_dt).days)
if diff < min_diff:
min_diff = diff
closest_idx = i
return closest_idx
# 辅助函数:计算延伸后的趋势线(与前端 getExtendedLinePoints 完全一致)
def make_extended_aline(line_pts, pivot_point, color):
"""
与前端 TechPattern.vue 的 getExtendedLinePoints 函数逻辑完全一致
趋势线从 window_start_date 延伸到 window_end_date
关键修复:
1. 使用 K 线数据的实际索引,而不是公式内部的 index
2. 使用 get_closest_date_idx 查找日期支持周K/月K 聚合后的日期匹配
"""
if not line_pts or len(line_pts) < 2:
return None
p1 = line_pts[0]
p2 = line_pts[-1]
if window_start and window_end:
# 将枢轴点日期映射到 K 线数据的实际索引
# 使用 get_closest_date_idx 支持周K/月K 聚合后的日期匹配
idx1 = get_closest_date_idx(p1['date'], dates_str)
idx2 = get_closest_date_idx(p2['date'], dates_str)
d1 = _date_int_to_str(p1['date'])
d2 = _date_int_to_str(p2['date'])
print(f'[chart] 枢轴点日期映射: {d1} -> K线索引{idx1}, {d2} -> K线索引{idx2}')
# 使用 K 线数据的实际索引计算斜率
slope = (p2['price'] - p1['price']) / (idx2 - idx1) if idx2 != idx1 else 0
# 找到 window_start_date 和 window_end_date 在 K 线数据中的索引
kline_start_idx = get_closest_date_idx(window_start, dates_str)
kline_end_idx = get_closest_date_idx(window_end, dates_str)
# 使用 K 线数据的索引计算延伸后的价格
# 价格 = pivot_price + slope * (目标索引 - pivot索引)
pivot_kline_idx = idx1 # 使用第一个枢轴点的 K 线索引
pivot_price = p1['price']
start_price = pivot_price + slope * (kline_start_idx - pivot_kline_idx)
end_price = pivot_price + slope * (kline_end_idx - pivot_kline_idx)
print(f'[chart] 趋势线延伸: {window_start_str}({kline_start_idx}) -> {window_end_str}({kline_end_idx})')
print(f'[chart] 枢轴点: {d1}({idx1}) 价格={pivot_price:.2f}')
print(f'[chart] 斜率: {slope:.4f}')
print(f'[chart] 价格: {start_price:.2f} -> {end_price:.2f}')
return (
(kline_start_idx, start_price),
(kline_end_idx, end_price)
)
# 降级处理:仅连接两点
d1 = _date_int_to_str(p1['date'])
d2 = _date_int_to_str(p2['date'])
# 使用 get_closest_date_idx 支持日期模糊匹配
idx1 = get_closest_date_idx(p1['date'], dates_str)
idx2 = get_closest_date_idx(p2['date'], dates_str)
return (
(idx1, p1['price']),
(idx2, p2['price'])
)
# 绘制上轨线(红色)- 与前端一致
if upper_line and window_start and window_end:
upper_aline = make_extended_aline(upper_line, upper_line[0] if upper_line else None, '#ef4444')
if upper_aline:
alines_list.append(upper_aline)
alines_colors.append('#ef4444') # 红色
# 绘制下轨线(绿色)- 与前端一致
if lower_line and window_start and window_end:
lower_aline = make_extended_aline(lower_line, lower_line[0] if lower_line else None, '#10b981')
if lower_aline:
alines_list.append(lower_aline)
alines_colors.append('#10b981') # 绿色
# ── 枢轴点散点图(可选)─────────────────────────
addplots = []
if show_pivots:
upper_pivots_arr = _build_pivot_arrays(chart.get('upper_pivots', []), dates_str, n)
lower_pivots_arr = _build_pivot_arrays(chart.get('lower_pivots', []), dates_str, n)
if np.any(~np.isnan(upper_pivots_arr)):
addplots.append(mpf.make_addplot(
upper_pivots_arr, type='scatter',
markersize=100, marker='^', color='#ef4444', alpha=0.85,
))
if np.any(~np.isnan(lower_pivots_arr)):
addplots.append(mpf.make_addplot(
lower_pivots_arr, type='scatter',
markersize=100, marker='v', color='#10b981', alpha=0.85,
))
# ── 突破价位水平线 ─────────────────────────────
bp_up = detail.get('breakout_price_up')
bp_down = detail.get('breakout_price_down')
hlines_vals = []
hlines_colors = []
if bp_up and float(bp_up) > 0:
hlines_vals.append(float(bp_up))
hlines_colors.append('#ef4444')
if bp_down and float(bp_down) > 0:
hlines_vals.append(float(bp_down))
hlines_colors.append('#10b981')
# ── 标题 ────────────────────────────────────────
title = (
f'{ticker} {freq_label} '
f'强度={strength:.3f} {dir_label}\n'
f'窗口 {window_start}~{window_end} '
f'(上沿触碰={touches_upper} 下沿触碰={touches_lower})'
)
# ── 组装 plot 参数 ────────────────────────────────
_rc = {'axes.unicode_minus': False}
if _CHINESE_FONT:
_rc['font.family'] = _CHINESE_FONT
_rc['font.sans-serif'] = [_CHINESE_FONT, 'DejaVu Sans']
_style = mpf.make_mpf_style(base_mpf_style='yahoo', rc=_rc)
kwargs = dict(
type='candle',
volume=False, # 不显示成交量
title=title,
style=_style,
figsize=(18, 10),
tight_layout=True,
)
if alines_list:
kwargs['alines'] = dict(
alines=alines_list,
colors=alines_colors,
linewidths=2.0,
alpha=0.90,
)
if addplots:
kwargs['addplot'] = addplots
if hlines_vals:
kwargs['hlines'] = dict(
hlines=hlines_vals,
colors=hlines_colors,
linestyle='--',
linewidths=1.2,
alpha=0.7,
)
if output_path:
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
# ── 绘图 ─────────────────────────────────────────
# 先不绘制趋势线,获取 figure 和 axes
_kwargs = {k: v for k, v in kwargs.items() if k not in ['alines', 'savefig']}
fig, axes = mpf.plot(df, **_kwargs, returnfig=True, warn_too_much_data=1000)
# 使用 matplotlib 直接在主图上绘制趋势线
# alines_list 格式: [((idx1, price1), (idx2, price2)), ...]
if alines_list and len(axes) > 0:
main_ax = axes[0] # 主图K线图
# 收集趋势线价格,用于调整 Y 轴范围(与前端一致)
trend_line_prices = []
for aline, color in zip(alines_list, alines_colors):
(idx1, p1), (idx2, p2) = aline
trend_line_prices.extend([p1, p2])
# 直接使用整数索引作为 x 坐标(与 mplfinance 内部一致)
main_ax.plot([idx1, idx2], [p1, p2], color=color, linewidth=2, alpha=0.9)
# 调整 Y 轴范围以包含趋势线(与前端一致,增加 5% 边距)
if trend_line_prices:
current_ylim = main_ax.get_ylim()
kline_min, kline_max = current_ylim
# 合并 K 线和趋势线的价格范围
all_min = min(kline_min, min(trend_line_prices))
all_max = max(kline_max, max(trend_line_prices))
price_range = all_max - all_min
# 增加 5% 边距(与前端一致)
new_min = all_min - price_range * 0.05
new_max = all_max + price_range * 0.05
main_ax.set_ylim(new_min, new_max)
print(f'[chart] Y 轴范围调整: {kline_min:.2f}~{kline_max:.2f} -> {new_min:.2f}~{new_max:.2f}')
if show:
plt.show()
if output_path:
fig.savefig(output_path, dpi=150, bbox_inches='tight')
print(f'[chart] 已保存: {output_path}')
plt.close(fig)
return output_path