| """ |
| TB Bacilli Analysis System - Hugging Face Spaces Deployment |
| Complete version with all 6 tabs |
| """ |
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import cv2 |
| from PIL import Image |
| from transformers import SegformerForSemanticSegmentation |
| import math |
| import os |
|
|
| |
| class ResidualBlockNoBN(nn.Module): |
| def __init__(self, n_feats, res_scale=0.1): |
| super().__init__() |
| self.res_scale = res_scale |
| self.conv1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) |
| self.conv2 = nn.Conv2d(n_feats, n_feats, 3, 1, 1) |
| self.relu = nn.ReLU(inplace=True) |
|
|
| def forward(self, x): |
| res = self.conv1(x) |
| res = self.relu(res) |
| res = self.conv2(res) |
| return x + res * self.res_scale |
|
|
| class EDSR(nn.Module): |
| def __init__(self, in_channels=3, out_channels=3, scale=4, n_resblocks=32, n_feats=128, rgb_range=1.0): |
| super().__init__() |
| self.scale = scale |
| self.rgb_range = rgb_range |
| self.conv_head = nn.Conv2d(in_channels, n_feats, 3, 1, 1) |
| body = [ResidualBlockNoBN(n_feats) for _ in range(n_resblocks)] |
| self.body = nn.Sequential(*body) |
| self.conv_body = nn.Conv2d(n_feats, n_feats, 3, 1, 1) |
| tail = [] |
| n_upscale = int(math.log2(scale)) |
| for _ in range(n_upscale): |
| tail.append(nn.Conv2d(n_feats, n_feats * 4, 3, 1, 1)) |
| tail.append(nn.PixelShuffle(2)) |
| tail.append(nn.ReLU(inplace=True)) |
| tail.append(nn.Conv2d(n_feats, out_channels, 3, 1, 1)) |
| self.tail = nn.Sequential(*tail) |
|
|
| def forward(self, x): |
| x = x * self.rgb_range |
| x = self.conv_head(x) |
| res = self.body(x) |
| res = self.conv_body(res) |
| x = x + res |
| x = self.tail(x) |
| x = torch.clamp(x / self.rgb_range, 0.0, 1.0) |
| return x |
|
|
| |
| class SegFormerTB(torch.nn.Module): |
| def __init__(self, model_name, num_classes=1): |
| super().__init__() |
| self.segformer = SegformerForSemanticSegmentation.from_pretrained( |
| model_name, |
| num_labels=num_classes, |
| ignore_mismatched_sizes=True |
| ) |
| |
| def forward(self, pixel_values): |
| outputs = self.segformer(pixel_values=pixel_values) |
| logits = outputs.logits |
| logits = F.interpolate( |
| logits, |
| size=pixel_values.shape[-2:], |
| mode='bilinear', |
| align_corners=False |
| ) |
| return logits |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| EDSR_SCALE = 4 |
| EDSR_PATCH_SIZE = 256 |
| EDSR_OVERLAP = 32 |
| MODEL_NAME = "nvidia/segformer-b3-finetuned-ade-512-512" |
| NUM_CLASSES = 1 |
| IMAGE_SIZE = 512 |
| CONFIDENCE_THRESHOLD = 0.5 |
| MIN_BACILLI_AREA = 3 |
|
|
| |
| print("Loading EDSR model...") |
| edsr_model = EDSR(in_channels=3, out_channels=3, scale=4, n_resblocks=16, n_feats=64).to(device) |
| if os.path.exists("models/edsr_ft_best.pth"): |
| edsr_checkpoint = torch.load("models/edsr_ft_best.pth", map_location=device, weights_only=False) |
| edsr_model.load_state_dict(edsr_checkpoint["state_dict"]) |
| edsr_model.eval() |
| print("EDSR model loaded!") |
| else: |
| print("⚠️ EDSR model not found. SR features will be disabled.") |
| edsr_model = None |
|
|
| print("Loading SegFormer model...") |
| segformer_model = SegFormerTB(MODEL_NAME, NUM_CLASSES).to(device) |
| if os.path.exists("models/best_model.pth"): |
| seg_checkpoint = torch.load("models/best_model.pth", map_location=device, weights_only=False) |
| segformer_model.load_state_dict(seg_checkpoint['model_state_dict']) |
| segformer_model.eval() |
| print("SegFormer model loaded!") |
| else: |
| print("⚠️ SegFormer model not found. Segmentation features will be disabled.") |
| segformer_model = None |
|
|
| |
| def detect_bacilli(mask, min_area=3): |
| num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( |
| mask.astype(np.uint8), connectivity=8 |
| ) |
| detections = [] |
| for i in range(1, num_labels): |
| area = stats[i, cv2.CC_STAT_AREA] |
| if area >= min_area: |
| x = stats[i, cv2.CC_STAT_LEFT] |
| y = stats[i, cv2.CC_STAT_TOP] |
| w = stats[i, cv2.CC_STAT_WIDTH] |
| h = stats[i, cv2.CC_STAT_HEIGHT] |
| detections.append({ |
| 'bbox': (x, y, w, h), |
| 'area': area, |
| 'centroid': centroids[i] |
| }) |
| return detections |
|
|
| def preprocess_image(image): |
| img = np.array(image) |
| if len(img.shape) == 2: |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
| elif img.shape[2] == 4: |
| img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) |
| return img |
|
|
| |
| def lr_to_sr(image): |
| if edsr_model is None: |
| return None, "⚠️ EDSR model not available" |
| |
| img = preprocess_image(image) |
| h_orig, w_orig = img.shape[:2] |
| |
| |
| patch_size = EDSR_PATCH_SIZE |
| overlap = EDSR_OVERLAP |
| scale = EDSR_SCALE |
| |
| if h_orig > patch_size or w_orig > patch_size: |
| h, w = img.shape[:2] |
| sr_h, sr_w = h * scale, w * scale |
| sr_img = np.zeros((sr_h, sr_w, 3), dtype=np.float32) |
| weight_map = np.zeros((sr_h, sr_w, 3), dtype=np.float32) |
| |
| for y in range(0, h, patch_size - overlap): |
| for x in range(0, w, patch_size - overlap): |
| y_end = min(y + patch_size, h) |
| x_end = min(x + patch_size, w) |
| patch = img[y:y_end, x:x_end] |
| |
| patch_tensor = torch.from_numpy(patch).permute(2, 0, 1).float() / 255.0 |
| patch_tensor = patch_tensor.unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| sr_patch_tensor = edsr_model(patch_tensor) |
| sr_patch = sr_patch_tensor.clamp(0.0, 1.0).squeeze(0).permute(1, 2, 0).cpu().numpy() |
| |
| sr_y, sr_x = y * scale, x * scale |
| sr_y_end, sr_x_end = sr_y + sr_patch.shape[0], sr_x + sr_patch.shape[1] |
| |
| weight = np.ones_like(sr_patch) |
| if overlap > 0: |
| fade = overlap * scale |
| for i in range(fade): |
| alpha = i / fade |
| if sr_y + i < sr_h: |
| weight[i, :, :] *= alpha |
| if sr_y_end - i - 1 >= 0: |
| weight[-i-1, :, :] *= alpha |
| if sr_x + i < sr_w: |
| weight[:, i, :] *= alpha |
| if sr_x_end - i - 1 >= 0: |
| weight[:, -i-1, :] *= alpha |
| |
| sr_img[sr_y:sr_y_end, sr_x:sr_x_end] += sr_patch * weight |
| weight_map[sr_y:sr_y_end, sr_x:sr_x_end] += weight |
| |
| sr_img = np.divide(sr_img, weight_map, where=weight_map > 0) |
| sr_img = (sr_img * 255).astype(np.uint8) |
| info = f"Input: {w_orig}x{h_orig} → Output: {sr_w}x{sr_h} (Patch-wise, Scale: {scale}x)" |
| else: |
| img_tensor = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 |
| img_tensor = img_tensor.unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| sr_tensor = edsr_model(img_tensor) |
| sr_tensor = sr_tensor.clamp(0.0, 1.0) |
| sr_img = (sr_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) |
| |
| h_sr, w_sr = sr_img.shape[:2] |
| info = f"Input: {w_orig}x{h_orig} → Output: {w_sr}x{h_sr} (Scale: {scale}x)" |
| |
| return Image.fromarray(sr_img), info |
|
|
| def lr_to_sr_comparison(image): |
| if edsr_model is None: |
| return None, "⚠️ EDSR model not available" |
| |
| sr_img_pil, info = lr_to_sr(image) |
| if sr_img_pil is None: |
| return None, info |
| |
| sr_img = np.array(sr_img_pil) |
| img = preprocess_image(image) |
| h_orig, w_orig = img.shape[:2] |
| h_sr, w_sr = sr_img.shape[:2] |
| |
| img_upscaled = cv2.resize(img, (w_sr, h_sr), interpolation=cv2.INTER_CUBIC) |
| comparison = np.hstack([img_upscaled, sr_img]) |
| info = f"Left: Bicubic ({w_sr}x{h_sr}) | Right: EDSR SR ({w_sr}x{h_sr}) | Original: {w_orig}x{h_orig}" |
| |
| return Image.fromarray(comparison), info |
|
|
| def segment_image(image): |
| if segformer_model is None: |
| return None, "⚠️ SegFormer model not available" |
| |
| img = preprocess_image(image) |
| h_orig, w_orig = img.shape[:2] |
| |
| img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) |
| img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0) |
| img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| output = segformer_model(img_tensor) |
| pred_prob = torch.sigmoid(output).squeeze().cpu().numpy() |
| pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) * 255 |
| |
| pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST) |
| |
| info = f"Input: {w_orig}x{h_orig} | Segmentation completed" |
| return Image.fromarray(pred_mask), info |
|
|
| def segment_and_detect(image, min_area=3): |
| if segformer_model is None: |
| return None, "⚠️ SegFormer model not available" |
| |
| img = preprocess_image(image) |
| h_orig, w_orig = img.shape[:2] |
| |
| img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) |
| img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0) |
| img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| output = segformer_model(img_tensor) |
| pred_prob = torch.sigmoid(output).squeeze().cpu().numpy() |
| pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) |
| |
| pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST) |
| detections = detect_bacilli(pred_mask, min_area=min_area) |
| |
| result_img = img.copy() |
| for det in detections: |
| x, y, w, h = det['bbox'] |
| cv2.rectangle(result_img, (x, y), (x+w, y+h), (0, 255, 0), 2) |
| cv2.putText(result_img, f"{det['area']}", (x, y-5), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) |
| |
| total_mask_pixels = np.sum(pred_mask > 0) |
| detected_pixels = sum(det['area'] for det in detections) |
| coverage = (detected_pixels / total_mask_pixels * 100) if total_mask_pixels > 0 else 0 |
| |
| info = f"Detected {len(detections)} TB bacilli | Mask pixels: {total_mask_pixels} | Coverage: {coverage:.1f}%" |
| return Image.fromarray(result_img), info |
|
|
| def segment_comparison(image): |
| if segformer_model is None: |
| return None, "⚠️ SegFormer model not available" |
| |
| img = preprocess_image(image) |
| h_orig, w_orig = img.shape[:2] |
| |
| img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) |
| img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0) |
| img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| output = segformer_model(img_tensor) |
| pred_prob = torch.sigmoid(output).squeeze().cpu().numpy() |
| pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) * 255 |
| |
| pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST) |
| pred_mask_rgb = cv2.cvtColor(pred_mask, cv2.COLOR_GRAY2RGB) |
| comparison = np.hstack([img, pred_mask_rgb]) |
| |
| info = f"Left: Original Image ({w_orig}x{h_orig}) | Right: Segmentation Mask" |
| return Image.fromarray(comparison), info |
|
|
| def segment_overlay(image, alpha=0.5): |
| if segformer_model is None: |
| return None, "⚠️ SegFormer model not available" |
| |
| img = preprocess_image(image) |
| h_orig, w_orig = img.shape[:2] |
| |
| img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) |
| img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0) |
| img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| output = segformer_model(img_tensor) |
| pred_prob = torch.sigmoid(output).squeeze().cpu().numpy() |
| pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) |
| |
| pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST) |
| |
| overlay = img.copy() |
| overlay[pred_mask > 0] = [0, 255, 0] |
| result = cv2.addWeighted(img, 1-alpha, overlay, alpha, 0) |
| |
| bacilli_pixels = np.sum(pred_mask > 0) |
| info = f"Overlay with {bacilli_pixels} bacilli pixels | Alpha: {alpha}" |
| return Image.fromarray(result), info |
|
|
| def full_segmentation_pipeline(image, min_area=3): |
| if segformer_model is None: |
| return None, None, None |
| |
| img = preprocess_image(image) |
| h_orig, w_orig = img.shape[:2] |
| |
| img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) |
| img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0) |
| img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| output = segformer_model(img_tensor) |
| pred_prob = torch.sigmoid(output).squeeze().cpu().numpy() |
| pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) |
| |
| pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST) |
| |
| seg_mask = (pred_mask * 255).astype(np.uint8) |
| |
| detections = detect_bacilli(pred_mask, min_area=min_area) |
| detection_img = img.copy() |
| for det in detections: |
| x, y, w, h = det['bbox'] |
| cv2.rectangle(detection_img, (x, y), (x+w, y+h), (0, 255, 0), 2) |
| |
| overlay = img.copy() |
| overlay[pred_mask > 0] = [0, 255, 0] |
| overlay_img = cv2.addWeighted(img, 0.6, overlay, 0.4, 0) |
| |
| return Image.fromarray(seg_mask), Image.fromarray(detection_img), Image.fromarray(overlay_img) |
|
|
| def complete_pipeline(image, min_area=3): |
| """Full pipeline: LR to SR, then segmentation and detection""" |
| if edsr_model is None or segformer_model is None: |
| return None, None, None |
| |
| |
| sr_img_pil, sr_info = lr_to_sr(image) |
| if sr_img_pil is None: |
| return None, None, None |
| |
| sr_img = np.array(sr_img_pil) |
| h_sr, w_sr = sr_img.shape[:2] |
| |
| |
| img_resized = cv2.resize(sr_img, (IMAGE_SIZE, IMAGE_SIZE)) |
| img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0) |
| img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| output = segformer_model(img_tensor) |
| pred_prob = torch.sigmoid(output).squeeze().cpu().numpy() |
| pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) |
| |
| |
| pred_mask = cv2.resize(pred_mask, (w_sr, h_sr), interpolation=cv2.INTER_NEAREST) |
| |
| |
| detections = detect_bacilli(pred_mask, min_area=min_area) |
| |
| |
| |
| sr_output = Image.fromarray(sr_img) |
| |
| |
| seg_mask = (pred_mask * 255).astype(np.uint8) |
| |
| |
| detection_img = sr_img.copy() |
| for det in detections: |
| x, y, w, h = det['bbox'] |
| cv2.rectangle(detection_img, (x, y), (x+w, y+h), (0, 255, 0), 2) |
| |
| return sr_output, Image.fromarray(seg_mask), Image.fromarray(detection_img) |
|
|
| |
| with gr.Blocks(title="TB Bacilli Analysis System", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 🔬 TB Bacilli Analysis System") |
| gr.Markdown("Upload a microscopy image for tuberculosis bacilli detection and analysis") |
| |
| if device.type == "cpu": |
| gr.Markdown("⚠️ **Running on CPU** - Processing will be slower. For best performance, use GPU.") |
| |
| with gr.Tab("🎯 Basic Segmentation"): |
| with gr.Row(): |
| seg_input = gr.Image(type="pil", label="Input Microscopy Image") |
| seg_output = gr.Image(type="pil", label="Segmentation Mask") |
| seg_info = gr.Textbox(label="Info", interactive=False) |
| with gr.Row(): |
| seg_btn = gr.Button("Segment Image", variant="primary") |
| seg_compare_btn = gr.Button("Compare Side-by-Side") |
| |
| seg_btn.click(segment_image, inputs=seg_input, outputs=[seg_output, seg_info]) |
| seg_compare_btn.click(segment_comparison, inputs=seg_input, outputs=[seg_output, seg_info]) |
| |
| with gr.Tab("📊 Detection & Analysis"): |
| with gr.Row(): |
| det_input = gr.Image(type="pil", label="Input Microscopy Image") |
| det_output = gr.Image(type="pil", label="Detection Result") |
| det_info = gr.Textbox(label="Detection Info", interactive=False) |
| min_area_slider = gr.Slider(minimum=1, maximum=20, value=3, step=1, |
| label="Minimum Bacilli Area (pixels)") |
| det_btn = gr.Button("Detect Bacilli", variant="primary") |
| |
| det_btn.click(segment_and_detect, inputs=[det_input, min_area_slider], |
| outputs=[det_output, det_info]) |
| |
| with gr.Tab("🎨 Overlay Visualization"): |
| with gr.Row(): |
| overlay_input = gr.Image(type="pil", label="Input Microscopy Image") |
| overlay_output = gr.Image(type="pil", label="Overlay Result") |
| overlay_info = gr.Textbox(label="Info", interactive=False) |
| alpha_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.1, |
| label="Overlay Transparency") |
| overlay_btn = gr.Button("Create Overlay", variant="primary") |
| |
| overlay_btn.click(segment_overlay, inputs=[overlay_input, alpha_slider], |
| outputs=[overlay_output, overlay_info]) |
| |
| with gr.Tab("⚙️ Segmentation Pipeline"): |
| with gr.Row(): |
| pipe_input = gr.Image(type="pil", label="Input Microscopy Image") |
| pipe_min_area = gr.Slider(minimum=1, maximum=20, value=3, step=1, |
| label="Minimum Bacilli Area (pixels)") |
| with gr.Row(): |
| pipe_seg_output = gr.Image(type="pil", label="Segmentation Mask") |
| pipe_det_output = gr.Image(type="pil", label="Detection with Bounding Boxes") |
| pipe_overlay_output = gr.Image(type="pil", label="Overlay Visualization") |
| pipe_btn = gr.Button("Run Segmentation Analysis", variant="primary") |
| |
| pipe_btn.click(full_segmentation_pipeline, inputs=[pipe_input, pipe_min_area], |
| outputs=[pipe_seg_output, pipe_det_output, pipe_overlay_output]) |
| |
| with gr.Tab("🔍 LR to SR Conversion"): |
| gr.Markdown("### Super-Resolution Enhancement") |
| with gr.Row(): |
| lr_input = gr.Image(type="pil", label="Input Low-Resolution Image") |
| sr_output = gr.Image(type="pil", label="Super-Resolution Output") |
| sr_info = gr.Textbox(label="Info", interactive=False) |
| with gr.Row(): |
| sr_btn = gr.Button("Convert to SR", variant="primary") |
| compare_btn = gr.Button("Compare with Bicubic") |
| |
| sr_btn.click(lr_to_sr, inputs=lr_input, outputs=[sr_output, sr_info]) |
| compare_btn.click(lr_to_sr_comparison, inputs=lr_input, outputs=[sr_output, sr_info]) |
| |
| with gr.Tab("🚀 Complete Pipeline (LR→SR→Segmentation)"): |
| gr.Markdown("### Full end-to-end pipeline: Low-Resolution → Super-Resolution → Segmentation → Detection") |
| with gr.Row(): |
| full_input = gr.Image(type="pil", label="Input Low-Resolution Image") |
| full_min_area = gr.Slider(minimum=1, maximum=20, value=3, step=1, |
| label="Minimum Bacilli Area (pixels)") |
| with gr.Row(): |
| full_sr_output = gr.Image(type="pil", label="Super-Resolution Result") |
| full_seg_output = gr.Image(type="pil", label="Segmentation Mask") |
| full_det_output = gr.Image(type="pil", label="Detection with Bounding Boxes") |
| full_btn = gr.Button("Run Complete Pipeline", variant="primary", size="lg") |
| |
| full_btn.click(complete_pipeline, inputs=[full_input, full_min_area], |
| outputs=[full_sr_output, full_seg_output, full_det_output]) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|