train_ml_models.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. """训练和验证结构化 ML 模型。"""
  2. import argparse
  3. import pandas as pd
  4. from defect_analysis.ml.datasets import build_supervised_dataset
  5. from defect_analysis.ml.features import build_feature_frame
  6. from defect_analysis.ml.model_registry import detect_optional_model_backends
  7. from defect_analysis.ml.predict import predict_key_factors
  8. from defect_analysis.ml.tabular_models import train_tabular_model
  9. from defect_analysis.schemas import normalize_defect_schema
  10. def load_defect_csv(csv_path):
  11. return normalize_defect_schema(pd.read_csv(csv_path, parse_dates=["timestamp"], encoding="utf-8-sig"))
  12. def main():
  13. parser = argparse.ArgumentParser(description="训练/运行不良分析 ML 模型")
  14. parser.add_argument("--csv", default="defect_data.csv")
  15. parser.add_argument(
  16. "--model",
  17. default="random_forest",
  18. choices=["random_forest", "logistic_regression", "isolation_forest", "xgboost", "lightgbm"],
  19. )
  20. parser.add_argument("--target-defect-type")
  21. parser.add_argument("--target-severity")
  22. parser.add_argument("--top-n", type=int, default=10)
  23. parser.add_argument("--show-backends", action="store_true")
  24. args = parser.parse_args()
  25. if args.show_backends:
  26. print(detect_optional_model_backends())
  27. df = load_defect_csv(args.csv)
  28. if args.model == "isolation_forest":
  29. X = build_feature_frame(df)
  30. result = train_tabular_model("isolation_forest", X)
  31. scores = pd.Series(result["anomaly_scores"])
  32. print(f"IsolationForest 完成: 样本数={len(scores)}, 最高异常分={scores.max():.4f}, 平均异常分={scores.mean():.4f}")
  33. return
  34. X, y = build_supervised_dataset(
  35. df,
  36. target_defect_type=args.target_defect_type,
  37. target_severity=args.target_severity,
  38. )
  39. result = train_tabular_model(args.model, X, y)
  40. print(f"{args.model} 训练完成: {result['metrics']}")
  41. predictions = predict_key_factors(
  42. df,
  43. target_defect_type=args.target_defect_type,
  44. target_severity=args.target_severity,
  45. model_name=args.model,
  46. top_n=args.top_n,
  47. )
  48. if predictions.empty:
  49. print("未找到关键因子候选。")
  50. else:
  51. columns = ["维度", "因子值", "目标数", "异常倍数", "关键因子得分", "ml_probability", "model_name"]
  52. print(predictions[columns].to_string(index=False))
  53. if __name__ == "__main__":
  54. main()