Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - cnn/cnn_model.py | |
| # ResNet50 visual classifier for phishing screenshot detection. | |
| # | |
| # Architecture (from spec): | |
| # Backbone: ResNet50 fully frozen | |
| # Custom head: Linear(2048β512) β ReLU β Dropout(0.5) β | |
| # Linear(512β1) β Sigmoid | |
| # Input: 224Γ224 screenshot tensor | |
| # Output: P_cnn β [0,1] | |
| # ============================================================ | |
| from __future__ import annotations | |
| import io | |
| import logging | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| logger = logging.getLogger("phishguard.cnn.model") | |
| class PhishCNN(nn.Module): | |
| """ | |
| ResNet50 with frozen backbone and custom 2-layer binary classification head. | |
| Output: P_cnn β [0,1] via sigmoid. | |
| """ | |
| def __init__(self, pretrained: bool = True) -> None: | |
| super().__init__() | |
| # Load pretrained ResNet50 backbone | |
| if pretrained: | |
| self.backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
| else: | |
| self.backbone = models.resnet50(weights=None) | |
| # Freeze entire backbone | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| # Replace fc with custom head: 2048 β 512 β 1 β sigmoid | |
| in_features = self.backbone.fc.in_features # 2048 | |
| self.backbone.fc = nn.Sequential( | |
| nn.Linear(in_features, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(512, 1), | |
| ) | |
| # Ensure custom head is trainable | |
| for param in self.backbone.fc.parameters(): | |
| param.requires_grad = True | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass. | |
| Input: (batch, 3, 224, 224) | |
| Output: (batch, 1) probabilities in [0, 1] | |
| """ | |
| logits = self.backbone(x) | |
| return torch.sigmoid(logits) | |
| def predict_proba(self, x: torch.Tensor) -> float: | |
| """Return P_cnn β [0,1] β probability of phishing.""" | |
| self.eval() | |
| with torch.no_grad(): | |
| output = self.forward(x) | |
| return output.squeeze().item() | |
| # ββ Preprocessing pipeline (matches ImageNet normalization) ββββββββββ | |
| TRANSFORM = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize( | |
| mean=[0.485, 0.456, 0.406], # ImageNet mean | |
| std=[0.229, 0.224, 0.225], # ImageNet std | |
| ), | |
| ]) | |
| # Training augmentation transforms | |
| TRAIN_TRANSFORM = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.RandomHorizontalFlip(), | |
| T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), | |
| T.RandomRotation(5), | |
| T.ToTensor(), | |
| T.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ), | |
| ]) | |
| def preprocess_screenshot(screenshot_bytes: bytes) -> torch.Tensor: | |
| """Convert raw screenshot bytes β model-ready tensor [1, 3, 224, 224].""" | |
| img = Image.open(io.BytesIO(screenshot_bytes)).convert("RGB") | |
| return TRANSFORM(img).unsqueeze(0) | |
| def load_cnn(weights_path: Optional[str] = None) -> PhishCNN: | |
| """Load CNN model with optional trained weights.""" | |
| model = PhishCNN(pretrained=True) | |
| if weights_path: | |
| try: | |
| state = torch.load(weights_path, map_location="cpu", weights_only=True) | |
| model.load_state_dict(state) | |
| logger.info(f"CNN weights loaded from {weights_path}") | |
| except Exception as e: | |
| logger.warning(f"Could not load CNN weights: {e}") | |
| logger.info("Using ImageNet features only (baseline)") | |
| model.eval() | |
| return model | |