Source code for abstractnn.soundness_checker

"""
Vérification statistique de la soundness des bornes calculées
"""

import numpy as np
import onnxruntime as ort
from typing import Dict, Tuple, List
from PIL import Image


[docs] class SoundnessChecker: """Vérifie statistiquement que les bornes calculées sont correctes (sound)"""
[docs] def __init__(self, model_path: str): self.model_path = model_path self.session = ort.InferenceSession(model_path) self.input_name = self.session.get_inputs()[0].name self.violations = []
[docs] def check_soundness(self, image_path: str, bounds: Dict[str, Tuple[float, float]], noise_level: float, num_samples: int = 1000) -> Dict: """ Vérifie la soundness en échantillonnant aléatoirement dans la région bruitée Args: image_path: Chemin vers l'image d'entrée bounds: Bornes calculées par le moteur formel {class: (lower, upper)} noise_level: Niveau de bruit appliqué num_samples: Nombre d'échantillons à tester Returns: Dictionnaire contenant les statistiques de soundness """ # Charge l'image de base img = Image.open(image_path).convert('L') img = img.resize((28, 28)) base_image = np.array(img, dtype=np.float32) / 255.0 base_image = base_image.reshape(1, 1, 28, 28) # Statistiques violations_per_class = {cls: 0 for cls in bounds.keys()} samples_per_class = {cls: [] for cls in bounds.keys()} print(f"\nVérification de soundness avec {num_samples} échantillons...") # Échantillonne et teste for i in range(num_samples): # Génère une perturbation aléatoire (force float32) noise = np.random.uniform(-noise_level, noise_level, base_image.shape).astype(np.float32) perturbed_image = np.clip(base_image + noise, 0.0, 1.0).astype(np.float32) # Inférence output = self.session.run(None, {self.input_name: perturbed_image})[0][0] # Vérifie les bornes pour chaque classe for class_idx, (class_name, (lower_bound, upper_bound)) in enumerate(bounds.items()): actual_value = output[class_idx] samples_per_class[class_name].append(actual_value) # Vérifie si la valeur est dans les bornes (avec tolérance numérique) tolerance = 1e-6 if actual_value < lower_bound - tolerance or actual_value > upper_bound + tolerance: violations_per_class[class_name] += 1 self.violations.append({ 'sample': i, 'class': class_name, 'actual': actual_value, 'bounds': (lower_bound, upper_bound), 'violation_type': 'lower' if actual_value < lower_bound else 'upper' }) if (i + 1) % 200 == 0: print(f" Progression: {i + 1}/{num_samples} échantillons testés") # Calcule les statistiques results = { 'num_samples': num_samples, 'total_violations': sum(violations_per_class.values()), 'violation_rate': sum(violations_per_class.values()) / (num_samples * len(bounds)), 'violations_per_class': violations_per_class, 'is_sound': sum(violations_per_class.values()) == 0, 'statistics_per_class': {} } # Statistiques détaillées par classe for class_name, samples in samples_per_class.items(): lower_bound, upper_bound = bounds[class_name] samples_array = np.array(samples) results['statistics_per_class'][class_name] = { 'empirical_min': float(np.min(samples_array)), 'empirical_max': float(np.max(samples_array)), 'empirical_mean': float(np.mean(samples_array)), 'empirical_std': float(np.std(samples_array)), 'formal_lower': lower_bound, 'formal_upper': upper_bound, 'lower_margin': float(np.min(samples_array) - lower_bound), 'upper_margin': float(upper_bound - np.max(samples_array)), 'violations': violations_per_class[class_name] } return results
[docs] def print_soundness_report(self, results: Dict): """Affiche un rapport de soundness lisible""" print("\n" + "="*70) print("RAPPORT DE SOUNDNESS") print("="*70) print(f"\nNombre d'échantillons testés: {results['num_samples']}") print(f"Violations totales: {results['total_violations']}") print(f"Taux de violation: {results['violation_rate']*100:.4f}%") print(f"Sound (correct): {'✓ OUI' if results['is_sound'] else '✗ NON'}") if results['is_sound']: print("\n🎉 Les bornes formelles sont SOUND (correctes) !") print(" Toutes les valeurs observées sont dans les bornes calculées.") else: print("\n⚠️ ATTENTION: Violations détectées !") print(" Certaines valeurs observées sont hors des bornes calculées.") print("\n" + "-"*70) print("STATISTIQUES PAR CLASSE") print("-"*70) print(f"\n{'Classe':<15} | {'Min obs':<10} | {'Max obs':<10} | {'Borne inf':<10} | {'Borne sup':<10} | {'Violations'}") print("-"*90) for class_name, stats in results['statistics_per_class'].items(): class_idx = int(class_name.split('_')[1]) violations = stats['violations'] violation_mark = '✗' if violations > 0 else '✓' print(f"{class_name:<15} | {stats['empirical_min']:<10.4f} | " f"{stats['empirical_max']:<10.4f} | {stats['formal_lower']:<10.4f} | " f"{stats['formal_upper']:<10.4f} | {violations:>5} {violation_mark}") # Analyse des marges print("\n" + "-"*70) print("MARGES DE SÉCURITÉ") print("-"*70) print(f"\n{'Classe':<15} | {'Marge inf':<15} | {'Marge sup':<15} | {'Statut'}") print("-"*70) for class_name, stats in results['statistics_per_class'].items(): lower_margin = stats['lower_margin'] upper_margin = stats['upper_margin'] if lower_margin < 0 or upper_margin < 0: status = "⚠️ VIOLATION" elif lower_margin < 1e-3 or upper_margin < 1e-3: status = "⚠️ Marge faible" else: status = "✓ OK" print(f"{class_name:<15} | {lower_margin:<15.6f} | {upper_margin:<15.6f} | {status}") # Affiche les violations si présentes if not results['is_sound'] and self.violations: print("\n" + "-"*70) print("DÉTAILS DES VIOLATIONS (premières 10)") print("-"*70) for violation in self.violations[:10]: lower, upper = violation['bounds'] print(f"\nÉchantillon #{violation['sample']} - Classe {violation['class']}") print(f" Valeur observée: {violation['actual']:.6f}") print(f" Bornes formelles: [{lower:.6f}, {upper:.6f}]") print(f" Type de violation: {violation['violation_type']}") print("\n" + "="*70)
[docs] def monte_carlo_robustness_test(model_path: str, image_path: str, noise_level: float, num_samples: int = 10000) -> Dict: """ Test Monte Carlo de robustesse empirique Compare avec l'évaluation formelle pour valider la conservativité Args: model_path: Chemin vers le modèle ONNX image_path: Chemin vers l'image noise_level: Niveau de bruit num_samples: Nombre d'échantillons Monte Carlo Returns: Statistiques de robustesse empirique """ session = ort.InferenceSession(model_path) input_name = session.get_inputs()[0].name # Charge l'image img = Image.open(image_path).convert('L') img = img.resize((28, 28)) base_image = np.array(img, dtype=np.float32) / 255.0 base_image = base_image.reshape(1, 1, 28, 28) # Prédiction sur l'image originale original_output = session.run(None, {input_name: base_image})[0][0] original_class = np.argmax(original_output) # Échantillonnage Monte Carlo class_changes = 0 predictions = [] print(f"\nTest Monte Carlo avec {num_samples} échantillons...") for i in range(num_samples): # Force float32 pour éviter les erreurs de type noise = np.random.uniform(-noise_level, noise_level, base_image.shape).astype(np.float32) perturbed_image = np.clip(base_image + noise, 0.0, 1.0).astype(np.float32) output = session.run(None, {input_name: perturbed_image})[0][0] predicted_class = np.argmax(output) predictions.append(predicted_class) if predicted_class != original_class: class_changes += 1 if (i + 1) % 2000 == 0: print(f" Progression: {i + 1}/{num_samples}") # Statistiques predictions = np.array(predictions) empirical_robustness = 1.0 - (class_changes / num_samples) class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] results = { 'num_samples': num_samples, 'original_class': int(original_class), 'original_class_name': class_names[original_class], 'class_changes': class_changes, 'empirical_robustness': empirical_robustness, 'prediction_distribution': { i: int(np.sum(predictions == i)) for i in range(10) } } print(f"\n{'='*60}") print("RÉSULTATS MONTE CARLO") print(f"{'='*60}") print(f"Classe originale: {class_names[original_class]}") print(f"Changements de classe: {class_changes}/{num_samples}") print(f"Robustesse empirique: {empirical_robustness*100:.2f}%") print(f"\nDistribution des prédictions:") for i, count in results['prediction_distribution'].items(): if count > 0: print(f" {class_names[i]}: {count} ({count/num_samples*100:.2f}%)") return results