- Add scripts/scoring/ module with normalizer, sensitivity analysis, and config - Enhance stock_viewer.html with standardized scoring display - Add integration tests and normalization verification scripts - Add documentation for standardization implementation and usage guides - Add data distribution analysis reports for strength scoring dimensions - Update discussion documents with algorithm optimization plans
386 lines
13 KiB
Python
386 lines
13 KiB
Python
"""
|
||
敏感性分析工具
|
||
|
||
分析参数变化对筛选结果的影响,帮助用户优化参数设置。
|
||
|
||
主要功能:
|
||
1. 阈值敏感性分析
|
||
2. 权重敏感性分析
|
||
3. 生成完整的敏感性分析报告
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any
|
||
import sys
|
||
import os
|
||
|
||
# 添加路径
|
||
script_dir = os.path.dirname(__file__)
|
||
sys.path.insert(0, script_dir)
|
||
|
||
from config import StrengthConfig, filter_signals, calculate_strength
|
||
from normalizer import normalize_all
|
||
|
||
# 设置中文字体
|
||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
|
||
|
||
def analyze_threshold_sensitivity(
|
||
df_normalized: pd.DataFrame,
|
||
config: StrengthConfig,
|
||
param_name: str,
|
||
param_range: List[float]
|
||
) -> pd.DataFrame:
|
||
"""
|
||
分析阈值参数的敏感性
|
||
|
||
Args:
|
||
df_normalized: 标准化后的DataFrame
|
||
config: 基础配置
|
||
param_name: 参数名(如'threshold_price')
|
||
param_range: 参数取值范围列表
|
||
|
||
Returns:
|
||
敏感性分析结果DataFrame
|
||
"""
|
||
results = []
|
||
|
||
for value in param_range:
|
||
# 复制配置并修改参数
|
||
test_config = StrengthConfig(**config.__dict__)
|
||
setattr(test_config, param_name, value)
|
||
|
||
# 筛选
|
||
try:
|
||
filtered = filter_signals(df_normalized, test_config)
|
||
n_signals = len(filtered)
|
||
pct_selected = n_signals / len(df_normalized) * 100
|
||
|
||
# 计算筛选后的平均强度分
|
||
if n_signals > 0:
|
||
strength = calculate_strength(filtered, test_config)
|
||
avg_strength = strength.mean()
|
||
min_strength = strength.min()
|
||
max_strength = strength.max()
|
||
else:
|
||
avg_strength = 0.0
|
||
min_strength = 0.0
|
||
max_strength = 0.0
|
||
|
||
results.append({
|
||
'参数值': value,
|
||
'信号数量': n_signals,
|
||
'占比%': pct_selected,
|
||
'平均强度': avg_strength,
|
||
'最小强度': min_strength,
|
||
'最大强度': max_strength,
|
||
})
|
||
except Exception as e:
|
||
results.append({
|
||
'参数值': value,
|
||
'信号数量': 0,
|
||
'占比%': 0.0,
|
||
'平均强度': 0.0,
|
||
'最小强度': 0.0,
|
||
'最大强度': 0.0,
|
||
'错误': str(e)
|
||
})
|
||
|
||
return pd.DataFrame(results)
|
||
|
||
|
||
def analyze_weight_sensitivity(
|
||
df_normalized: pd.DataFrame,
|
||
config: StrengthConfig,
|
||
weight_name: str,
|
||
weight_range: List[float]
|
||
) -> pd.DataFrame:
|
||
"""
|
||
分析权重参数的敏感性
|
||
|
||
Args:
|
||
df_normalized: 标准化后的DataFrame
|
||
config: 基础配置
|
||
weight_name: 权重参数名(如'w_price')
|
||
weight_range: 权重取值范围列表(注意:其他权重会自动调整以保持和为1)
|
||
|
||
Returns:
|
||
敏感性分析结果DataFrame
|
||
"""
|
||
results = []
|
||
|
||
# 获取所有权重参数
|
||
weight_params = ['w_price', 'w_convergence', 'w_volume',
|
||
'w_geometry', 'w_activity', 'w_tilt']
|
||
|
||
for value in weight_range:
|
||
# 复制配置
|
||
test_config = StrengthConfig(**config.__dict__)
|
||
|
||
# 设置目标权重
|
||
setattr(test_config, weight_name, value)
|
||
|
||
# 调整其他权重以保持和为1
|
||
other_weights = [w for w in weight_params if w != weight_name]
|
||
remaining_weight = 1.0 - value
|
||
|
||
if remaining_weight < 0:
|
||
continue # 跳过无效配置
|
||
|
||
# 等比例分配剩余权重
|
||
for w in other_weights:
|
||
original_value = getattr(config, w)
|
||
original_sum = sum(getattr(config, w) for w in other_weights)
|
||
if original_sum > 0:
|
||
new_value = (original_value / original_sum) * remaining_weight
|
||
setattr(test_config, w, new_value)
|
||
|
||
try:
|
||
test_config.validate()
|
||
|
||
# 计算强度分
|
||
strength = calculate_strength(df_normalized, test_config)
|
||
|
||
results.append({
|
||
'权重值': value,
|
||
'平均强度': strength.mean(),
|
||
'中位数强度': strength.median(),
|
||
'标准差': strength.std(),
|
||
'P90': strength.quantile(0.90),
|
||
'P95': strength.quantile(0.95),
|
||
})
|
||
except Exception as e:
|
||
continue
|
||
|
||
return pd.DataFrame(results)
|
||
|
||
|
||
def plot_threshold_sensitivity(
|
||
sensitivity_df: pd.DataFrame,
|
||
param_name: str,
|
||
output_path: Path
|
||
):
|
||
"""绘制阈值敏感性图表"""
|
||
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
|
||
|
||
# 左图:信号数量变化
|
||
ax1 = axes[0]
|
||
ax1.plot(sensitivity_df['参数值'], sensitivity_df['信号数量'],
|
||
marker='o', linewidth=2, markersize=8, color='steelblue')
|
||
ax1.set_xlabel(f'{param_name} 阈值', fontsize=12)
|
||
ax1.set_ylabel('信号数量', fontsize=12, color='steelblue')
|
||
ax1.tick_params(axis='y', labelcolor='steelblue')
|
||
ax1.grid(True, alpha=0.3)
|
||
ax1.set_title(f'{param_name} 阈值敏感性分析', fontsize=14, fontweight='bold')
|
||
|
||
# 右轴:占比
|
||
ax1_twin = ax1.twinx()
|
||
ax1_twin.plot(sensitivity_df['参数值'], sensitivity_df['占比%'],
|
||
marker='s', linewidth=2, markersize=8, color='coral', alpha=0.7)
|
||
ax1_twin.set_ylabel('占比 (%)', fontsize=12, color='coral')
|
||
ax1_twin.tick_params(axis='y', labelcolor='coral')
|
||
|
||
# 右图:平均强度变化
|
||
ax2 = axes[1]
|
||
ax2.plot(sensitivity_df['参数值'], sensitivity_df['平均强度'],
|
||
marker='o', linewidth=2, markersize=8, color='forestgreen')
|
||
ax2.set_xlabel(f'{param_name} 阈值', fontsize=12)
|
||
ax2.set_ylabel('筛选后平均强度', fontsize=12, color='forestgreen')
|
||
ax2.tick_params(axis='y', labelcolor='forestgreen')
|
||
ax2.grid(True, alpha=0.3)
|
||
ax2.set_title(f'阈值对强度分的影响', fontsize=14, fontweight='bold')
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
|
||
def generate_full_sensitivity_report(
|
||
df_normalized: pd.DataFrame,
|
||
base_config: StrengthConfig,
|
||
output_dir: Path
|
||
):
|
||
"""
|
||
生成完整的敏感性分析报告
|
||
|
||
Args:
|
||
df_normalized: 标准化后的DataFrame
|
||
base_config: 基础配置
|
||
output_dir: 输出目录
|
||
"""
|
||
print("=" * 80)
|
||
print("生成完整敏感性分析报告")
|
||
print("=" * 80)
|
||
|
||
# 1. 突破幅度阈值敏感性
|
||
print("\n[1] 分析 threshold_price 敏感性...")
|
||
price_range = np.arange(0.50, 0.91, 0.05)
|
||
price_sens = analyze_threshold_sensitivity(
|
||
df_normalized, base_config, 'threshold_price', price_range.tolist()
|
||
)
|
||
price_sens_path = output_dir / 'sensitivity_threshold_price.csv'
|
||
price_sens.to_csv(price_sens_path, index=False, encoding='utf-8-sig')
|
||
print(f" 已保存: {price_sens_path}")
|
||
|
||
# 绘图
|
||
plot_path = output_dir / 'sensitivity_threshold_price.png'
|
||
plot_threshold_sensitivity(price_sens, 'threshold_price', plot_path)
|
||
print(f" 图表已保存: {plot_path}")
|
||
|
||
# 2. 收敛度阈值敏感性
|
||
print("\n[2] 分析 threshold_convergence 敏感性...")
|
||
conv_range = np.arange(0.30, 0.81, 0.05)
|
||
conv_sens = analyze_threshold_sensitivity(
|
||
df_normalized, base_config, 'threshold_convergence', conv_range.tolist()
|
||
)
|
||
conv_sens_path = output_dir / 'sensitivity_threshold_convergence.csv'
|
||
conv_sens.to_csv(conv_sens_path, index=False, encoding='utf-8-sig')
|
||
print(f" 已保存: {conv_sens_path}")
|
||
|
||
# 3. 成交量阈值敏感性
|
||
print("\n[3] 分析 threshold_volume 敏感性...")
|
||
vol_range = np.arange(0.50, 0.91, 0.05)
|
||
vol_sens = analyze_threshold_sensitivity(
|
||
df_normalized, base_config, 'threshold_volume', vol_range.tolist()
|
||
)
|
||
vol_sens_path = output_dir / 'sensitivity_threshold_volume.csv'
|
||
vol_sens.to_csv(vol_sens_path, index=False, encoding='utf-8-sig')
|
||
print(f" 已保存: {vol_sens_path}")
|
||
|
||
# 4. 突破幅度权重敏感性
|
||
print("\n[4] 分析 w_price 权重敏感性...")
|
||
price_weight_range = np.arange(0.10, 0.51, 0.05)
|
||
price_weight_sens = analyze_weight_sensitivity(
|
||
df_normalized, base_config, 'w_price', price_weight_range.tolist()
|
||
)
|
||
price_weight_path = output_dir / 'sensitivity_weight_price.csv'
|
||
price_weight_sens.to_csv(price_weight_path, index=False, encoding='utf-8-sig')
|
||
print(f" 已保存: {price_weight_path}")
|
||
|
||
# 5. 生成汇总报告
|
||
print("\n[5] 生成汇总报告...")
|
||
|
||
# 辅助函数:将DataFrame转为markdown表格
|
||
def df_to_markdown(df):
|
||
lines = []
|
||
# 表头
|
||
lines.append('| ' + ' | '.join(df.columns) + ' |')
|
||
lines.append('|' + '|'.join(['---' for _ in df.columns]) + '|')
|
||
# 数据行
|
||
for _, row in df.iterrows():
|
||
values = []
|
||
for v in row:
|
||
if isinstance(v, float):
|
||
values.append(f"{v:.4f}")
|
||
elif isinstance(v, int):
|
||
values.append(str(v))
|
||
else:
|
||
values.append(str(v))
|
||
lines.append('| ' + ' | '.join(values) + ' |')
|
||
return '\n'.join(lines)
|
||
|
||
summary_lines = [
|
||
"# 敏感性分析汇总报告",
|
||
"",
|
||
f"基础配置: {base_config.name}",
|
||
f"样本数量: {len(df_normalized):,}",
|
||
"",
|
||
"## 1. 突破幅度阈值敏感性",
|
||
"",
|
||
df_to_markdown(price_sens),
|
||
"",
|
||
"### 建议:",
|
||
f"- 宽松筛选 (10%+信号): threshold_price ≈ 0.60",
|
||
f"- 适中筛选 (5%信号): threshold_price ≈ 0.70",
|
||
f"- 严格筛选 (1-2%信号): threshold_price ≈ 0.80",
|
||
"",
|
||
"## 2. 收敛度阈值敏感性",
|
||
"",
|
||
df_to_markdown(conv_sens),
|
||
"",
|
||
"## 3. 成交量阈值敏感性",
|
||
"",
|
||
df_to_markdown(vol_sens),
|
||
"",
|
||
"### 注意:",
|
||
"成交量阈值 > 0.5 时才启用筛选,≤ 0.5 表示不限制",
|
||
"",
|
||
"## 4. 突破幅度权重敏感性",
|
||
"",
|
||
df_to_markdown(price_weight_sens),
|
||
"",
|
||
]
|
||
|
||
summary_path = output_dir / 'sensitivity_analysis_report.md'
|
||
with open(summary_path, 'w', encoding='utf-8') as f:
|
||
f.write('\n'.join(summary_lines))
|
||
print(f" 汇总报告已保存: {summary_path}")
|
||
|
||
print("\n" + "=" * 80)
|
||
print("敏感性分析完成!")
|
||
print("=" * 80)
|
||
|
||
|
||
def quick_analysis():
|
||
"""快速敏感性分析(仅关键参数)"""
|
||
# 加载数据
|
||
data_path = Path(__file__).parent.parent.parent / 'outputs' / 'converging_triangles' / 'all_results_normalized.csv'
|
||
|
||
if not data_path.exists():
|
||
print(f"标准化数据不存在: {data_path}")
|
||
print("请先运行 verify_normalization.py")
|
||
return
|
||
|
||
print("=" * 80)
|
||
print("快速敏感性分析")
|
||
print("=" * 80)
|
||
|
||
df = pd.read_csv(data_path)
|
||
print(f"\n加载数据: {len(df)} 条记录")
|
||
|
||
# 使用等权配置作为基础
|
||
from config import CONFIG_EQUAL
|
||
|
||
# 分析突破幅度阈值
|
||
print("\n[1] 突破幅度阈值敏感性")
|
||
print("-" * 80)
|
||
price_range = [0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90]
|
||
price_sens = analyze_threshold_sensitivity(df, CONFIG_EQUAL, 'threshold_price', price_range)
|
||
|
||
print("\nthreshold_price | 信号数 | 占比 | 平均强度")
|
||
print("-" * 60)
|
||
for _, row in price_sens.iterrows():
|
||
print(f"{row['参数值']:15.2f} | {row['信号数量']:6.0f} | {row['占比%']:5.1f}% | {row['平均强度']:8.4f}")
|
||
|
||
# 分析成交量阈值
|
||
print("\n[2] 成交量阈值敏感性")
|
||
print("-" * 80)
|
||
vol_range = [0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80]
|
||
vol_sens = analyze_threshold_sensitivity(df, CONFIG_EQUAL, 'threshold_volume', vol_range)
|
||
|
||
print("\nthreshold_volume | 信号数 | 占比 | 平均强度")
|
||
print("-" * 60)
|
||
for _, row in vol_sens.iterrows():
|
||
print(f"{row['参数值']:16.2f} | {row['信号数量']:6.0f} | {row['占比%']:5.1f}% | {row['平均强度']:8.4f}")
|
||
|
||
# 找出最佳阈值建议
|
||
print("\n" + "=" * 80)
|
||
print("阈值设置建议")
|
||
print("=" * 80)
|
||
|
||
# 根据信号数量给出建议
|
||
for target_pct, desc in [(10.0, "宽松"), (5.0, "适中"), (2.0, "严格"), (1.0, "极严格")]:
|
||
closest = price_sens.iloc[(price_sens['占比%'] - target_pct).abs().argsort()[:1]]
|
||
if len(closest) > 0:
|
||
row = closest.iloc[0]
|
||
print(f"{desc:6s} (目标{target_pct:4.1f}%信号): threshold_price ≈ {row['参数值']:.2f} "
|
||
f"(实际{row['占比%']:.1f}%, {int(row['信号数量'])}个信号)")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
quick_analysis()
|