model_bundle.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """可持久化的 ML 模型包。"""
  2. import warnings
  3. from datetime import datetime, timezone
  4. import joblib
  5. import pandas as pd
  6. from sklearn.model_selection import train_test_split
  7. from defect_analysis.ml.datasets import build_supervised_dataset
  8. from defect_analysis.ml.features import build_feature_frame
  9. from defect_analysis.ml.model_registry import detect_optional_model_backends
  10. from defect_analysis.ml.tabular_models import classification_metrics, extract_feature_importance, train_tabular_model
  11. from defect_analysis.schemas import normalize_defect_schema
  12. MODEL_BUNDLE_VERSION = 1
  13. def _target_config(target_defect_type=None, target_severity=None):
  14. return {
  15. "defect_type": target_defect_type,
  16. "severity": target_severity,
  17. "default": target_defect_type is None and target_severity is None,
  18. }
  19. def _align_features(features, feature_columns):
  20. """按训练时特征签名对齐新数据,避免 one-hot 列漂移导致推理失败。"""
  21. aligned = features.reindex(columns=feature_columns, fill_value=0.0)
  22. return aligned.astype(float)
  23. def create_model_bundle(
  24. df,
  25. *,
  26. model_name="random_forest",
  27. target_defect_type=None,
  28. target_severity=None,
  29. random_state=42,
  30. test_size=0.25,
  31. ):
  32. """训练并创建可保存的模型包。"""
  33. normalized = normalize_defect_schema(df)
  34. X, y = build_supervised_dataset(
  35. normalized,
  36. target_defect_type=target_defect_type,
  37. target_severity=target_severity,
  38. )
  39. if y.nunique() < 2:
  40. raise ValueError("目标标签只有一个类别,无法训练监督模型")
  41. min_count = int(y.value_counts().min())
  42. if min_count < 2:
  43. warnings.warn(
  44. f"最小类别仅 {min_count} 个样本,已关闭分层抽样。验证集可能不包含少数类别。",
  45. UserWarning,
  46. )
  47. stratify = y if min_count >= 2 else None
  48. X_train, X_valid, y_train, y_valid = train_test_split(
  49. X,
  50. y,
  51. test_size=test_size,
  52. random_state=random_state,
  53. stratify=stratify,
  54. )
  55. validation_model = train_tabular_model(model_name, X_train, y_train, random_state=random_state)["model"]
  56. validation_metrics = classification_metrics(validation_model, X_valid, y_valid, prefix="validation")
  57. trained = train_tabular_model(model_name, X, y, random_state=random_state)
  58. feature_importance = extract_feature_importance(trained["model"], X.columns)
  59. return {
  60. "bundle_version": MODEL_BUNDLE_VERSION,
  61. "created_at": datetime.now(timezone.utc).isoformat(),
  62. "model_name": model_name,
  63. "target": _target_config(target_defect_type, target_severity),
  64. "feature_columns": list(X.columns),
  65. "metrics": trained.get("metrics", {}),
  66. "validation_metrics": validation_metrics,
  67. "feature_importance": feature_importance,
  68. "optional_backends": detect_optional_model_backends(),
  69. "model": trained["model"],
  70. }
  71. def save_model_bundle(bundle, path):
  72. """保存模型包。"""
  73. joblib.dump(bundle, path)
  74. return path
  75. def load_model_bundle(path):
  76. """加载模型包。"""
  77. bundle = joblib.load(path)
  78. if bundle.get("bundle_version") != MODEL_BUNDLE_VERSION:
  79. raise ValueError("模型包版本不兼容")
  80. return bundle
  81. def predict_with_bundle(bundle, df):
  82. """使用模型包对新数据打分。"""
  83. normalized = normalize_defect_schema(df).reset_index(drop=True)
  84. features = build_feature_frame(normalized)
  85. X = _align_features(features, bundle["feature_columns"])
  86. model = bundle["model"]
  87. scored = normalized.copy()
  88. scored["ml_prediction"] = model.predict(X)
  89. if hasattr(model, "predict_proba"):
  90. scored["ml_probability"] = model.predict_proba(X)[:, 1]
  91. else:
  92. scored["ml_probability"] = pd.NA
  93. scored["model_name"] = bundle["model_name"]
  94. return scored