# ============================================================ # PhishGuard AI - cnn/cnn_model.py # ResNet50 visual classifier for phishing screenshot detection. # # Architecture (from spec): # Backbone: ResNet50 fully frozen # Custom head: Linear(2048→512) → ReLU → Dropout(0.5) → # Linear(512→1) → Sigmoid # Input: 224×224 screenshot tensor # Output: P_cnn ∈ [0,1] # ============================================================ from __future__ import annotations import io import logging from typing import Optional import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as T from PIL import Image logger = logging.getLogger("phishguard.cnn.model") class PhishCNN(nn.Module): """ ResNet50 with frozen backbone and custom 2-layer binary classification head. Output: P_cnn ∈ [0,1] via sigmoid. """ def __init__(self, pretrained: bool = True) -> None: super().__init__() # Load pretrained ResNet50 backbone if pretrained: self.backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) else: self.backbone = models.resnet50(weights=None) # Freeze entire backbone for param in self.backbone.parameters(): param.requires_grad = False # Replace fc with custom head: 2048 → 512 → 1 → sigmoid in_features = self.backbone.fc.in_features # 2048 self.backbone.fc = nn.Sequential( nn.Linear(in_features, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 1), ) # Ensure custom head is trainable for param in self.backbone.fc.parameters(): param.requires_grad = True def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass. Input: (batch, 3, 224, 224) Output: (batch, 1) probabilities in [0, 1] """ logits = self.backbone(x) return torch.sigmoid(logits) def predict_proba(self, x: torch.Tensor) -> float: """Return P_cnn ∈ [0,1] — probability of phishing.""" self.eval() with torch.no_grad(): output = self.forward(x) return output.squeeze().item() # ── Preprocessing pipeline (matches ImageNet normalization) ────────── TRANSFORM = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize( mean=[0.485, 0.456, 0.406], # ImageNet mean std=[0.229, 0.224, 0.225], # ImageNet std ), ]) # Training augmentation transforms TRAIN_TRANSFORM = T.Compose([ T.Resize((224, 224)), T.RandomHorizontalFlip(), T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), T.RandomRotation(5), T.ToTensor(), T.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ]) def preprocess_screenshot(screenshot_bytes: bytes) -> torch.Tensor: """Convert raw screenshot bytes → model-ready tensor [1, 3, 224, 224].""" img = Image.open(io.BytesIO(screenshot_bytes)).convert("RGB") return TRANSFORM(img).unsqueeze(0) def load_cnn(weights_path: Optional[str] = None) -> PhishCNN: """Load CNN model with optional trained weights.""" model = PhishCNN(pretrained=True) if weights_path: try: state = torch.load(weights_path, map_location="cpu", weights_only=True) model.load_state_dict(state) logger.info(f"CNN weights loaded from {weights_path}") except Exception as e: logger.warning(f"Could not load CNN weights: {e}") logger.info("Using ImageNet features only (baseline)") model.eval() return model