- Add scripts/scoring/ module with normalizer, sensitivity analysis, and config - Enhance stock_viewer.html with standardized scoring display - Add integration tests and normalization verification scripts - Add documentation for standardization implementation and usage guides - Add data distribution analysis reports for strength scoring dimensions - Update discussion documents with algorithm optimization plans
22 KiB
22 KiB
收敛三角形检测算法优化方案
目标场景
- 股票数量: 5000只
- 运行频率: 每天执行
- 当前耗时: 10-60秒(预估)
- 优化目标: <5秒(10倍提速)
- 质量要求: 保持检测准确率
性能瓶颈分析
当前算法流程
检测流程(每只股票/每天):
┌─────────────────────────────────────────────────┐
│ 1. 枢轴点检测 (pivots_fractal_hybrid) 30% │ ← 热点1
│ 2. 边界线拟合 (fit_pivot_line) 25% │ ← 热点2
│ 3. 几何验证 (收敛度/触碰/斜率) 20% │
│ 4. 突破强度计算 (价格/成交量) 15% │
│ 5. DataFrame构建 + 数据复制 10% │ ← 热点3
└─────────────────────────────────────────────────┘
关键瓶颈
- 枢轴点检测: O(n*k) 滑动窗口,重复计算
- 边界线拟合: 迭代离群点移除,多次最小二乘
- Python循环: 大量stock×day双层循环
- 内存分配: 频繁创建临时数组
优化方案(分级实施)
🚀 Level 1: 向量化优化(预期提速2-3倍)
1.1 枢轴点检测向量化
当前实现 (O(n*k) 滑动窗口):
def pivots_fractal(high, low, k=15):
"""滑动窗口查找局部极值"""
ph, pl = [], []
for i in range(k, len(high) - k):
# 检查左右k个点
if all(high[i] >= high[j] for j in range(i-k, i+k+1) if j != i):
ph.append(i)
if all(low[i] <= low[j] for j in range(i-k, i+k+1) if j != i):
pl.append(i)
return np.array(ph), np.array(pl)
优化方案 (向量化):
def pivots_fractal_vectorized(high, low, k=15):
"""
向量化枢轴点检测
核心思路:
1. 使用scipy.signal.argrelextrema一次性找所有极值
2. 或使用卷积/滚动窗口向量化计算
预期提速:3-5倍
"""
from scipy.signal import argrelextrema
# 找局部极大值(高点)
ph = argrelextrema(high, np.greater_equal, order=k)[0]
# 找局部极小值(低点)
pl = argrelextrema(low, np.less_equal, order=k)[0]
return ph, pl
def pivots_fractal_rolling(high, low, k=15):
"""
使用pandas滚动窗口实现
预期提速:2-4倍
"""
import pandas as pd
high_series = pd.Series(high)
low_series = pd.Series(low)
# 滚动窗口找最大/最小值索引
window = 2*k + 1
high_rolling_max = high_series.rolling(window, center=True).max()
low_rolling_min = low_series.rolling(window, center=True).min()
# 中心点等于窗口极值的位置即为枢轴点
ph = np.where((high_series == high_rolling_max) & (high_series.notna()))[0]
pl = np.where((low_series == low_rolling_min) & (low_series.notna()))[0]
return ph, pl
实施:
- 文件:
src/converging_triangle.py - 函数:
pivots_fractal()和pivots_fractal_hybrid() - 向下兼容:保留原函数作为fallback
1.2 边界线拟合优化
当前实现 (迭代离群点移除):
def fit_pivot_line(pivot_indices, pivot_values, mode="upper", max_iter=10):
"""迭代移除离群点,多次最小二乘拟合"""
for iter in range(max_iter):
a, b = np.polyfit(indices, values, 1) # 最小二乘
residuals = values - (a * indices + b)
outliers = find_outliers(residuals)
if no_outliers: break
remove_outliers()
return a, b
优化方案A (预计算+缓存):
def fit_pivot_line_cached(pivot_indices, pivot_values, mode="upper", cache=None):
"""
缓存中间结果,避免重复计算
场景:相邻日期的枢轴点大部分重叠
策略:缓存最近N天的拟合结果,增量更新
预期提速:30-50%(针对滚动窗口场景)
"""
cache_key = (tuple(pivot_indices), tuple(pivot_values), mode)
if cache and cache_key in cache:
return cache[cache_key]
# 原有拟合逻辑
result = _fit_pivot_line_core(pivot_indices, pivot_values, mode)
if cache is not None:
cache[cache_key] = result
return result
优化方案B (快速拟合算法):
def fit_pivot_line_ransac(pivot_indices, pivot_values, mode="upper"):
"""
使用RANSAC快速拟合(对离群点鲁棒)
sklearn.linear_model.RANSACRegressor
预期提速:2-3倍
"""
from sklearn.linear_model import RANSACRegressor
X = pivot_indices.reshape(-1, 1)
y = pivot_values
ransac = RANSACRegressor(
residual_threshold=threshold,
max_trials=100,
random_state=42
)
ransac.fit(X, y)
a = ransac.estimator_.coef_[0]
b = ransac.estimator_.intercept_
inlier_mask = ransac.inlier_mask_
return a, b, np.where(inlier_mask)[0]
推荐: 先实施方案A(缓存),简单且收益稳定
1.3 消除Python循环
当前实现 (双层循环):
def detect_converging_triangle_batch(...):
results = []
for stock_idx in range(n_stocks):
for date_idx in range(start_day, end_day + 1):
result = detect_converging_triangle_single(
stock_idx, date_idx, ...
)
results.append(result)
return pd.DataFrame(results)
优化方案 (向量化外层):
def detect_converging_triangle_batch_vectorized(...):
"""
外层循环向量化
策略:
1. 按date_idx分组,一次处理所有股票
2. 使用numpy广播并行计算
预期提速:1.5-2倍
"""
all_results = []
for date_idx in range(start_day, end_day + 1):
# 一次性处理所有股票在同一天的检测
# 提取窗口数据(向量化)
window_start = date_idx - window + 1
high_windows = high_mtx[:, window_start:date_idx+1] # (n_stocks, window)
low_windows = low_mtx[:, window_start:date_idx+1]
# 批量检测枢轴点(利用numpy向量运算)
pivots_batch = detect_pivots_batch(high_windows, low_windows)
# 批量拟合边界线
fits_batch = fit_lines_batch(pivots_batch)
# 批量计算强度
strengths_batch = calc_strengths_batch(fits_batch, ...)
all_results.append(strengths_batch)
return np.vstack(all_results)
关键: 需要重构算法,使单个函数能处理 (n_stocks, window) 维度
⚡ Level 2: Numba JIT加速(预期提速5-10倍)
2.1 Numba加速核心函数
from numba import jit, prange
@jit(nopython=True, parallel=True, cache=True)
def pivots_fractal_numba(high, low, k=15):
"""
Numba加速枢轴点检测
优势:
- nopython=True: 编译为机器码
- parallel=True: 多线程并行
- cache=True: 缓存编译结果
预期提速:10-20倍(相比纯Python)
"""
n = len(high)
ph_list = []
pl_list = []
for i in prange(k, n - k): # 并行循环
# 检查是否为高点
is_high_pivot = True
for j in range(i - k, i + k + 1):
if j != i and high[i] < high[j]:
is_high_pivot = False
break
if is_high_pivot:
ph_list.append(i)
# 检查是否为低点
is_low_pivot = True
for j in range(i - k, i + k + 1):
if j != i and low[i] > low[j]:
is_low_pivot = False
break
if is_low_pivot:
pl_list.append(i)
return np.array(ph_list), np.array(pl_list)
@jit(nopython=True, cache=True)
def fit_line_numba(x, y):
"""Numba加速最小二乘拟合"""
n = len(x)
x_mean = np.mean(x)
y_mean = np.mean(y)
numerator = np.sum((x - x_mean) * (y - y_mean))
denominator = np.sum((x - x_mean) ** 2)
a = numerator / denominator
b = y_mean - a * x_mean
return a, b
@jit(nopython=True, parallel=True)
def detect_batch_numba(
high_mtx, low_mtx, close_mtx, volume_mtx,
window, k, start_day, end_day
):
"""
Numba加速批量检测
核心优化:
- 消除Python对象开销
- 并行化最外层循环
- 预分配结果数组
预期提速:5-10倍
"""
n_stocks, n_days = high_mtx.shape
total_points = n_stocks * (end_day - start_day + 1)
# 预分配结果数组
strength_up = np.zeros(total_points, dtype=np.float64)
strength_down = np.zeros(total_points, dtype=np.float64)
is_valid = np.zeros(total_points, dtype=np.bool_)
# 并行处理每个检测点
for idx in prange(total_points):
stock_idx = idx // (end_day - start_day + 1)
day_offset = idx % (end_day - start_day + 1)
date_idx = start_day + day_offset
# 提取窗口数据
window_start = date_idx - window + 1
high_win = high_mtx[stock_idx, window_start:date_idx+1]
low_win = low_mtx[stock_idx, window_start:date_idx+1]
# 检测枢轴点
ph, pl = pivots_fractal_numba(high_win, low_win, k)
# ... 后续处理 ...
strength_up[idx] = computed_strength_up
strength_down[idx] = computed_strength_down
is_valid[idx] = computed_is_valid
return strength_up, strength_down, is_valid
实施要点:
- Numba要求函数纯数值计算,不能有pandas/字典等Python对象
- 首次运行会有JIT编译开销(~1-2秒),后续调用极快
- 需要将算法拆分为纯数值函数
🔥 Level 3: 并行化+缓存策略(预期提速10-20倍)
3.1 多进程并行
from multiprocessing import Pool, cpu_count
from functools import partial
def detect_stock_range(stock_indices, high_mtx, low_mtx, ...):
"""处理一批股票的检测任务"""
results = []
for stock_idx in stock_indices:
for date_idx in range(start_day, end_day + 1):
result = detect_converging_triangle_single(
stock_idx, date_idx, high_mtx, low_mtx, ...
)
results.append(result)
return results
def detect_converging_triangle_parallel(
high_mtx, low_mtx, close_mtx, volume_mtx,
params, start_day, end_day,
n_workers=None
):
"""
多进程并行检测
策略:
- 将5000只股票分成n_workers组
- 每个进程处理一组股票
- 主进程合并结果
预期提速:接近线性(8核约7倍)
"""
n_stocks = high_mtx.shape[0]
n_workers = n_workers or cpu_count() - 1
# 分配任务(按股票索引分组)
stock_groups = np.array_split(range(n_stocks), n_workers)
# 创建部分函数(固定参数)
detect_fn = partial(
detect_stock_range,
high_mtx=high_mtx,
low_mtx=low_mtx,
close_mtx=close_mtx,
volume_mtx=volume_mtx,
params=params,
start_day=start_day,
end_day=end_day
)
# 并行执行
with Pool(n_workers) as pool:
results_groups = pool.map(detect_fn, stock_groups)
# 合并结果
all_results = []
for group_results in results_groups:
all_results.extend(group_results)
return pd.DataFrame(all_results)
注意:
- 适合CPU密集型任务
- 需要足够内存(数据复制到子进程)
- 5000只股票场景下,8-16核最优
3.2 增量计算+缓存
class IncrementalDetector:
"""
增量检测器:缓存历史计算结果
场景:每天新增一个交易日,复用历史检测结果
策略:
1. 缓存最近N天的枢轴点/拟合线
2. 新增日期时只计算增量部分
3. LRU淘汰旧缓存
预期收益:
- 首次运行:无加速
- 后续每日:提速5-10倍(只需计算最新day)
"""
def __init__(self, window=240, cache_size=100):
self.window = window
self.pivot_cache = {} # {stock_idx: {date_idx: (ph, pl)}}
self.fit_cache = {} # {stock_idx: {date_idx: fit_result}}
self.cache_size = cache_size
def detect_incremental(self, stock_idx, new_date_idx, high, low, close, volume):
"""
增量检测:利用缓存快速计算
逻辑:
1. 检查缓存中是否有前一天的结果
2. 如果有,只需:
- 更新枢轴点(新增1天数据)
- 复用历史拟合结果
3. 如果无,全量计算并缓存
"""
prev_date_idx = new_date_idx - 1
# 尝试从缓存获取前一天结果
if stock_idx in self.pivot_cache and prev_date_idx in self.pivot_cache[stock_idx]:
# 增量更新枢轴点
prev_ph, prev_pl = self.pivot_cache[stock_idx][prev_date_idx]
new_ph, new_pl = self._update_pivots_incremental(
prev_ph, prev_pl, high, low, new_date_idx
)
else:
# 全量计算
new_ph, new_pl = pivots_fractal(high, low, k=15)
# 缓存枢轴点
if stock_idx not in self.pivot_cache:
self.pivot_cache[stock_idx] = {}
self.pivot_cache[stock_idx][new_date_idx] = (new_ph, new_pl)
# ... 后续处理 ...
return result
def _update_pivots_incremental(self, prev_ph, prev_pl, high, low, new_idx):
"""
增量更新枢轴点
策略:
1. 大部分枢轴点位置不变(相对索引+1)
2. 只需检查窗口边界的新增/移除
"""
# 简化版:这里需要更复杂的逻辑
# 实际应该检查最近k个点是否形成新枢轴
k = 15
last_points = high[-2*k:]
# 检查最新点是否为枢轴
if self._is_pivot_high(last_points, k):
prev_ph = np.append(prev_ph, new_idx)
if self._is_pivot_low(low[-2*k:], k):
prev_pl = np.append(prev_pl, new_idx)
return prev_ph, prev_pl
实施优先级:
- 中等(适合生产环境每日运行)
- 首次启动无收益,后续每日收益显著
💾 Level 4: 数据结构优化(预期提速1.5-2倍)
4.1 使用Numpy结构化数组替代DataFrame
def detect_converging_triangle_batch_numpy(...):
"""
使用numpy结构化数组替代pandas DataFrame
优势:
- 避免pandas对象开销
- 内存连续,cache友好
- 直接返回numpy数组供后续处理
预期提速:30-50%(减少内存分配)
"""
n_stocks, n_days = close_mtx.shape
total_points = n_stocks * (end_day - start_day + 1)
# 定义结果结构
dtype = np.dtype([
('stock_idx', np.int32),
('date_idx', np.int32),
('is_valid', np.bool_),
('strength_up', np.float32),
('strength_down', np.float32),
('convergence_score', np.float32),
('volume_score', np.float32),
('geometry_score', np.float32),
('activity_score', np.float32),
('tilt_score', np.float32),
])
# 预分配结果数组
results = np.empty(total_points, dtype=dtype)
idx = 0
for stock_idx in range(n_stocks):
for date_idx in range(start_day, end_day + 1):
result = detect_single(stock_idx, date_idx, ...)
# 直接写入结构化数组(无中间对象)
results[idx]['stock_idx'] = stock_idx
results[idx]['date_idx'] = date_idx
results[idx]['is_valid'] = result.is_valid
results[idx]['strength_up'] = result.strength_up
# ...
idx += 1
return results # 后续可转为DataFrame: pd.DataFrame(results)
4.2 内存映射文件(大规模数据)
def load_data_mmap(data_dir):
"""
使用内存映射加载数据
适用场景:
- 数据量 > 可用内存
- 多进程共享数据(避免复制)
预期收益:
- 加载时间:从秒级降到毫秒级
- 内存占用:0(按需加载页面)
"""
import os
# 保存为.npy格式(支持mmap)
high_mmap = np.load(
os.path.join(data_dir, 'high.npy'),
mmap_mode='r' # 只读模式
)
return high_mmap # 返回mmap对象,按需加载数据
def save_data_for_mmap(data, filepath):
"""保存数据为mmap兼容格式"""
np.save(filepath, data)
优化实施路线图
Phase 1: 快速收益(1-2周,预期2-3倍提速)
优先级P0:
- ✅ 枢轴点检测向量化(使用scipy或pandas rolling)
- ✅ 边界线拟合缓存
- ✅ 消除简单的Python循环(能向量化的先向量化)
预期收益:
- 耗时:30-60秒 → 10-20秒
- 实施难度:低
- 风险:低(向下兼容)
Phase 2: Numba加速(2-3周,预期5-10倍提速)
优先级P1:
- ✅ 核心函数Numba化(pivots/fit_line/calc_strength)
- ✅ 批量检测主循环Numba化
- ⚠️ 单元测试(确保数值精度一致)
预期收益:
- 耗时:10-20秒 → 2-5秒
- 实施难度:中
- 风险:中(需要验证数值稳定性)
注意事项:
# Numba限制:
# 1. 不支持pandas DataFrame(需改用numpy)
# 2. 不支持字典/列表(需改用numpy数组)
# 3. 不支持动态类型(需显式类型标注)
# 解决方案:
# - 将pandas逻辑分离到外层
# - 核心计算用纯numpy实现
# - 添加类型标注
Phase 3: 并行化(1周,预期10-20倍提速)
优先级P2(如果Phase 2后仍需优化):
- ✅ 多进程并行检测(Pool/ProcessPoolExecutor)
- ✅ 增量计算+缓存策略(生产环境每日运行)
预期收益:
- 耗时:2-5秒 → <1秒
- 实施难度:低(如果已完成Phase 1-2)
- 风险:低
Phase 4: 极致优化(按需实施)
优先级P3(仅当前面优化不够):
- Cython重写核心模块(C扩展)
- GPU加速(CUDA/cupy)
- Rust扩展(pyo3)
预期收益:
- 耗时:<1秒 → 毫秒级
- 实施难度:高
- 风险:高(维护成本)
基准测试脚本
使用现有测试脚本
# 当前项目已有性能测试脚本
cd d:\project\technical-patterns-lab
# 运行小规模测试(10只股票)
python scripts/test_performance.py
# 查看profiling结果
pip install snakeviz
snakeviz outputs/performance/profile_*.prof
创建5000股测试脚本
# scripts/benchmark_5000_stocks.py
"""
5000只股票性能测试
"""
import time
import numpy as np
from src.converging_triangle import detect_converging_triangle_batch, ConvergingTriangleParams
def generate_synthetic_data(n_stocks=5000, n_days=500):
"""生成合成数据用于测试"""
np.random.seed(42)
base_price = 10 + np.random.randn(n_stocks, 1) * 2
returns = np.random.randn(n_stocks, n_days) * 0.02
close = base_price * np.cumprod(1 + returns, axis=1)
high = close * (1 + np.abs(np.random.randn(n_stocks, n_days)) * 0.01)
low = close * (1 - np.abs(np.random.randn(n_stocks, n_days)) * 0.01)
open_ = close * (1 + np.random.randn(n_stocks, n_days) * 0.005)
volume = np.random.randint(100000, 1000000, (n_stocks, n_days))
return open_, high, low, close, volume
def benchmark_5000_stocks():
print("=" * 80)
print("5000只股票性能测试")
print("=" * 80)
# 生成数据
print("\n生成测试数据...")
open_, high, low, close, volume = generate_synthetic_data(5000, 500)
print(f"数据形状: {close.shape}")
# 配置参数
params = ConvergingTriangleParams(window=240, pivot_k=15)
# 测试
print("\n开始检测...")
start = time.time()
df = detect_converging_triangle_batch(
open_mtx=open_,
high_mtx=high,
low_mtx=low,
close_mtx=close,
volume_mtx=volume,
params=params,
start_day=239,
end_day=499,
only_valid=True,
verbose=False
)
elapsed = time.time() - start
# 结果
print("\n" + "=" * 80)
print("测试结果")
print("=" * 80)
print(f"总耗时: {elapsed:.2f} 秒")
print(f"检测点数: {5000 * 261}") # (500-239)
print(f"速度: {5000*261/elapsed:.1f} 点/秒")
print(f"有效形态: {len(df)}")
# 评估
print("\n" + "=" * 80)
if elapsed < 5:
print("✅ 性能优秀! (<5秒)")
elif elapsed < 10:
print("✔️ 性能良好 (5-10秒)")
elif elapsed < 30:
print("⚠️ 性能一般 (10-30秒), 建议优化")
else:
print("❌ 性能较差 (>30秒), 急需优化")
if __name__ == '__main__':
benchmark_5000_stocks()
质量保证
回归测试
# tests/test_optimization_correctness.py
"""
优化正确性测试:确保优化后结果一致
"""
import numpy as np
import pytest
def test_optimized_vs_original():
"""对比优化版本和原版本的结果"""
# 加载测试数据
data = load_test_case()
# 原版本
result_orig = detect_original(data)
# 优化版本
result_opt = detect_optimized(data)
# 验证结果一致(允许微小数值误差)
np.testing.assert_allclose(
result_orig['strength_up'],
result_opt['strength_up'],
rtol=1e-5,
atol=1e-8
)
# 验证有效性标记完全一致
assert (result_orig['is_valid'] == result_opt['is_valid']).all()
预期效果总结
| 优化阶段 | 实施难度 | 预期提速 | 累计提速 | 总耗时(5000股) |
|---|---|---|---|---|
| 当前 | - | - | 1x | 30-60秒 |
| Phase 1 | 低 | 2-3x | 2-3x | 10-20秒 |
| Phase 2 | 中 | 5-10x | 10-30x | 2-5秒 |
| Phase 3 | 低 | 10-20x | 20-60x | <1秒 |
| Phase 4 | 高 | 50-100x | 100-300x | 毫秒级 |
推荐路径: Phase 1 → Phase 2 → 观察是否满足需求 → 按需进入Phase 3
立即行动
本周任务
-
基准测试当前性能
python scripts/benchmark_5000_stocks.py -
确认瓶颈函数
python -m cProfile -o profile.stats scripts/benchmark_5000_stocks.py python -m pstats profile.stats # >> stats # >> sort cumulative # >> stats 20 -
优先实施:枢轴点检测向量化(收益最大、难度最低)
需要我帮你实现具体的优化代码吗?比如从枢轴点检测向量化开始?