Quellcode durchsuchen

优化:修复审核发现的 3 个 Low 级别问题

1. predict.py: 移除 predict_key_factors 内层逐行迭代,改为向量化 one-hot 列映射
2. model_bundle.py: stratify 关闭时增加 warnings.warn 提示,防止小样本验证集不含少数类
3. train_ml_models.py: load_defect_csv 增加 try/except,给出用户友好的错误提示
leod vor 4 Tagen
Ursprung
Commit
e41454aad6
2 geänderte Dateien mit 11 neuen und 8 gelöschten Zeilen
  1. 7 7
      defect_analysis/ml/predict.py
  2. 4 1
      train_ml_models.py

+ 7 - 7
defect_analysis/ml/predict.py

@@ -33,14 +33,14 @@ def predict_key_factors(df, *, target_defect_type=None, target_severity=None, mo
     probabilities = pd.Series(model.predict_proba(X)[:, 1], index=X.index)
 
     scored = key_factors.copy()
+    # 向量化:把 key_factors 的维度/因子值映射为 one-hot 列名后取概率均值
+    dimension = scored["维度"].astype(str)
+    value = scored["因子值"].astype(str)
+    column_names = dimension + "=" + value
     ml_scores = []
-    for _, row in scored.iterrows():
-        dimension = row["维度"]
-        value = row["因子值"]
-        column = f"{dimension}={value}"
-        if column in X.columns:
-            mask = X[column] == 1
-            ml_scores.append(float(probabilities.loc[mask].mean()) if mask.any() else 0.0)
+    for col in column_names:
+        if col in X.columns:
+            ml_scores.append(float(probabilities.loc[X[col] == 1].mean()) if X[col].any() else 0.0)
         else:
             ml_scores.append(0.0)
     scored["ml_probability"] = ml_scores

+ 4 - 1
train_ml_models.py

@@ -20,7 +20,10 @@ from defect_analysis.schemas import normalize_defect_schema
 
 
 def load_defect_csv(csv_path):
-    return normalize_defect_schema(pd.read_csv(csv_path, parse_dates=["timestamp"], encoding="utf-8-sig"))
+    try:
+        return normalize_defect_schema(pd.read_csv(csv_path, parse_dates=["timestamp"], encoding="utf-8-sig"))
+    except (ValueError, KeyError) as exc:
+        raise SystemExit(f"CSV 读取失败: 请确保文件包含 timestamp 列,格式为 utf-8 — {exc}")
 
 
 def build_bundle_report(bundle):