2 次代碼提交 17e5c9360e ... 014a58e7cd

作者 SHA1 備註 提交日期
  leod 014a58e7cd 优化:model_bundle.py stratify 关闭时增加 UserWarning 提示 4 天之前
  leod e41454aad6 优化:修复审核发现的 3 个 Low 级别问题 4 天之前
共有 3 個文件被更改,包括 19 次插入9 次删除
  1. 8 1
      defect_analysis/ml/model_bundle.py
  2. 7 7
      defect_analysis/ml/predict.py
  3. 4 1
      train_ml_models.py

+ 8 - 1
defect_analysis/ml/model_bundle.py

@@ -1,5 +1,6 @@
 """可持久化的 ML 模型包。"""
 
+import warnings
 from datetime import datetime, timezone
 
 import joblib
@@ -49,7 +50,13 @@ def create_model_bundle(
     if y.nunique() < 2:
         raise ValueError("目标标签只有一个类别,无法训练监督模型")
 
-    stratify = y if y.value_counts().min() >= 2 else None
+    min_count = int(y.value_counts().min())
+    if min_count < 2:
+        warnings.warn(
+            f"最小类别仅 {min_count} 个样本,已关闭分层抽样。验证集可能不包含少数类别。",
+            UserWarning,
+        )
+    stratify = y if min_count >= 2 else None
     X_train, X_valid, y_train, y_valid = train_test_split(
         X,
         y,

+ 7 - 7
defect_analysis/ml/predict.py

@@ -33,14 +33,14 @@ def predict_key_factors(df, *, target_defect_type=None, target_severity=None, mo
     probabilities = pd.Series(model.predict_proba(X)[:, 1], index=X.index)
 
     scored = key_factors.copy()
+    # 向量化:把 key_factors 的维度/因子值映射为 one-hot 列名后取概率均值
+    dimension = scored["维度"].astype(str)
+    value = scored["因子值"].astype(str)
+    column_names = dimension + "=" + value
     ml_scores = []
-    for _, row in scored.iterrows():
-        dimension = row["维度"]
-        value = row["因子值"]
-        column = f"{dimension}={value}"
-        if column in X.columns:
-            mask = X[column] == 1
-            ml_scores.append(float(probabilities.loc[mask].mean()) if mask.any() else 0.0)
+    for col in column_names:
+        if col in X.columns:
+            ml_scores.append(float(probabilities.loc[X[col] == 1].mean()) if X[col].any() else 0.0)
         else:
             ml_scores.append(0.0)
     scored["ml_probability"] = ml_scores

+ 4 - 1
train_ml_models.py

@@ -20,7 +20,10 @@ from defect_analysis.schemas import normalize_defect_schema
 
 
 def load_defect_csv(csv_path):
-    return normalize_defect_schema(pd.read_csv(csv_path, parse_dates=["timestamp"], encoding="utf-8-sig"))
+    try:
+        return normalize_defect_schema(pd.read_csv(csv_path, parse_dates=["timestamp"], encoding="utf-8-sig"))
+    except (ValueError, KeyError) as exc:
+        raise SystemExit(f"CSV 读取失败: 请确保文件包含 timestamp 列,格式为 utf-8 — {exc}")
 
 
 def build_bundle_report(bundle):