train_ml_models.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. """训练和验证结构化 ML 模型。"""
  2. import argparse
  3. import json
  4. import pandas as pd
  5. from defect_analysis.ml.datasets import build_supervised_dataset
  6. from defect_analysis.ml.features import build_feature_frame
  7. from defect_analysis.ml.model_bundle import (
  8. create_model_bundle,
  9. load_model_bundle,
  10. predict_with_bundle,
  11. save_model_bundle,
  12. )
  13. from defect_analysis.ml.model_registry import detect_optional_model_backends
  14. from defect_analysis.ml.predict import predict_key_factors
  15. from defect_analysis.ml.tabular_models import train_tabular_model
  16. from defect_analysis.schemas import normalize_defect_schema
  17. def load_defect_csv(csv_path):
  18. try:
  19. return normalize_defect_schema(pd.read_csv(csv_path, parse_dates=["timestamp"], encoding="utf-8-sig"))
  20. except (ValueError, KeyError) as exc:
  21. raise SystemExit(f"CSV 读取失败: 请确保文件包含 timestamp 列,格式为 utf-8 — {exc}")
  22. def build_bundle_report(bundle):
  23. """生成可序列化的模型训练报告。"""
  24. return {
  25. "bundle_version": bundle["bundle_version"],
  26. "created_at": bundle["created_at"],
  27. "model_name": bundle["model_name"],
  28. "target": bundle["target"],
  29. "feature_count": len(bundle["feature_columns"]),
  30. "metrics": bundle["metrics"],
  31. "validation_metrics": bundle["validation_metrics"],
  32. "feature_importance": bundle["feature_importance"],
  33. "optional_backends": bundle["optional_backends"],
  34. }
  35. def main():
  36. parser = argparse.ArgumentParser(description="训练/运行不良分析 ML 模型")
  37. parser.add_argument("--csv", default="defect_data.csv")
  38. parser.add_argument(
  39. "--model",
  40. default="random_forest",
  41. choices=["random_forest", "logistic_regression", "isolation_forest", "xgboost", "lightgbm"],
  42. )
  43. parser.add_argument("--target-defect-type")
  44. parser.add_argument("--target-severity")
  45. parser.add_argument("--top-n", type=int, default=10)
  46. parser.add_argument("--show-backends", action="store_true")
  47. parser.add_argument("--save-model", help="训练后保存监督模型包到指定路径,仅支持监督模型")
  48. parser.add_argument("--model-path", help="批量打分时加载的模型包路径")
  49. parser.add_argument("--predict-csv", help="使用已保存模型包对新 CSV 批量打分")
  50. parser.add_argument("--output-csv", help="批量打分结果导出路径,默认打印前 20 行")
  51. parser.add_argument("--report-json", help="导出训练评估报告 JSON")
  52. args = parser.parse_args()
  53. if args.show_backends:
  54. print(detect_optional_model_backends())
  55. if args.predict_csv:
  56. model_path = args.model_path or args.save_model
  57. if not model_path:
  58. raise SystemExit("--predict-csv 需要通过 --model-path 指定已保存的模型包路径")
  59. bundle = load_model_bundle(model_path)
  60. scored = predict_with_bundle(bundle, load_defect_csv(args.predict_csv))
  61. if args.output_csv:
  62. scored.to_csv(args.output_csv, index=False, encoding="utf-8-sig")
  63. print(f"批量打分完成: {args.output_csv},样本数={len(scored)}")
  64. else:
  65. columns = ["defect_id", "panel_id", "defect_type", "severity", "ml_prediction", "ml_probability", "model_name"]
  66. print(scored[[col for col in columns if col in scored.columns]].head(20).to_string(index=False))
  67. return
  68. df = load_defect_csv(args.csv)
  69. if args.model == "isolation_forest":
  70. X = build_feature_frame(df)
  71. result = train_tabular_model("isolation_forest", X)
  72. scores = pd.Series(result["anomaly_scores"])
  73. print(f"IsolationForest 完成: 样本数={len(scores)}, 最高异常分={scores.max():.4f}, 平均异常分={scores.mean():.4f}")
  74. return
  75. if args.save_model:
  76. bundle = create_model_bundle(
  77. df,
  78. model_name=args.model,
  79. target_defect_type=args.target_defect_type,
  80. target_severity=args.target_severity,
  81. )
  82. save_model_bundle(bundle, args.save_model)
  83. result = {"metrics": bundle["metrics"]}
  84. print(f"模型包已保存: {args.save_model}")
  85. if args.report_json:
  86. with open(args.report_json, "w", encoding="utf-8") as f:
  87. json.dump(build_bundle_report(bundle), f, ensure_ascii=False, indent=2)
  88. print(f"训练评估报告已保存: {args.report_json}")
  89. else:
  90. X, y = build_supervised_dataset(
  91. df,
  92. target_defect_type=args.target_defect_type,
  93. target_severity=args.target_severity,
  94. )
  95. result = train_tabular_model(args.model, X, y)
  96. print(f"{args.model} 训练完成: {result['metrics']}")
  97. predictions = predict_key_factors(
  98. df,
  99. target_defect_type=args.target_defect_type,
  100. target_severity=args.target_severity,
  101. model_name=args.model,
  102. top_n=args.top_n,
  103. )
  104. if predictions.empty:
  105. print("未找到关键因子候选。")
  106. else:
  107. columns = ["维度", "因子值", "目标数", "异常倍数", "关键因子得分", "ml_probability", "model_name"]
  108. print(predictions[columns].to_string(index=False))
  109. if __name__ == "__main__":
  110. main()