tabular_models.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. """表格模型训练入口。"""
  2. from sklearn.ensemble import IsolationForest, RandomForestClassifier
  3. from sklearn.linear_model import LogisticRegression
  4. from sklearn.metrics import accuracy_score, roc_auc_score
  5. from sklearn.pipeline import make_pipeline
  6. from sklearn.preprocessing import StandardScaler
  7. def _classification_metrics(model, X, y):
  8. pred = model.predict(X)
  9. metrics = {"train_accuracy": float(accuracy_score(y, pred))}
  10. if hasattr(model, "predict_proba") and len(set(y)) > 1:
  11. metrics["train_auc"] = float(roc_auc_score(y, model.predict_proba(X)[:, 1]))
  12. return metrics
  13. def train_tabular_model(model_name, X, y=None, *, random_state=42):
  14. """训练表格模型。
  15. 支持 random_forest、logistic_regression、isolation_forest。
  16. """
  17. if model_name == "random_forest":
  18. if y is None:
  19. raise ValueError("random_forest 需要监督标签 y")
  20. model = RandomForestClassifier(
  21. n_estimators=100,
  22. max_depth=8,
  23. min_samples_leaf=2,
  24. random_state=random_state,
  25. class_weight="balanced",
  26. )
  27. model.fit(X, y)
  28. return {"model_name": model_name, "model": model, "metrics": _classification_metrics(model, X, y)}
  29. if model_name == "logistic_regression":
  30. if y is None:
  31. raise ValueError("logistic_regression 需要监督标签 y")
  32. model = make_pipeline(
  33. StandardScaler(with_mean=False),
  34. LogisticRegression(max_iter=3000, class_weight="balanced", solver="liblinear"),
  35. )
  36. model.fit(X, y)
  37. return {"model_name": model_name, "model": model, "metrics": _classification_metrics(model, X, y)}
  38. if model_name == "isolation_forest":
  39. model = IsolationForest(n_estimators=100, contamination="auto", random_state=random_state)
  40. model.fit(X)
  41. scores = -model.decision_function(X)
  42. return {"model_name": model_name, "model": model, "anomaly_scores": scores}
  43. if model_name == "xgboost":
  44. if y is None:
  45. raise ValueError("xgboost 需要监督标签 y")
  46. try:
  47. from xgboost import XGBClassifier
  48. except ImportError as exc:
  49. raise RuntimeError("XGBoost 未安装,请安装 xgboost 后再启用该模型") from exc
  50. model = XGBClassifier(
  51. n_estimators=100,
  52. max_depth=4,
  53. learning_rate=0.08,
  54. eval_metric="logloss",
  55. random_state=random_state,
  56. )
  57. model.fit(X, y)
  58. return {"model_name": model_name, "model": model, "metrics": _classification_metrics(model, X, y)}
  59. if model_name == "lightgbm":
  60. if y is None:
  61. raise ValueError("lightgbm 需要监督标签 y")
  62. try:
  63. from lightgbm import LGBMClassifier
  64. except ImportError as exc:
  65. raise RuntimeError("LightGBM 未安装,请安装 lightgbm 后再启用该模型") from exc
  66. model = LGBMClassifier(
  67. n_estimators=100,
  68. max_depth=4,
  69. learning_rate=0.08,
  70. random_state=random_state,
  71. verbose=-1,
  72. )
  73. model.fit(X, y)
  74. return {"model_name": model_name, "model": model, "metrics": _classification_metrics(model, X, y)}
  75. raise ValueError(f"不支持的模型: {model_name}")