"""训练和验证结构化 ML 模型。""" import argparse import json import pandas as pd from defect_analysis.ml.datasets import build_supervised_dataset from defect_analysis.ml.features import build_feature_frame from defect_analysis.ml.model_bundle import ( create_model_bundle, load_model_bundle, predict_with_bundle, save_model_bundle, ) from defect_analysis.ml.model_registry import detect_optional_model_backends from defect_analysis.ml.predict import predict_key_factors from defect_analysis.ml.tabular_models import train_tabular_model from defect_analysis.schemas import normalize_defect_schema def load_defect_csv(csv_path): try: return normalize_defect_schema(pd.read_csv(csv_path, parse_dates=["timestamp"], encoding="utf-8-sig")) except (ValueError, KeyError) as exc: raise SystemExit(f"CSV 读取失败: 请确保文件包含 timestamp 列,格式为 utf-8 — {exc}") def build_bundle_report(bundle): """生成可序列化的模型训练报告。""" return { "bundle_version": bundle["bundle_version"], "created_at": bundle["created_at"], "model_name": bundle["model_name"], "target": bundle["target"], "feature_count": len(bundle["feature_columns"]), "metrics": bundle["metrics"], "validation_metrics": bundle["validation_metrics"], "feature_importance": bundle["feature_importance"], "optional_backends": bundle["optional_backends"], } def main(): parser = argparse.ArgumentParser(description="训练/运行不良分析 ML 模型") parser.add_argument("--csv", default="defect_data.csv") parser.add_argument( "--model", default="random_forest", choices=["random_forest", "logistic_regression", "isolation_forest", "xgboost", "lightgbm"], ) parser.add_argument("--target-defect-type") parser.add_argument("--target-severity") parser.add_argument("--top-n", type=int, default=10) parser.add_argument("--show-backends", action="store_true") parser.add_argument("--save-model", help="训练后保存监督模型包到指定路径,仅支持监督模型") parser.add_argument("--model-path", help="批量打分时加载的模型包路径") parser.add_argument("--predict-csv", help="使用已保存模型包对新 CSV 批量打分") parser.add_argument("--output-csv", help="批量打分结果导出路径,默认打印前 20 行") parser.add_argument("--report-json", help="导出训练评估报告 JSON") args = parser.parse_args() if args.show_backends: print(detect_optional_model_backends()) if args.predict_csv: model_path = args.model_path or args.save_model if not model_path: raise SystemExit("--predict-csv 需要通过 --model-path 指定已保存的模型包路径") bundle = load_model_bundle(model_path) scored = predict_with_bundle(bundle, load_defect_csv(args.predict_csv)) if args.output_csv: scored.to_csv(args.output_csv, index=False, encoding="utf-8-sig") print(f"批量打分完成: {args.output_csv},样本数={len(scored)}") else: columns = ["defect_id", "panel_id", "defect_type", "severity", "ml_prediction", "ml_probability", "model_name"] print(scored[[col for col in columns if col in scored.columns]].head(20).to_string(index=False)) return df = load_defect_csv(args.csv) if args.model == "isolation_forest": X = build_feature_frame(df) result = train_tabular_model("isolation_forest", X) scores = pd.Series(result["anomaly_scores"]) print(f"IsolationForest 完成: 样本数={len(scores)}, 最高异常分={scores.max():.4f}, 平均异常分={scores.mean():.4f}") return if args.save_model: bundle = create_model_bundle( df, model_name=args.model, target_defect_type=args.target_defect_type, target_severity=args.target_severity, ) save_model_bundle(bundle, args.save_model) result = {"metrics": bundle["metrics"]} print(f"模型包已保存: {args.save_model}") if args.report_json: with open(args.report_json, "w", encoding="utf-8") as f: json.dump(build_bundle_report(bundle), f, ensure_ascii=False, indent=2) print(f"训练评估报告已保存: {args.report_json}") else: X, y = build_supervised_dataset( df, target_defect_type=args.target_defect_type, target_severity=args.target_severity, ) result = train_tabular_model(args.model, X, y) print(f"{args.model} 训练完成: {result['metrics']}") predictions = predict_key_factors( df, target_defect_type=args.target_defect_type, target_severity=args.target_severity, model_name=args.model, top_n=args.top_n, ) if predictions.empty: print("未找到关键因子候选。") else: columns = ["维度", "因子值", "目标数", "异常倍数", "关键因子得分", "ml_probability", "model_name"] print(predictions[columns].to_string(index=False)) if __name__ == "__main__": main()