| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- """表格模型训练入口。"""
- from sklearn.ensemble import IsolationForest, RandomForestClassifier
- from sklearn.linear_model import LogisticRegression
- from sklearn.metrics import accuracy_score, roc_auc_score
- from sklearn.pipeline import make_pipeline
- from sklearn.preprocessing import StandardScaler
- def _classification_metrics(model, X, y):
- pred = model.predict(X)
- metrics = {"train_accuracy": float(accuracy_score(y, pred))}
- if hasattr(model, "predict_proba") and len(set(y)) > 1:
- metrics["train_auc"] = float(roc_auc_score(y, model.predict_proba(X)[:, 1]))
- return metrics
- def train_tabular_model(model_name, X, y=None, *, random_state=42):
- """训练表格模型。
- 支持 random_forest、logistic_regression、isolation_forest。
- """
- if model_name == "random_forest":
- if y is None:
- raise ValueError("random_forest 需要监督标签 y")
- model = RandomForestClassifier(
- n_estimators=100,
- max_depth=8,
- min_samples_leaf=2,
- random_state=random_state,
- class_weight="balanced",
- )
- model.fit(X, y)
- return {"model_name": model_name, "model": model, "metrics": _classification_metrics(model, X, y)}
- if model_name == "logistic_regression":
- if y is None:
- raise ValueError("logistic_regression 需要监督标签 y")
- model = make_pipeline(
- StandardScaler(with_mean=False),
- LogisticRegression(max_iter=3000, class_weight="balanced", solver="liblinear"),
- )
- model.fit(X, y)
- return {"model_name": model_name, "model": model, "metrics": _classification_metrics(model, X, y)}
- if model_name == "isolation_forest":
- model = IsolationForest(n_estimators=100, contamination="auto", random_state=random_state)
- model.fit(X)
- scores = -model.decision_function(X)
- return {"model_name": model_name, "model": model, "anomaly_scores": scores}
- if model_name == "xgboost":
- if y is None:
- raise ValueError("xgboost 需要监督标签 y")
- try:
- from xgboost import XGBClassifier
- except ImportError as exc:
- raise RuntimeError("XGBoost 未安装,请安装 xgboost 后再启用该模型") from exc
- model = XGBClassifier(
- n_estimators=100,
- max_depth=4,
- learning_rate=0.08,
- eval_metric="logloss",
- random_state=random_state,
- )
- model.fit(X, y)
- return {"model_name": model_name, "model": model, "metrics": _classification_metrics(model, X, y)}
- if model_name == "lightgbm":
- if y is None:
- raise ValueError("lightgbm 需要监督标签 y")
- try:
- from lightgbm import LGBMClassifier
- except ImportError as exc:
- raise RuntimeError("LightGBM 未安装,请安装 lightgbm 后再启用该模型") from exc
- model = LGBMClassifier(
- n_estimators=100,
- max_depth=4,
- learning_rate=0.08,
- random_state=random_state,
- verbose=-1,
- )
- model.fit(X, y)
- return {"model_name": model_name, "model": model, "metrics": _classification_metrics(model, X, y)}
- raise ValueError(f"不支持的模型: {model_name}")
|