model_bundle.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. """可持久化的 ML 模型包。"""
  2. from datetime import datetime, timezone
  3. import joblib
  4. import pandas as pd
  5. from sklearn.model_selection import train_test_split
  6. from defect_analysis.ml.datasets import build_supervised_dataset
  7. from defect_analysis.ml.features import build_feature_frame
  8. from defect_analysis.ml.model_registry import detect_optional_model_backends
  9. from defect_analysis.ml.tabular_models import classification_metrics, extract_feature_importance, train_tabular_model
  10. from defect_analysis.schemas import normalize_defect_schema
  11. MODEL_BUNDLE_VERSION = 1
  12. def _target_config(target_defect_type=None, target_severity=None):
  13. return {
  14. "defect_type": target_defect_type,
  15. "severity": target_severity,
  16. "default": target_defect_type is None and target_severity is None,
  17. }
  18. def _align_features(features, feature_columns):
  19. """按训练时特征签名对齐新数据,避免 one-hot 列漂移导致推理失败。"""
  20. aligned = features.reindex(columns=feature_columns, fill_value=0.0)
  21. return aligned.astype(float)
  22. def create_model_bundle(
  23. df,
  24. *,
  25. model_name="random_forest",
  26. target_defect_type=None,
  27. target_severity=None,
  28. random_state=42,
  29. test_size=0.25,
  30. ):
  31. """训练并创建可保存的模型包。"""
  32. normalized = normalize_defect_schema(df)
  33. X, y = build_supervised_dataset(
  34. normalized,
  35. target_defect_type=target_defect_type,
  36. target_severity=target_severity,
  37. )
  38. if y.nunique() < 2:
  39. raise ValueError("目标标签只有一个类别,无法训练监督模型")
  40. stratify = y if y.value_counts().min() >= 2 else None
  41. X_train, X_valid, y_train, y_valid = train_test_split(
  42. X,
  43. y,
  44. test_size=test_size,
  45. random_state=random_state,
  46. stratify=stratify,
  47. )
  48. validation_model = train_tabular_model(model_name, X_train, y_train, random_state=random_state)["model"]
  49. validation_metrics = classification_metrics(validation_model, X_valid, y_valid, prefix="validation")
  50. trained = train_tabular_model(model_name, X, y, random_state=random_state)
  51. feature_importance = extract_feature_importance(trained["model"], X.columns)
  52. return {
  53. "bundle_version": MODEL_BUNDLE_VERSION,
  54. "created_at": datetime.now(timezone.utc).isoformat(),
  55. "model_name": model_name,
  56. "target": _target_config(target_defect_type, target_severity),
  57. "feature_columns": list(X.columns),
  58. "metrics": trained.get("metrics", {}),
  59. "validation_metrics": validation_metrics,
  60. "feature_importance": feature_importance,
  61. "optional_backends": detect_optional_model_backends(),
  62. "model": trained["model"],
  63. }
  64. def save_model_bundle(bundle, path):
  65. """保存模型包。"""
  66. joblib.dump(bundle, path)
  67. return path
  68. def load_model_bundle(path):
  69. """加载模型包。"""
  70. bundle = joblib.load(path)
  71. if bundle.get("bundle_version") != MODEL_BUNDLE_VERSION:
  72. raise ValueError("模型包版本不兼容")
  73. return bundle
  74. def predict_with_bundle(bundle, df):
  75. """使用模型包对新数据打分。"""
  76. normalized = normalize_defect_schema(df).reset_index(drop=True)
  77. features = build_feature_frame(normalized)
  78. X = _align_features(features, bundle["feature_columns"])
  79. model = bundle["model"]
  80. scored = normalized.copy()
  81. scored["ml_prediction"] = model.predict(X)
  82. if hasattr(model, "predict_proba"):
  83. scored["ml_probability"] = model.predict_proba(X)[:, 1]
  84. else:
  85. scored["ml_probability"] = pd.NA
  86. scored["model_name"] = bundle["model_name"]
  87. return scored