Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - gnn/domain_graph_builder.py | |
| # Builds graph representations for GNN inference + training. | |
| # | |
| # Node features (12-dim per URL): | |
| # [url_len_norm, domain_len_norm, subdomain_count_norm, | |
| # shannon_entropy_norm, digit_ratio, hyphen_count_norm, | |
| # phishing_keyword_hits_norm, suspicious_tld_binary, | |
| # ip_as_hostname_binary, has_https_binary, | |
| # path_depth_norm, query_string_len_norm] | |
| # | |
| # Edges: shared suspicious TLD + shared IP (async DNS) | |
| # ============================================================ | |
| from __future__ import annotations | |
| import re | |
| import math | |
| import asyncio | |
| import logging | |
| import socket | |
| from typing import Dict, List, Optional, Tuple | |
| from urllib.parse import urlparse | |
| import numpy as np | |
| logger = logging.getLogger("phishguard.gnn.graph_builder") | |
| # ββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SUSPICIOUS_TLDS = frozenset({ | |
| ".xyz", ".tk", ".ml", ".ga", ".cf", | |
| ".gq", ".pw", ".top", ".click", | |
| }) | |
| PHISHING_KEYWORDS = frozenset({ | |
| "login", "verify", "secure", "update", "account", | |
| "banking", "signin", "reset", "confirm", "suspend", | |
| "webscr", "cmd", "payment", "alert", | |
| }) | |
| _re_ip = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$") | |
| class DomainGraphBuilder: | |
| """ | |
| Builds PyTorch Geometric Data objects from URL lists. | |
| Each URL becomes a node with 12-dim feature vector. | |
| Edges are created from shared IP addresses and shared TLDs. | |
| """ | |
| def __init__(self) -> None: | |
| self._re_ip = _re_ip | |
| def extract_node_features(self, url: str) -> np.ndarray: | |
| """ | |
| Extract 12-dim feature vector from a URL. | |
| Returns np.ndarray of shape (12,) with values in [0, 1]. | |
| """ | |
| try: | |
| parsed = urlparse(url if "://" in url else f"http://{url}") | |
| except Exception: | |
| return np.zeros(12, dtype=np.float32) | |
| hostname: str = (parsed.hostname or "").lower() | |
| path: str = parsed.path or "" | |
| query: str = parsed.query or "" | |
| scheme: str = parsed.scheme or "" | |
| # 1. url_len_norm (normalized by 500) | |
| url_len_norm = min(len(url) / 500.0, 1.0) | |
| # 2. domain_len_norm (normalized by 100) | |
| domain_len_norm = min(len(hostname) / 100.0, 1.0) | |
| # 3. subdomain_count_norm | |
| parts = hostname.split(".") | |
| subdomain_count = max(0, len(parts) - 2) | |
| subdomain_count_norm = min(subdomain_count / 10.0, 1.0) | |
| # 4. shannon_entropy_norm (normalized by 5.0) | |
| entropy = self._shannon_entropy(hostname) | |
| shannon_entropy_norm = min(entropy / 5.0, 1.0) | |
| # 5. digit_ratio | |
| digit_ratio = 0.0 | |
| if hostname: | |
| digits = sum(1 for c in hostname if c.isdigit()) | |
| digit_ratio = digits / len(hostname) | |
| # 6. hyphen_count_norm | |
| hyphen_count = hostname.count("-") | |
| hyphen_count_norm = min(hyphen_count / 10.0, 1.0) | |
| # 7. phishing_keyword_hits_norm | |
| url_lower = url.lower() | |
| keyword_hits = sum(1 for kw in PHISHING_KEYWORDS if kw in url_lower) | |
| phishing_keyword_hits_norm = min(keyword_hits / 5.0, 1.0) | |
| # 8. suspicious_tld_binary | |
| suspicious_tld_binary = 0.0 | |
| for tld in SUSPICIOUS_TLDS: | |
| if hostname.endswith(tld): | |
| suspicious_tld_binary = 1.0 | |
| break | |
| # 9. ip_as_hostname_binary | |
| ip_as_hostname_binary = 1.0 if self._re_ip.match(hostname) else 0.0 | |
| # 10. has_https_binary | |
| has_https_binary = 1.0 if scheme == "https" else 0.0 | |
| # 11. path_depth_norm | |
| path_segments = [s for s in path.split("/") if s] | |
| path_depth_norm = min(len(path_segments) / 10.0, 1.0) | |
| # 12. query_string_len_norm | |
| query_string_len_norm = min(len(query) / 500.0, 1.0) | |
| features = np.array([ | |
| url_len_norm, | |
| domain_len_norm, | |
| subdomain_count_norm, | |
| shannon_entropy_norm, | |
| digit_ratio, | |
| hyphen_count_norm, | |
| phishing_keyword_hits_norm, | |
| suspicious_tld_binary, | |
| ip_as_hostname_binary, | |
| has_https_binary, | |
| path_depth_norm, | |
| query_string_len_norm, | |
| ], dtype=np.float32) | |
| return features | |
| def _shannon_entropy(self, s: str) -> float: | |
| """Compute Shannon entropy of a string.""" | |
| if not s: | |
| return 0.0 | |
| length = len(s) | |
| freq: Dict[str, int] = {} | |
| for c in s: | |
| freq[c] = freq.get(c, 0) + 1 | |
| return -sum( | |
| (count / length) * math.log2(count / length) | |
| for count in freq.values() | |
| if count > 0 | |
| ) | |
| async def _resolve_ips(self, domains: List[str]) -> Dict[str, str]: | |
| """ | |
| Async DNS resolution for a list of domains. | |
| Returns dict mapping domain β IP address. | |
| """ | |
| results: Dict[str, str] = {} | |
| loop = asyncio.get_event_loop() | |
| async def resolve_one(domain: str) -> Tuple[str, str]: | |
| try: | |
| ip = await asyncio.wait_for( | |
| loop.run_in_executor(None, socket.gethostbyname, domain), | |
| timeout=2.0, | |
| ) | |
| return domain, ip | |
| except Exception: | |
| return domain, "" | |
| tasks = [resolve_one(d) for d in domains] | |
| resolved = await asyncio.gather(*tasks, return_exceptions=True) | |
| for item in resolved: | |
| if isinstance(item, tuple): | |
| domain, ip = item | |
| if ip: | |
| results[domain] = ip | |
| return results | |
| def _add_shared_ip_edges( | |
| self, domains: List[str], ips: Dict[str, str] | |
| ) -> List[Tuple[int, int]]: | |
| """ | |
| Create edges between nodes that share the same IP address. | |
| Returns list of (src, dst) index pairs. | |
| """ | |
| edges: List[Tuple[int, int]] = [] | |
| # Group domain indices by IP | |
| ip_to_indices: Dict[str, List[int]] = {} | |
| for idx, domain in enumerate(domains): | |
| ip = ips.get(domain, "") | |
| if ip: | |
| ip_to_indices.setdefault(ip, []).append(idx) | |
| # Create edges between all nodes sharing an IP | |
| for ip, indices in ip_to_indices.items(): | |
| for i in range(len(indices)): | |
| for j in range(i + 1, len(indices)): | |
| edges.append((indices[i], indices[j])) | |
| edges.append((indices[j], indices[i])) # bidirectional | |
| return edges | |
| def _add_shared_tld_edges(self, domains: List[str]) -> List[Tuple[int, int]]: | |
| """ | |
| Create edges between nodes that share the same suspicious TLD. | |
| """ | |
| edges: List[Tuple[int, int]] = [] | |
| tld_to_indices: Dict[str, List[int]] = {} | |
| for idx, domain in enumerate(domains): | |
| for tld in SUSPICIOUS_TLDS: | |
| if domain.endswith(tld): | |
| tld_to_indices.setdefault(tld, []).append(idx) | |
| break | |
| for tld, indices in tld_to_indices.items(): | |
| for i in range(len(indices)): | |
| for j in range(i + 1, len(indices)): | |
| edges.append((indices[i], indices[j])) | |
| edges.append((indices[j], indices[i])) | |
| return edges | |
| def build_graph(self, urls: List[str], resolve_dns: bool = False) -> dict: | |
| """ | |
| Build a graph dict from a list of URLs. | |
| Returns dict with: | |
| - features: np.ndarray of shape (N, 12) | |
| - edges: List of (src, dst) pairs | |
| - node_count: int | |
| - edge_count: int | |
| - domains: List[str] | |
| """ | |
| if not urls: | |
| return { | |
| "features": np.zeros((1, 12), dtype=np.float32), | |
| "edges": [], | |
| "node_count": 0, | |
| "edge_count": 0, | |
| "domains": [], | |
| } | |
| # Extract features for each URL | |
| features = np.array( | |
| [self.extract_node_features(url) for url in urls], | |
| dtype=np.float32, | |
| ) | |
| # Extract domains | |
| domains: List[str] = [] | |
| for url in urls: | |
| try: | |
| parsed = urlparse(url if "://" in url else f"http://{url}") | |
| domains.append((parsed.hostname or "").lower()) | |
| except Exception: | |
| domains.append("") | |
| # Build edges from shared TLDs (synchronous, fast) | |
| edges = self._add_shared_tld_edges(domains) | |
| # Optionally resolve DNS for shared IP edges | |
| if resolve_dns and len(domains) > 1: | |
| try: | |
| loop = asyncio.get_event_loop() | |
| if loop.is_running(): | |
| # Already in async context | |
| pass | |
| else: | |
| ips = loop.run_until_complete(self._resolve_ips(domains)) | |
| edges.extend(self._add_shared_ip_edges(domains, ips)) | |
| except RuntimeError: | |
| pass # Cannot resolve in this context | |
| return { | |
| "features": features, | |
| "edges": edges, | |
| "node_count": len(urls), | |
| "edge_count": len(edges), | |
| "domains": domains, | |
| } | |
| def build_single_node_graph(self, url: str) -> dict: | |
| """ | |
| Build a single-node graph for MLP fallback path. | |
| Used when a graph has fewer than 2 nodes. | |
| """ | |
| features = self.extract_node_features(url).reshape(1, -1) | |
| return { | |
| "features": features, | |
| "edges": [], | |
| "node_count": 1, | |
| "edge_count": 0, | |
| "domains": [url], | |
| } | |
| # ββ Legacy compatibility wrapper βββββββββββββββββββββββββββββββββββββ | |
| _builder = DomainGraphBuilder() | |
| def build_domain_graph(urls: List[str]) -> dict: | |
| """Legacy wrapper for backward compatibility.""" | |
| return _builder.build_graph(urls) | |