""" TB Bacilli Analysis System - Hugging Face Spaces Deployment Complete version with all 6 tabs """ import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import cv2 from PIL import Image from transformers import SegformerForSemanticSegmentation import math import os # EDSR Model for Super-Resolution class ResidualBlockNoBN(nn.Module): def __init__(self, n_feats, res_scale=0.1): super().__init__() self.res_scale = res_scale self.conv1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) self.conv2 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) self.relu = nn.ReLU(inplace=True) def forward(self, x): res = self.conv1(x) res = self.relu(res) res = self.conv2(res) return x + res * self.res_scale class EDSR(nn.Module): def __init__(self, in_channels=3, out_channels=3, scale=4, n_resblocks=32, n_feats=128, rgb_range=1.0): super().__init__() self.scale = scale self.rgb_range = rgb_range self.conv_head = nn.Conv2d(in_channels, n_feats, 3, 1, 1) body = [ResidualBlockNoBN(n_feats) for _ in range(n_resblocks)] self.body = nn.Sequential(*body) self.conv_body = nn.Conv2d(n_feats, n_feats, 3, 1, 1) tail = [] n_upscale = int(math.log2(scale)) for _ in range(n_upscale): tail.append(nn.Conv2d(n_feats, n_feats * 4, 3, 1, 1)) tail.append(nn.PixelShuffle(2)) tail.append(nn.ReLU(inplace=True)) tail.append(nn.Conv2d(n_feats, out_channels, 3, 1, 1)) self.tail = nn.Sequential(*tail) def forward(self, x): x = x * self.rgb_range x = self.conv_head(x) res = self.body(x) res = self.conv_body(res) x = x + res x = self.tail(x) x = torch.clamp(x / self.rgb_range, 0.0, 1.0) return x # SegFormer Model for TB Bacilli Segmentation class SegFormerTB(torch.nn.Module): def __init__(self, model_name, num_classes=1): super().__init__() self.segformer = SegformerForSemanticSegmentation.from_pretrained( model_name, num_labels=num_classes, ignore_mismatched_sizes=True ) def forward(self, pixel_values): outputs = self.segformer(pixel_values=pixel_values) logits = outputs.logits logits = F.interpolate( logits, size=pixel_values.shape[-2:], mode='bilinear', align_corners=False ) return logits # Configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") EDSR_SCALE = 4 EDSR_PATCH_SIZE = 256 EDSR_OVERLAP = 32 MODEL_NAME = "nvidia/segformer-b3-finetuned-ade-512-512" NUM_CLASSES = 1 IMAGE_SIZE = 512 CONFIDENCE_THRESHOLD = 0.5 MIN_BACILLI_AREA = 3 # Load models print("Loading EDSR model...") edsr_model = EDSR(in_channels=3, out_channels=3, scale=4, n_resblocks=16, n_feats=64).to(device) if os.path.exists("models/edsr_ft_best.pth"): edsr_checkpoint = torch.load("models/edsr_ft_best.pth", map_location=device, weights_only=False) edsr_model.load_state_dict(edsr_checkpoint["state_dict"]) edsr_model.eval() print("EDSR model loaded!") else: print("⚠️ EDSR model not found. SR features will be disabled.") edsr_model = None print("Loading SegFormer model...") segformer_model = SegFormerTB(MODEL_NAME, NUM_CLASSES).to(device) if os.path.exists("models/best_model.pth"): seg_checkpoint = torch.load("models/best_model.pth", map_location=device, weights_only=False) segformer_model.load_state_dict(seg_checkpoint['model_state_dict']) segformer_model.eval() print("SegFormer model loaded!") else: print("⚠️ SegFormer model not found. Segmentation features will be disabled.") segformer_model = None # Helper functions def detect_bacilli(mask, min_area=3): num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( mask.astype(np.uint8), connectivity=8 ) detections = [] for i in range(1, num_labels): area = stats[i, cv2.CC_STAT_AREA] if area >= min_area: x = stats[i, cv2.CC_STAT_LEFT] y = stats[i, cv2.CC_STAT_TOP] w = stats[i, cv2.CC_STAT_WIDTH] h = stats[i, cv2.CC_STAT_HEIGHT] detections.append({ 'bbox': (x, y, w, h), 'area': area, 'centroid': centroids[i] }) return detections def preprocess_image(image): img = np.array(image) if len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) elif img.shape[2] == 4: img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) return img # Feature functions def lr_to_sr(image): if edsr_model is None: return None, "⚠️ EDSR model not available" img = preprocess_image(image) h_orig, w_orig = img.shape[:2] # Patch-wise processing for large images patch_size = EDSR_PATCH_SIZE overlap = EDSR_OVERLAP scale = EDSR_SCALE if h_orig > patch_size or w_orig > patch_size: h, w = img.shape[:2] sr_h, sr_w = h * scale, w * scale sr_img = np.zeros((sr_h, sr_w, 3), dtype=np.float32) weight_map = np.zeros((sr_h, sr_w, 3), dtype=np.float32) for y in range(0, h, patch_size - overlap): for x in range(0, w, patch_size - overlap): y_end = min(y + patch_size, h) x_end = min(x + patch_size, w) patch = img[y:y_end, x:x_end] patch_tensor = torch.from_numpy(patch).permute(2, 0, 1).float() / 255.0 patch_tensor = patch_tensor.unsqueeze(0).to(device) with torch.no_grad(): sr_patch_tensor = edsr_model(patch_tensor) sr_patch = sr_patch_tensor.clamp(0.0, 1.0).squeeze(0).permute(1, 2, 0).cpu().numpy() sr_y, sr_x = y * scale, x * scale sr_y_end, sr_x_end = sr_y + sr_patch.shape[0], sr_x + sr_patch.shape[1] weight = np.ones_like(sr_patch) if overlap > 0: fade = overlap * scale for i in range(fade): alpha = i / fade if sr_y + i < sr_h: weight[i, :, :] *= alpha if sr_y_end - i - 1 >= 0: weight[-i-1, :, :] *= alpha if sr_x + i < sr_w: weight[:, i, :] *= alpha if sr_x_end - i - 1 >= 0: weight[:, -i-1, :] *= alpha sr_img[sr_y:sr_y_end, sr_x:sr_x_end] += sr_patch * weight weight_map[sr_y:sr_y_end, sr_x:sr_x_end] += weight sr_img = np.divide(sr_img, weight_map, where=weight_map > 0) sr_img = (sr_img * 255).astype(np.uint8) info = f"Input: {w_orig}x{h_orig} → Output: {sr_w}x{sr_h} (Patch-wise, Scale: {scale}x)" else: img_tensor = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 img_tensor = img_tensor.unsqueeze(0).to(device) with torch.no_grad(): sr_tensor = edsr_model(img_tensor) sr_tensor = sr_tensor.clamp(0.0, 1.0) sr_img = (sr_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) h_sr, w_sr = sr_img.shape[:2] info = f"Input: {w_orig}x{h_orig} → Output: {w_sr}x{h_sr} (Scale: {scale}x)" return Image.fromarray(sr_img), info def lr_to_sr_comparison(image): if edsr_model is None: return None, "⚠️ EDSR model not available" sr_img_pil, info = lr_to_sr(image) if sr_img_pil is None: return None, info sr_img = np.array(sr_img_pil) img = preprocess_image(image) h_orig, w_orig = img.shape[:2] h_sr, w_sr = sr_img.shape[:2] img_upscaled = cv2.resize(img, (w_sr, h_sr), interpolation=cv2.INTER_CUBIC) comparison = np.hstack([img_upscaled, sr_img]) info = f"Left: Bicubic ({w_sr}x{h_sr}) | Right: EDSR SR ({w_sr}x{h_sr}) | Original: {w_orig}x{h_orig}" return Image.fromarray(comparison), info def segment_image(image): if segformer_model is None: return None, "⚠️ SegFormer model not available" img = preprocess_image(image) h_orig, w_orig = img.shape[:2] img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0) img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) with torch.no_grad(): output = segformer_model(img_tensor) pred_prob = torch.sigmoid(output).squeeze().cpu().numpy() pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) * 255 pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST) info = f"Input: {w_orig}x{h_orig} | Segmentation completed" return Image.fromarray(pred_mask), info def segment_and_detect(image, min_area=3): if segformer_model is None: return None, "⚠️ SegFormer model not available" img = preprocess_image(image) h_orig, w_orig = img.shape[:2] img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0) img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) with torch.no_grad(): output = segformer_model(img_tensor) pred_prob = torch.sigmoid(output).squeeze().cpu().numpy() pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST) detections = detect_bacilli(pred_mask, min_area=min_area) result_img = img.copy() for det in detections: x, y, w, h = det['bbox'] cv2.rectangle(result_img, (x, y), (x+w, y+h), (0, 255, 0), 2) cv2.putText(result_img, f"{det['area']}", (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) total_mask_pixels = np.sum(pred_mask > 0) detected_pixels = sum(det['area'] for det in detections) coverage = (detected_pixels / total_mask_pixels * 100) if total_mask_pixels > 0 else 0 info = f"Detected {len(detections)} TB bacilli | Mask pixels: {total_mask_pixels} | Coverage: {coverage:.1f}%" return Image.fromarray(result_img), info def segment_comparison(image): if segformer_model is None: return None, "⚠️ SegFormer model not available" img = preprocess_image(image) h_orig, w_orig = img.shape[:2] img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0) img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) with torch.no_grad(): output = segformer_model(img_tensor) pred_prob = torch.sigmoid(output).squeeze().cpu().numpy() pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) * 255 pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST) pred_mask_rgb = cv2.cvtColor(pred_mask, cv2.COLOR_GRAY2RGB) comparison = np.hstack([img, pred_mask_rgb]) info = f"Left: Original Image ({w_orig}x{h_orig}) | Right: Segmentation Mask" return Image.fromarray(comparison), info def segment_overlay(image, alpha=0.5): if segformer_model is None: return None, "⚠️ SegFormer model not available" img = preprocess_image(image) h_orig, w_orig = img.shape[:2] img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0) img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) with torch.no_grad(): output = segformer_model(img_tensor) pred_prob = torch.sigmoid(output).squeeze().cpu().numpy() pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST) overlay = img.copy() overlay[pred_mask > 0] = [0, 255, 0] result = cv2.addWeighted(img, 1-alpha, overlay, alpha, 0) bacilli_pixels = np.sum(pred_mask > 0) info = f"Overlay with {bacilli_pixels} bacilli pixels | Alpha: {alpha}" return Image.fromarray(result), info def full_segmentation_pipeline(image, min_area=3): if segformer_model is None: return None, None, None img = preprocess_image(image) h_orig, w_orig = img.shape[:2] img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0) img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) with torch.no_grad(): output = segformer_model(img_tensor) pred_prob = torch.sigmoid(output).squeeze().cpu().numpy() pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST) seg_mask = (pred_mask * 255).astype(np.uint8) detections = detect_bacilli(pred_mask, min_area=min_area) detection_img = img.copy() for det in detections: x, y, w, h = det['bbox'] cv2.rectangle(detection_img, (x, y), (x+w, y+h), (0, 255, 0), 2) overlay = img.copy() overlay[pred_mask > 0] = [0, 255, 0] overlay_img = cv2.addWeighted(img, 0.6, overlay, 0.4, 0) return Image.fromarray(seg_mask), Image.fromarray(detection_img), Image.fromarray(overlay_img) def complete_pipeline(image, min_area=3): """Full pipeline: LR to SR, then segmentation and detection""" if edsr_model is None or segformer_model is None: return None, None, None # Step 1: LR to SR sr_img_pil, sr_info = lr_to_sr(image) if sr_img_pil is None: return None, None, None sr_img = np.array(sr_img_pil) h_sr, w_sr = sr_img.shape[:2] # Step 2: Segmentation on SR image img_resized = cv2.resize(sr_img, (IMAGE_SIZE, IMAGE_SIZE)) img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0) img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) with torch.no_grad(): output = segformer_model(img_tensor) pred_prob = torch.sigmoid(output).squeeze().cpu().numpy() pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) # Resize back to SR size pred_mask = cv2.resize(pred_mask, (w_sr, h_sr), interpolation=cv2.INTER_NEAREST) # Step 3: Detection detections = detect_bacilli(pred_mask, min_area=min_area) # Create outputs # 1. SR image sr_output = Image.fromarray(sr_img) # 2. Segmentation mask seg_mask = (pred_mask * 255).astype(np.uint8) # 3. Detection with bounding boxes detection_img = sr_img.copy() for det in detections: x, y, w, h = det['bbox'] cv2.rectangle(detection_img, (x, y), (x+w, y+h), (0, 255, 0), 2) return sr_output, Image.fromarray(seg_mask), Image.fromarray(detection_img) # Gradio Interface with gr.Blocks(title="TB Bacilli Analysis System", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🔬 TB Bacilli Analysis System") gr.Markdown("Upload a microscopy image for tuberculosis bacilli detection and analysis") if device.type == "cpu": gr.Markdown("⚠️ **Running on CPU** - Processing will be slower. For best performance, use GPU.") with gr.Tab("🎯 Basic Segmentation"): with gr.Row(): seg_input = gr.Image(type="pil", label="Input Microscopy Image") seg_output = gr.Image(type="pil", label="Segmentation Mask") seg_info = gr.Textbox(label="Info", interactive=False) with gr.Row(): seg_btn = gr.Button("Segment Image", variant="primary") seg_compare_btn = gr.Button("Compare Side-by-Side") seg_btn.click(segment_image, inputs=seg_input, outputs=[seg_output, seg_info]) seg_compare_btn.click(segment_comparison, inputs=seg_input, outputs=[seg_output, seg_info]) with gr.Tab("📊 Detection & Analysis"): with gr.Row(): det_input = gr.Image(type="pil", label="Input Microscopy Image") det_output = gr.Image(type="pil", label="Detection Result") det_info = gr.Textbox(label="Detection Info", interactive=False) min_area_slider = gr.Slider(minimum=1, maximum=20, value=3, step=1, label="Minimum Bacilli Area (pixels)") det_btn = gr.Button("Detect Bacilli", variant="primary") det_btn.click(segment_and_detect, inputs=[det_input, min_area_slider], outputs=[det_output, det_info]) with gr.Tab("🎨 Overlay Visualization"): with gr.Row(): overlay_input = gr.Image(type="pil", label="Input Microscopy Image") overlay_output = gr.Image(type="pil", label="Overlay Result") overlay_info = gr.Textbox(label="Info", interactive=False) alpha_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.1, label="Overlay Transparency") overlay_btn = gr.Button("Create Overlay", variant="primary") overlay_btn.click(segment_overlay, inputs=[overlay_input, alpha_slider], outputs=[overlay_output, overlay_info]) with gr.Tab("⚙️ Segmentation Pipeline"): with gr.Row(): pipe_input = gr.Image(type="pil", label="Input Microscopy Image") pipe_min_area = gr.Slider(minimum=1, maximum=20, value=3, step=1, label="Minimum Bacilli Area (pixels)") with gr.Row(): pipe_seg_output = gr.Image(type="pil", label="Segmentation Mask") pipe_det_output = gr.Image(type="pil", label="Detection with Bounding Boxes") pipe_overlay_output = gr.Image(type="pil", label="Overlay Visualization") pipe_btn = gr.Button("Run Segmentation Analysis", variant="primary") pipe_btn.click(full_segmentation_pipeline, inputs=[pipe_input, pipe_min_area], outputs=[pipe_seg_output, pipe_det_output, pipe_overlay_output]) with gr.Tab("🔍 LR to SR Conversion"): gr.Markdown("### Super-Resolution Enhancement") with gr.Row(): lr_input = gr.Image(type="pil", label="Input Low-Resolution Image") sr_output = gr.Image(type="pil", label="Super-Resolution Output") sr_info = gr.Textbox(label="Info", interactive=False) with gr.Row(): sr_btn = gr.Button("Convert to SR", variant="primary") compare_btn = gr.Button("Compare with Bicubic") sr_btn.click(lr_to_sr, inputs=lr_input, outputs=[sr_output, sr_info]) compare_btn.click(lr_to_sr_comparison, inputs=lr_input, outputs=[sr_output, sr_info]) with gr.Tab("🚀 Complete Pipeline (LR→SR→Segmentation)"): gr.Markdown("### Full end-to-end pipeline: Low-Resolution → Super-Resolution → Segmentation → Detection") with gr.Row(): full_input = gr.Image(type="pil", label="Input Low-Resolution Image") full_min_area = gr.Slider(minimum=1, maximum=20, value=3, step=1, label="Minimum Bacilli Area (pixels)") with gr.Row(): full_sr_output = gr.Image(type="pil", label="Super-Resolution Result") full_seg_output = gr.Image(type="pil", label="Segmentation Mask") full_det_output = gr.Image(type="pil", label="Detection with Bounding Boxes") full_btn = gr.Button("Run Complete Pipeline", variant="primary", size="lg") full_btn.click(complete_pipeline, inputs=[full_input, full_min_area], outputs=[full_sr_output, full_seg_output, full_det_output]) if __name__ == "__main__": demo.launch()