"""表格模型训练入口。""" from sklearn.ensemble import IsolationForest, RandomForestClassifier from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score, roc_auc_score from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler def _classification_metrics(model, X, y): pred = model.predict(X) metrics = {"train_accuracy": float(accuracy_score(y, pred))} if hasattr(model, "predict_proba") and len(set(y)) > 1: metrics["train_auc"] = float(roc_auc_score(y, model.predict_proba(X)[:, 1])) return metrics def train_tabular_model(model_name, X, y=None, *, random_state=42): """训练表格模型。 支持 random_forest、logistic_regression、isolation_forest。 """ if model_name == "random_forest": if y is None: raise ValueError("random_forest 需要监督标签 y") model = RandomForestClassifier( n_estimators=100, max_depth=8, min_samples_leaf=2, random_state=random_state, class_weight="balanced", ) model.fit(X, y) return {"model_name": model_name, "model": model, "metrics": _classification_metrics(model, X, y)} if model_name == "logistic_regression": if y is None: raise ValueError("logistic_regression 需要监督标签 y") model = make_pipeline( StandardScaler(with_mean=False), LogisticRegression(max_iter=3000, class_weight="balanced", solver="liblinear"), ) model.fit(X, y) return {"model_name": model_name, "model": model, "metrics": _classification_metrics(model, X, y)} if model_name == "isolation_forest": model = IsolationForest(n_estimators=100, contamination="auto", random_state=random_state) model.fit(X) scores = -model.decision_function(X) return {"model_name": model_name, "model": model, "anomaly_scores": scores} if model_name == "xgboost": if y is None: raise ValueError("xgboost 需要监督标签 y") try: from xgboost import XGBClassifier except ImportError as exc: raise RuntimeError("XGBoost 未安装,请安装 xgboost 后再启用该模型") from exc model = XGBClassifier( n_estimators=100, max_depth=4, learning_rate=0.08, eval_metric="logloss", random_state=random_state, ) model.fit(X, y) return {"model_name": model_name, "model": model, "metrics": _classification_metrics(model, X, y)} if model_name == "lightgbm": if y is None: raise ValueError("lightgbm 需要监督标签 y") try: from lightgbm import LGBMClassifier except ImportError as exc: raise RuntimeError("LightGBM 未安装,请安装 lightgbm 后再启用该模型") from exc model = LGBMClassifier( n_estimators=100, max_depth=4, learning_rate=0.08, random_state=random_state, verbose=-1, ) model.fit(X, y) return {"model_name": model_name, "model": model, "metrics": _classification_metrics(model, X, y)} raise ValueError(f"不支持的模型: {model_name}")