""" 收敛三角形可视化绘制模块 根据 chart_data (包含 klines) 绘制 K 线图 + 三角形趋势线 """ import os import numpy as np import pandas as pd import mplfinance as mpf import matplotlib.pyplot as plt import matplotlib.font_manager as fm # ── 中文字体支持 ────────────────────────────────────────────────────────────── def _setup_chinese_font(): """尝试配置中文字体,避免乱码""" candidates = [ 'Microsoft YaHei', 'SimHei', 'PingFang SC', 'STSong', 'Noto Sans CJK SC', 'WenQuanYi Zen Hei', ] available = {f.name for f in fm.fontManager.ttflist} for name in candidates: if name in available: plt.rcParams['font.family'] = name plt.rcParams['axes.unicode_minus'] = False return name # 找不到就用默认,可能出现方框 return None _CHINESE_FONT = _setup_chinese_font() # ── 工具函数 ────────────────────────────────────────────────────────────────── def _date_int_to_str(d: int) -> str: """20250327 → '2025-03-27'""" s = str(int(d)) return f'{s[:4]}-{s[4:6]}-{s[6:]}' def _build_dataframe(klines: dict) -> pd.DataFrame: """从 klines 字典构建 mplfinance 需要的 DataFrame""" dates_str = [_date_int_to_str(d) for d in klines['dates']] df = pd.DataFrame( { 'Open': [v if v is not None else np.nan for v in klines['open']], 'High': [v if v is not None else np.nan for v in klines['high']], 'Low': [v if v is not None else np.nan for v in klines['low']], 'Close': [v if v is not None else np.nan for v in klines['close']], 'Volume': [v if v is not None else np.nan for v in klines['volume']], }, index=pd.DatetimeIndex(dates_str), ) return df, dates_str def _make_aline(line_pts: list): """ 将 [{"date": int, "price": float}, ...] 转换为 mplfinance aline 格式的 2-端点元组 Returns: (("2025-03-27", price1), ("2026-01-20", price2)) or None """ if not line_pts or len(line_pts) < 2: return None p1 = (_date_int_to_str(line_pts[0]['date']), line_pts[0]['price']) p2 = (_date_int_to_str(line_pts[-1]['date']), line_pts[-1]['price']) return (p1, p2) def _build_pivot_arrays(pivots: list, dates_str: list, n: int) -> np.ndarray: """ 将枢轴点列表映射为长度 n 的 float 数组(非枢轴位置为 NaN) """ date_to_idx = {d: i for i, d in enumerate(dates_str)} arr = np.full(n, np.nan) for pt in pivots: ds = _date_int_to_str(pt['date']) idx = date_to_idx.get(ds) if idx is not None: arr[idx] = pt['price'] return arr # ── 主绘图函数 ──────────────────────────────────────────────────────────────── def draw_triangle_chart( detail: dict, output_path: str = None, show: bool = True, show_pivots: bool = False, # 是否显示枢轴点(默认不显示,与前端一致) ): """ 绘制收敛三角形 K 线图 Args: detail: 收敛三角形详情字典,即 ``result.info['detail']`` output_path: 保存路径(如 ``'output/CSCO_日.png'``),None 则不保存 show: 是否弹出交互窗口(在无 GUI 环境下设 False) show_pivots: 是否显示枢轴点标记(默认 False,与前端一致) Returns: output_path (便于链式调用) """ chart = detail.get('chart_data', {}) klines = chart.get('klines') if not klines: raise ValueError( "chart_data 中没有 klines,请用 include_klines=True 调用收敛三角形详情" ) # ── 基本属性 ───────────────────────────────── ticker = chart.get('ticker', '') freq = chart.get('freq', 'D') freq_label = {'D': '日K', 'W': '周K', 'M': '月K'}.get(freq, freq) strength = detail.get('strength', 0.0) direction = detail.get('direction', 'none') dir_label = {'up': '↑上涨突破', 'down': '↓下跌突破', 'none': '整理中'}.get(direction, direction) window_start = chart.get('window_start_date', '') window_end = chart.get('window_end_date', '') touches_upper = detail.get('touches_upper', '?') touches_lower = detail.get('touches_lower', '?') # ── 构建 DataFrame ───────────────────────────── df, dates_str = _build_dataframe(klines) n = len(df) # ── 趋势线:从 window_start_date 延伸到 window_end_date(与前端一致)────────── alines_list = [] alines_colors = [] upper_line = chart.get('upper_line', []) lower_line = chart.get('lower_line', []) # 将日期字符串转为集合,用于查找 date_to_idx = {d: i for i, d in enumerate(dates_str)} window_start_str = _date_int_to_str(window_start) if window_start else None window_end_str = _date_int_to_str(window_end) if window_end else None # 辅助函数:查找最近的日期(与前端 getClosestDateCategory 一致) def get_closest_date_idx(target_date: int, dates_str: list) -> int: """查找目标日期在 K 线数据中最接近的位置""" target_str = _date_int_to_str(target_date) if target_str in date_to_idx: return date_to_idx[target_str] # 找不到精确匹配时,找最近的日期 from datetime import datetime target_dt = datetime.strptime(target_str, '%Y-%m-%d') min_diff = float('inf') closest_idx = 0 for i, d in enumerate(dates_str): dt = datetime.strptime(d, '%Y-%m-%d') diff = abs((dt - target_dt).days) if diff < min_diff: min_diff = diff closest_idx = i return closest_idx # 辅助函数:计算延伸后的趋势线(与前端 getExtendedLinePoints 完全一致) def make_extended_aline(line_pts, pivot_point, color): """ 与前端 TechPattern.vue 的 getExtendedLinePoints 函数逻辑完全一致 趋势线从 window_start_date 延伸到 window_end_date """ if not line_pts or len(line_pts) < 2: return None p1 = line_pts[0] p2 = line_pts[-1] # 检查是否有 index 字段(公式返回的数据) has_index = 'index' in p1 and 'index' in p2 and pivot_point and 'index' in pivot_point if has_index and window_start and window_end: # 使用原始 index 计算斜率(与前端一致) slope = (p2['price'] - p1['price']) / (p2['index'] - p1['index']) if p2['index'] != p1['index'] else 0 # 前端: startIdx = 0 (window_start_date 对应 index 0) start_idx = 0 # 查找 window_end_date 对应的 index(与前端一致) end_idx = p2['index'] all_points = (chart.get('upper_line', []) + chart.get('lower_line', []) + chart.get('upper_pivots', []) + chart.get('lower_pivots', [])) for pt in all_points: if pt.get('date') == window_end: end_idx = pt.get('index', end_idx) break # 计算延伸后的价格(与前端一致) pivot_index = pivot_point['index'] pivot_price = pivot_point['price'] start_price = pivot_price + slope * (start_idx - pivot_index) end_price = pivot_price + slope * (end_idx - pivot_index) # 将 window_start_date 和 window_end_date 映射到 K 线数据的索引 kline_start_idx = get_closest_date_idx(window_start, dates_str) kline_end_idx = get_closest_date_idx(window_end, dates_str) print(f'[chart] 趋势线延伸: {window_start_str}({kline_start_idx}) -> {window_end_str}({kline_end_idx})') print(f'[chart] 价格: {start_price:.2f} -> {end_price:.2f}') return ( (kline_start_idx, start_price), (kline_end_idx, end_price) ) # 降级处理:仅连接两点(与前端一致) d1 = _date_int_to_str(p1['date']) d2 = _date_int_to_str(p2['date']) idx1 = date_to_idx.get(d1) idx2 = date_to_idx.get(d2) if idx1 is None or idx2 is None: print(f'[chart] 警告: 趋势线日期不在 K 线数据中: {d1} 或 {d2}') return None return ( (idx1, p1['price']), (idx2, p2['price']) ) # 绘制上轨线(红色)- 与前端一致 if upper_line and window_start and window_end: upper_aline = make_extended_aline(upper_line, upper_line[0] if upper_line else None, '#ef4444') if upper_aline: alines_list.append(upper_aline) alines_colors.append('#ef4444') # 红色 # 绘制下轨线(绿色)- 与前端一致 if lower_line and window_start and window_end: lower_aline = make_extended_aline(lower_line, lower_line[0] if lower_line else None, '#10b981') if lower_aline: alines_list.append(lower_aline) alines_colors.append('#10b981') # 绿色 # ── 枢轴点散点图(可选)───────────────────────── addplots = [] if show_pivots: upper_pivots_arr = _build_pivot_arrays(chart.get('upper_pivots', []), dates_str, n) lower_pivots_arr = _build_pivot_arrays(chart.get('lower_pivots', []), dates_str, n) if np.any(~np.isnan(upper_pivots_arr)): addplots.append(mpf.make_addplot( upper_pivots_arr, type='scatter', markersize=100, marker='^', color='#ef4444', alpha=0.85, )) if np.any(~np.isnan(lower_pivots_arr)): addplots.append(mpf.make_addplot( lower_pivots_arr, type='scatter', markersize=100, marker='v', color='#10b981', alpha=0.85, )) # ── 突破价位水平线 ───────────────────────────── bp_up = detail.get('breakout_price_up') bp_down = detail.get('breakout_price_down') hlines_vals = [] hlines_colors = [] if bp_up and float(bp_up) > 0: hlines_vals.append(float(bp_up)) hlines_colors.append('#ef4444') if bp_down and float(bp_down) > 0: hlines_vals.append(float(bp_down)) hlines_colors.append('#10b981') # ── 标题 ──────────────────────────────────────── title = ( f'{ticker} {freq_label} ' f'强度={strength:.3f} {dir_label}\n' f'窗口 {window_start}~{window_end} ' f'(上沿触碰={touches_upper} 下沿触碰={touches_lower})' ) # ── 组装 plot 参数 ──────────────────────────────── _rc = {'axes.unicode_minus': False} if _CHINESE_FONT: _rc['font.family'] = _CHINESE_FONT _rc['font.sans-serif'] = [_CHINESE_FONT, 'DejaVu Sans'] _style = mpf.make_mpf_style(base_mpf_style='yahoo', rc=_rc) kwargs = dict( type='candle', volume=False, # 不显示成交量 title=title, style=_style, figsize=(18, 10), tight_layout=True, ) if alines_list: kwargs['alines'] = dict( alines=alines_list, colors=alines_colors, linewidths=2.0, alpha=0.90, ) if addplots: kwargs['addplot'] = addplots if hlines_vals: kwargs['hlines'] = dict( hlines=hlines_vals, colors=hlines_colors, linestyle='--', linewidths=1.2, alpha=0.7, ) if output_path: os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) # ── 绘图 ───────────────────────────────────────── # 先不绘制趋势线,获取 figure 和 axes _kwargs = {k: v for k, v in kwargs.items() if k not in ['alines', 'savefig']} fig, axes = mpf.plot(df, **_kwargs, returnfig=True, warn_too_much_data=1000) # 使用 matplotlib 直接在主图上绘制趋势线 # alines_list 格式: [((idx1, price1), (idx2, price2)), ...] if alines_list and len(axes) > 0: main_ax = axes[0] # 主图(K线图) # 收集趋势线价格,用于调整 Y 轴范围(与前端一致) trend_line_prices = [] for aline, color in zip(alines_list, alines_colors): (idx1, p1), (idx2, p2) = aline trend_line_prices.extend([p1, p2]) # 直接使用整数索引作为 x 坐标(与 mplfinance 内部一致) main_ax.plot([idx1, idx2], [p1, p2], color=color, linewidth=2, alpha=0.9) # 调整 Y 轴范围以包含趋势线(与前端一致,增加 5% 边距) if trend_line_prices: current_ylim = main_ax.get_ylim() kline_min, kline_max = current_ylim # 合并 K 线和趋势线的价格范围 all_min = min(kline_min, min(trend_line_prices)) all_max = max(kline_max, max(trend_line_prices)) price_range = all_max - all_min # 增加 5% 边距(与前端一致) new_min = all_min - price_range * 0.05 new_max = all_max + price_range * 0.05 main_ax.set_ylim(new_min, new_max) print(f'[chart] Y 轴范围调整: {kline_min:.2f}~{kline_max:.2f} -> {new_min:.2f}~{new_max:.2f}') if show: plt.show() if output_path: fig.savefig(output_path, dpi=150, bbox_inches='tight') print(f'[chart] 已保存: {output_path}') plt.close(fig) return output_path