tabular_models.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. """表格模型训练入口。"""
  2. from sklearn.ensemble import IsolationForest, RandomForestClassifier
  3. from sklearn.linear_model import LogisticRegression
  4. from sklearn.metrics import accuracy_score, roc_auc_score
  5. from sklearn.pipeline import make_pipeline
  6. from sklearn.preprocessing import StandardScaler
  7. def classification_metrics(model, X, y, *, prefix="train"):
  8. """计算二分类评估指标。"""
  9. pred = model.predict(X)
  10. metrics = {f"{prefix}_accuracy": float(accuracy_score(y, pred))}
  11. if hasattr(model, "predict_proba") and len(set(y)) > 1:
  12. metrics[f"{prefix}_auc"] = float(roc_auc_score(y, model.predict_proba(X)[:, 1]))
  13. return metrics
  14. def extract_feature_importance(model, feature_columns, *, top_n=20):
  15. """提取模型特征贡献,用于工程复盘和模型审计。"""
  16. if hasattr(model, "feature_importances_"):
  17. values = model.feature_importances_
  18. elif hasattr(model, "named_steps") and "logisticregression" in model.named_steps:
  19. values = abs(model.named_steps["logisticregression"].coef_[0])
  20. elif hasattr(model, "coef_"):
  21. values = abs(model.coef_[0])
  22. else:
  23. return []
  24. rows = [
  25. {"feature": feature, "importance": float(importance)}
  26. for feature, importance in zip(feature_columns, values)
  27. ]
  28. rows.sort(key=lambda item: item["importance"], reverse=True)
  29. return rows[:top_n]
  30. def train_tabular_model(model_name, X, y=None, *, random_state=42):
  31. """训练表格模型。
  32. 支持 random_forest、logistic_regression、isolation_forest。
  33. """
  34. if model_name == "random_forest":
  35. if y is None:
  36. raise ValueError("random_forest 需要监督标签 y")
  37. model = RandomForestClassifier(
  38. n_estimators=100,
  39. max_depth=8,
  40. min_samples_leaf=2,
  41. random_state=random_state,
  42. class_weight="balanced",
  43. )
  44. model.fit(X, y)
  45. return {"model_name": model_name, "model": model, "metrics": classification_metrics(model, X, y)}
  46. if model_name == "logistic_regression":
  47. if y is None:
  48. raise ValueError("logistic_regression 需要监督标签 y")
  49. model = make_pipeline(
  50. StandardScaler(with_mean=False),
  51. LogisticRegression(max_iter=3000, class_weight="balanced", solver="liblinear"),
  52. )
  53. model.fit(X, y)
  54. return {"model_name": model_name, "model": model, "metrics": classification_metrics(model, X, y)}
  55. if model_name == "isolation_forest":
  56. model = IsolationForest(n_estimators=100, contamination="auto", random_state=random_state)
  57. model.fit(X)
  58. scores = -model.decision_function(X)
  59. return {"model_name": model_name, "model": model, "anomaly_scores": scores}
  60. if model_name == "xgboost":
  61. if y is None:
  62. raise ValueError("xgboost 需要监督标签 y")
  63. try:
  64. from xgboost import XGBClassifier
  65. except ImportError as exc:
  66. raise RuntimeError("XGBoost 未安装,请安装 xgboost 后再启用该模型") from exc
  67. model = XGBClassifier(
  68. n_estimators=100,
  69. max_depth=4,
  70. learning_rate=0.08,
  71. eval_metric="logloss",
  72. random_state=random_state,
  73. )
  74. model.fit(X, y)
  75. return {"model_name": model_name, "model": model, "metrics": classification_metrics(model, X, y)}
  76. if model_name == "lightgbm":
  77. if y is None:
  78. raise ValueError("lightgbm 需要监督标签 y")
  79. try:
  80. from lightgbm import LGBMClassifier
  81. except ImportError as exc:
  82. raise RuntimeError("LightGBM 未安装,请安装 lightgbm 后再启用该模型") from exc
  83. model = LGBMClassifier(
  84. n_estimators=100,
  85. max_depth=4,
  86. learning_rate=0.08,
  87. random_state=random_state,
  88. verbose=-1,
  89. )
  90. model.fit(X, y)
  91. return {"model_name": model_name, "model": model, "metrics": classification_metrics(model, X, y)}
  92. raise ValueError(f"不支持的模型: {model_name}")