datasets.py 827 B

12345678910111213141516171819202122
  1. """训练数据集构建。"""
  2. from defect_analysis.ml.features import build_feature_frame
  3. from defect_analysis.schemas import normalize_defect_schema
  4. def build_target_series(df, *, target_defect_type=None, target_severity=None):
  5. normalized = normalize_defect_schema(df)
  6. if target_defect_type:
  7. return (normalized["defect_type"] == target_defect_type).astype(int)
  8. if target_severity:
  9. return (normalized["severity"] == target_severity).astype(int)
  10. return (normalized["severity"] == "严重").astype(int)
  11. def build_supervised_dataset(df, *, target_defect_type=None, target_severity=None):
  12. """构建监督学习数据集。"""
  13. return build_feature_frame(df), build_target_series(
  14. df,
  15. target_defect_type=target_defect_type,
  16. target_severity=target_severity,
  17. )