technical-patterns-lab/scripts/archive/demo_pivot_detection.py
褚宏光 6d545eb231 Enhance converging triangle detection with new features and documentation updates
- Added support for a detailed chart mode in plot_converging_triangles.py, allowing users to visualize all pivot points and fitting lines.
- Improved pivot fitting logic to utilize multiple representative points, enhancing detection accuracy and reducing false positives.
- Introduced a new real-time detection mode with flexible zone parameters for better responsiveness in stock analysis.
- Updated README.md and USAGE.md to reflect new features and usage instructions.
- Added multiple documentation files detailing recent improvements, including pivot point fitting and visualization enhancements.
- Cleaned up and archived outdated scripts to streamline the project structure.
2026-01-26 16:21:36 +08:00

184 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
枢轴点检测可视化示例
用法:
python scripts/demo_pivot_detection.py
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
# 配置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from converging_triangle import pivots_fractal
def create_sample_data():
"""创建示例价格数据"""
# 模拟一个对称三角形
days = 100
base_price = 50
# 创建波动的价格数据
high = np.zeros(days)
low = np.zeros(days)
for i in range(days):
# 整体趋势:收敛
upper_trend = base_price + 15 - (i / days) * 10 # 上沿下降
lower_trend = base_price - 10 + (i / days) * 8 # 下沿上升
# 添加局部波动
wave = 3 * np.sin(i / 10) + 2 * np.sin(i / 15)
high[i] = upper_trend + wave + np.random.uniform(0, 1)
low[i] = lower_trend + wave - np.random.uniform(0, 1)
return high, low
def demo_pivot_k_comparison():
"""演示不同k值对枢轴点检测的影响"""
print("=" * 70)
print("枢轴点检测可视化演示")
print("=" * 70)
# 生成示例数据
high, low = create_sample_data()
close = (high + low) / 2
# 测试不同的k值
k_values = [3, 8, 15]
fig, axes = plt.subplots(len(k_values), 1, figsize=(14, 12))
for idx, k in enumerate(k_values):
ax = axes[idx]
# 检测枢轴点
ph_idx, pl_idx = pivots_fractal(high, low, k=k)
# 绘制价格线
x = np.arange(len(close))
ax.plot(x, close, 'k-', linewidth=1, alpha=0.6, label='收盘价')
ax.plot(x, high, 'gray', linewidth=0.5, alpha=0.3, label='最高价')
ax.plot(x, low, 'gray', linewidth=0.5, alpha=0.3, label='最低价')
# 标注枢轴点
ax.scatter(ph_idx, high[ph_idx],
marker='o', s=80, facecolors='none',
edgecolors='red', linewidths=2,
label=f'高点枢轴 ({len(ph_idx)}个)', zorder=5)
ax.scatter(pl_idx, low[pl_idx],
marker='o', s=80, facecolors='none',
edgecolors='green', linewidths=2,
label=f'低点枢轴 ({len(pl_idx)}个)', zorder=5)
# 标题和标签
ax.set_title(
f'k = {k} (左右各{k}根K线) | '
f'高点枢轴: {len(ph_idx)}个 | 低点枢轴: {len(pl_idx)}',
fontsize=12, pad=10
)
ax.set_ylabel('价格', fontsize=10)
ax.legend(loc='upper right', fontsize=9)
ax.grid(True, alpha=0.3)
if idx == len(k_values) - 1:
ax.set_xlabel('交易日', fontsize=10)
print(f"\nk={k}:")
print(f" 高点枢轴: {len(ph_idx)}个 (索引: {ph_idx[:10]}...)")
print(f" 低点枢轴: {len(pl_idx)}个 (索引: {pl_idx[:10]}...)")
plt.tight_layout()
# 保存图片
output_dir = os.path.join("docs", "images")
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "pivot_detection_demo.png")
plt.savefig(output_path, dpi=120)
print(f"\n图表已保存: {output_path}")
plt.show()
print("\n" + "=" * 70)
print("观察要点:")
print("=" * 70)
print("1. k=3: 检测到很多枢轴点,包括小幅波动")
print(" - 优点: 捕获更多细节")
print(" - 缺点: 容易受噪音影响")
print()
print("2. k=8: 中等敏感度,过滤了一些噪音")
print(" - 平衡灵敏度和稳定性")
print()
print("3. k=15: 只检测显著的转折点(当前配置)")
print(" - 优点: 稳定,适合识别主要形态")
print(" - 缺点: 可能遗漏一些细节")
print()
print("推荐: 日线级别使用 k=15分钟级别使用 k=3-5")
print("=" * 70)
def demo_pivot_logic():
"""演示枢轴点判定逻辑"""
print("\n" + "=" * 70)
print("枢轴点判定逻辑示例")
print("=" * 70)
# 简单示例数据
high = np.array([10, 12, 15, 18, 20, 17, 14, 16, 13, 11, 9])
low = np.array([8, 9, 11, 14, 16, 13, 10, 12, 9, 7, 5])
k = 3
print(f"\n最高价: {high}")
print(f"最低价: {low}")
print(f"k值: {k} (左右各{k}根K线)\n")
# 手动检查每个位置
print("逐个位置检查 (索引从0开始):")
print("-" * 70)
for i in range(k, len(high) - k):
# 检查高点
window_high = high[i - k : i + k + 1]
is_ph = (high[i] == np.max(window_high))
# 检查低点
window_low = low[i - k : i + k + 1]
is_pl = (low[i] == np.min(window_low))
print(f"索引 {i}:")
print(f" 高点检查: high[{i}]={high[i]:.0f}, "
f"窗口[{i-k}:{i+k+1}]最大值={np.max(window_high):.0f} "
f"-> {'是枢轴高点 [YES]' if is_ph else '不是'}")
print(f" 低点检查: low[{i}]={low[i]:.0f}, "
f"窗口[{i-k}:{i+k+1}]最小值={np.min(window_low):.0f} "
f"-> {'是枢轴低点 [YES]' if is_pl else '不是'}")
print()
# 使用函数检测
ph_idx, pl_idx = pivots_fractal(high, low, k=k)
print("-" * 70)
print(f"函数检测结果:")
print(f" 高点枢轴索引: {ph_idx}")
print(f" 低点枢轴索引: {pl_idx}")
print("=" * 70)
if __name__ == "__main__":
# 1. 演示判定逻辑
demo_pivot_logic()
# 2. 可视化不同k值的效果
print("\n按任意键继续查看可视化图表...")
input()
demo_pivot_k_comparison()