test_ml_platform.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import unittest
  2. import pandas as pd
  3. from defect_analysis.ml.datasets import build_supervised_dataset
  4. from defect_analysis.ml.features import build_feature_frame
  5. from defect_analysis.ml.image_models import ImageModelUnavailable, ImageModelWrapper
  6. from defect_analysis.ml.model_bundle import (
  7. create_model_bundle,
  8. load_model_bundle,
  9. predict_with_bundle,
  10. save_model_bundle,
  11. )
  12. from defect_analysis.ml.model_registry import detect_optional_model_backends
  13. from defect_analysis.ml.predict import predict_key_factors
  14. from defect_analysis.ml.tabular_models import train_tabular_model
  15. from defect_analysis.schemas import normalize_defect_schema
  16. from train_ml_models import build_bundle_report
  17. class MLPlatformTest(unittest.TestCase):
  18. def setUp(self):
  19. rows = []
  20. for i in range(40):
  21. hot = i < 24
  22. rows.append(
  23. {
  24. "defect_id": f"D{i}",
  25. "panel_id": f"P{i}",
  26. "batch_id": "B1",
  27. "equipment_id": "LAM-A01" if hot else "LAM-B01",
  28. "seat_id": "R1C1" if hot else "R2C2",
  29. "inspection_station": "AOI-1",
  30. "timestamp": pd.Timestamp("2026-04-01 08:00:00"),
  31. "defect_type": "气泡" if hot else "划痕",
  32. "severity": "严重" if i % 5 == 0 else "轻微",
  33. "x_mm": 10.0 + i,
  34. "y_mm": 20.0,
  35. "panel_width_mm": 155.0,
  36. "panel_height_mm": 340.0,
  37. "hour": 8,
  38. "shift": "白班",
  39. "day": "2026-04-01",
  40. "lam_fixture_id": "FIX-HOT" if hot else "FIX-OK",
  41. "material_lot_oca": "OCA-HOT" if hot else "OCA-OK",
  42. }
  43. )
  44. self.df = normalize_defect_schema(pd.DataFrame(rows))
  45. def test_build_feature_frame_creates_numeric_matrix(self):
  46. features = build_feature_frame(self.df)
  47. self.assertEqual(len(self.df), len(features))
  48. self.assertTrue(all(dtype.kind in "biufc" for dtype in features.dtypes))
  49. self.assertTrue(any(col.startswith("equipment_id=") for col in features.columns))
  50. def test_build_supervised_dataset_targets_defect_type(self):
  51. X, y = build_supervised_dataset(self.df, target_defect_type="气泡")
  52. self.assertEqual(len(self.df), len(X))
  53. self.assertEqual(24, int(y.sum()))
  54. def test_train_random_forest_and_logistic_regression(self):
  55. X, y = build_supervised_dataset(self.df, target_defect_type="气泡")
  56. rf = train_tabular_model("random_forest", X, y)
  57. lr = train_tabular_model("logistic_regression", X, y)
  58. self.assertIn("model", rf)
  59. self.assertIn("metrics", rf)
  60. self.assertIn("model", lr)
  61. self.assertGreaterEqual(rf["metrics"]["train_accuracy"], 0.5)
  62. def test_train_isolation_forest_outputs_anomaly_scores(self):
  63. X = build_feature_frame(self.df)
  64. result = train_tabular_model("isolation_forest", X)
  65. self.assertIn("anomaly_scores", result)
  66. self.assertEqual(len(self.df), len(result["anomaly_scores"]))
  67. def test_predict_key_factors_returns_model_scores(self):
  68. predictions = predict_key_factors(self.df, target_defect_type="气泡")
  69. self.assertFalse(predictions.empty)
  70. self.assertIn("ml_probability", predictions.columns)
  71. self.assertIn("model_name", predictions.columns)
  72. def test_optional_backends_are_reported_without_import_failure(self):
  73. backends = detect_optional_model_backends()
  74. self.assertIn("xgboost", backends)
  75. self.assertIn("lightgbm", backends)
  76. def test_image_model_wrapper_is_explicitly_unavailable_without_backend(self):
  77. wrapper = ImageModelWrapper()
  78. with self.assertRaises(ImageModelUnavailable):
  79. wrapper.predict([])
  80. def test_model_bundle_can_be_saved_loaded_and_score_new_data(self):
  81. bundle = create_model_bundle(
  82. self.df,
  83. model_name="random_forest",
  84. target_defect_type="气泡",
  85. )
  86. self.assertEqual("random_forest", bundle["model_name"])
  87. self.assertEqual("气泡", bundle["target"]["defect_type"])
  88. self.assertGreater(len(bundle["feature_columns"]), 0)
  89. self.assertIn("metrics", bundle)
  90. self.assertIn("validation_metrics", bundle)
  91. self.assertIn("feature_importance", bundle)
  92. self.assertGreater(len(bundle["feature_importance"]), 0)
  93. self.assertIn("feature", bundle["feature_importance"][0])
  94. self.assertIn("importance", bundle["feature_importance"][0])
  95. path = "tmp_test_model_bundle.joblib"
  96. try:
  97. save_model_bundle(bundle, path)
  98. loaded = load_model_bundle(path)
  99. scored = predict_with_bundle(loaded, self.df.tail(5))
  100. finally:
  101. import os
  102. if os.path.exists(path):
  103. os.remove(path)
  104. self.assertEqual(5, len(scored))
  105. self.assertIn("ml_probability", scored.columns)
  106. self.assertTrue(scored["ml_probability"].between(0, 1).all())
  107. def test_model_bundle_aligns_missing_feature_columns_for_new_data(self):
  108. bundle = create_model_bundle(
  109. self.df,
  110. model_name="logistic_regression",
  111. target_defect_type="气泡",
  112. )
  113. new_df = self.df.tail(3).copy()
  114. new_df["equipment_id"] = "NEW-LAM"
  115. new_df["seat_id"] = "NEW-SEAT"
  116. scored = predict_with_bundle(bundle, new_df)
  117. self.assertEqual(3, len(scored))
  118. self.assertIn("ml_prediction", scored.columns)
  119. def test_bundle_report_excludes_model_object_and_keeps_audit_fields(self):
  120. bundle = create_model_bundle(
  121. self.df,
  122. model_name="random_forest",
  123. target_defect_type="气泡",
  124. )
  125. report = build_bundle_report(bundle)
  126. self.assertNotIn("model", report)
  127. self.assertIn("validation_metrics", report)
  128. self.assertIn("feature_importance", report)
  129. self.assertGreater(report["feature_count"], 0)
  130. if __name__ == "__main__":
  131. unittest.main()