praveen-solanki's picture
Update app.py
fdcb596 verified
"""
TB Bacilli Analysis System - Hugging Face Spaces Deployment
Complete version with all 6 tabs
"""
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
from transformers import SegformerForSemanticSegmentation
import math
import os
# EDSR Model for Super-Resolution
class ResidualBlockNoBN(nn.Module):
def __init__(self, n_feats, res_scale=0.1):
super().__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(n_feats, n_feats, 3, 1, 1)
self.conv2 = nn.Conv2d(n_feats, n_feats, 3, 1, 1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
res = self.conv1(x)
res = self.relu(res)
res = self.conv2(res)
return x + res * self.res_scale
class EDSR(nn.Module):
def __init__(self, in_channels=3, out_channels=3, scale=4, n_resblocks=32, n_feats=128, rgb_range=1.0):
super().__init__()
self.scale = scale
self.rgb_range = rgb_range
self.conv_head = nn.Conv2d(in_channels, n_feats, 3, 1, 1)
body = [ResidualBlockNoBN(n_feats) for _ in range(n_resblocks)]
self.body = nn.Sequential(*body)
self.conv_body = nn.Conv2d(n_feats, n_feats, 3, 1, 1)
tail = []
n_upscale = int(math.log2(scale))
for _ in range(n_upscale):
tail.append(nn.Conv2d(n_feats, n_feats * 4, 3, 1, 1))
tail.append(nn.PixelShuffle(2))
tail.append(nn.ReLU(inplace=True))
tail.append(nn.Conv2d(n_feats, out_channels, 3, 1, 1))
self.tail = nn.Sequential(*tail)
def forward(self, x):
x = x * self.rgb_range
x = self.conv_head(x)
res = self.body(x)
res = self.conv_body(res)
x = x + res
x = self.tail(x)
x = torch.clamp(x / self.rgb_range, 0.0, 1.0)
return x
# SegFormer Model for TB Bacilli Segmentation
class SegFormerTB(torch.nn.Module):
def __init__(self, model_name, num_classes=1):
super().__init__()
self.segformer = SegformerForSemanticSegmentation.from_pretrained(
model_name,
num_labels=num_classes,
ignore_mismatched_sizes=True
)
def forward(self, pixel_values):
outputs = self.segformer(pixel_values=pixel_values)
logits = outputs.logits
logits = F.interpolate(
logits,
size=pixel_values.shape[-2:],
mode='bilinear',
align_corners=False
)
return logits
# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
EDSR_SCALE = 4
EDSR_PATCH_SIZE = 256
EDSR_OVERLAP = 32
MODEL_NAME = "nvidia/segformer-b3-finetuned-ade-512-512"
NUM_CLASSES = 1
IMAGE_SIZE = 512
CONFIDENCE_THRESHOLD = 0.5
MIN_BACILLI_AREA = 3
# Load models
print("Loading EDSR model...")
edsr_model = EDSR(in_channels=3, out_channels=3, scale=4, n_resblocks=16, n_feats=64).to(device)
if os.path.exists("models/edsr_ft_best.pth"):
edsr_checkpoint = torch.load("models/edsr_ft_best.pth", map_location=device, weights_only=False)
edsr_model.load_state_dict(edsr_checkpoint["state_dict"])
edsr_model.eval()
print("EDSR model loaded!")
else:
print("⚠️ EDSR model not found. SR features will be disabled.")
edsr_model = None
print("Loading SegFormer model...")
segformer_model = SegFormerTB(MODEL_NAME, NUM_CLASSES).to(device)
if os.path.exists("models/best_model.pth"):
seg_checkpoint = torch.load("models/best_model.pth", map_location=device, weights_only=False)
segformer_model.load_state_dict(seg_checkpoint['model_state_dict'])
segformer_model.eval()
print("SegFormer model loaded!")
else:
print("⚠️ SegFormer model not found. Segmentation features will be disabled.")
segformer_model = None
# Helper functions
def detect_bacilli(mask, min_area=3):
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
mask.astype(np.uint8), connectivity=8
)
detections = []
for i in range(1, num_labels):
area = stats[i, cv2.CC_STAT_AREA]
if area >= min_area:
x = stats[i, cv2.CC_STAT_LEFT]
y = stats[i, cv2.CC_STAT_TOP]
w = stats[i, cv2.CC_STAT_WIDTH]
h = stats[i, cv2.CC_STAT_HEIGHT]
detections.append({
'bbox': (x, y, w, h),
'area': area,
'centroid': centroids[i]
})
return detections
def preprocess_image(image):
img = np.array(image)
if len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4:
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
return img
# Feature functions
def lr_to_sr(image):
if edsr_model is None:
return None, "⚠️ EDSR model not available"
img = preprocess_image(image)
h_orig, w_orig = img.shape[:2]
# Patch-wise processing for large images
patch_size = EDSR_PATCH_SIZE
overlap = EDSR_OVERLAP
scale = EDSR_SCALE
if h_orig > patch_size or w_orig > patch_size:
h, w = img.shape[:2]
sr_h, sr_w = h * scale, w * scale
sr_img = np.zeros((sr_h, sr_w, 3), dtype=np.float32)
weight_map = np.zeros((sr_h, sr_w, 3), dtype=np.float32)
for y in range(0, h, patch_size - overlap):
for x in range(0, w, patch_size - overlap):
y_end = min(y + patch_size, h)
x_end = min(x + patch_size, w)
patch = img[y:y_end, x:x_end]
patch_tensor = torch.from_numpy(patch).permute(2, 0, 1).float() / 255.0
patch_tensor = patch_tensor.unsqueeze(0).to(device)
with torch.no_grad():
sr_patch_tensor = edsr_model(patch_tensor)
sr_patch = sr_patch_tensor.clamp(0.0, 1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
sr_y, sr_x = y * scale, x * scale
sr_y_end, sr_x_end = sr_y + sr_patch.shape[0], sr_x + sr_patch.shape[1]
weight = np.ones_like(sr_patch)
if overlap > 0:
fade = overlap * scale
for i in range(fade):
alpha = i / fade
if sr_y + i < sr_h:
weight[i, :, :] *= alpha
if sr_y_end - i - 1 >= 0:
weight[-i-1, :, :] *= alpha
if sr_x + i < sr_w:
weight[:, i, :] *= alpha
if sr_x_end - i - 1 >= 0:
weight[:, -i-1, :] *= alpha
sr_img[sr_y:sr_y_end, sr_x:sr_x_end] += sr_patch * weight
weight_map[sr_y:sr_y_end, sr_x:sr_x_end] += weight
sr_img = np.divide(sr_img, weight_map, where=weight_map > 0)
sr_img = (sr_img * 255).astype(np.uint8)
info = f"Input: {w_orig}x{h_orig} → Output: {sr_w}x{sr_h} (Patch-wise, Scale: {scale}x)"
else:
img_tensor = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
img_tensor = img_tensor.unsqueeze(0).to(device)
with torch.no_grad():
sr_tensor = edsr_model(img_tensor)
sr_tensor = sr_tensor.clamp(0.0, 1.0)
sr_img = (sr_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
h_sr, w_sr = sr_img.shape[:2]
info = f"Input: {w_orig}x{h_orig} → Output: {w_sr}x{h_sr} (Scale: {scale}x)"
return Image.fromarray(sr_img), info
def lr_to_sr_comparison(image):
if edsr_model is None:
return None, "⚠️ EDSR model not available"
sr_img_pil, info = lr_to_sr(image)
if sr_img_pil is None:
return None, info
sr_img = np.array(sr_img_pil)
img = preprocess_image(image)
h_orig, w_orig = img.shape[:2]
h_sr, w_sr = sr_img.shape[:2]
img_upscaled = cv2.resize(img, (w_sr, h_sr), interpolation=cv2.INTER_CUBIC)
comparison = np.hstack([img_upscaled, sr_img])
info = f"Left: Bicubic ({w_sr}x{h_sr}) | Right: EDSR SR ({w_sr}x{h_sr}) | Original: {w_orig}x{h_orig}"
return Image.fromarray(comparison), info
def segment_image(image):
if segformer_model is None:
return None, "⚠️ SegFormer model not available"
img = preprocess_image(image)
h_orig, w_orig = img.shape[:2]
img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0)
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
output = segformer_model(img_tensor)
pred_prob = torch.sigmoid(output).squeeze().cpu().numpy()
pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) * 255
pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
info = f"Input: {w_orig}x{h_orig} | Segmentation completed"
return Image.fromarray(pred_mask), info
def segment_and_detect(image, min_area=3):
if segformer_model is None:
return None, "⚠️ SegFormer model not available"
img = preprocess_image(image)
h_orig, w_orig = img.shape[:2]
img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0)
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
output = segformer_model(img_tensor)
pred_prob = torch.sigmoid(output).squeeze().cpu().numpy()
pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8)
pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
detections = detect_bacilli(pred_mask, min_area=min_area)
result_img = img.copy()
for det in detections:
x, y, w, h = det['bbox']
cv2.rectangle(result_img, (x, y), (x+w, y+h), (0, 255, 0), 2)
cv2.putText(result_img, f"{det['area']}", (x, y-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
total_mask_pixels = np.sum(pred_mask > 0)
detected_pixels = sum(det['area'] for det in detections)
coverage = (detected_pixels / total_mask_pixels * 100) if total_mask_pixels > 0 else 0
info = f"Detected {len(detections)} TB bacilli | Mask pixels: {total_mask_pixels} | Coverage: {coverage:.1f}%"
return Image.fromarray(result_img), info
def segment_comparison(image):
if segformer_model is None:
return None, "⚠️ SegFormer model not available"
img = preprocess_image(image)
h_orig, w_orig = img.shape[:2]
img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0)
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
output = segformer_model(img_tensor)
pred_prob = torch.sigmoid(output).squeeze().cpu().numpy()
pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8) * 255
pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
pred_mask_rgb = cv2.cvtColor(pred_mask, cv2.COLOR_GRAY2RGB)
comparison = np.hstack([img, pred_mask_rgb])
info = f"Left: Original Image ({w_orig}x{h_orig}) | Right: Segmentation Mask"
return Image.fromarray(comparison), info
def segment_overlay(image, alpha=0.5):
if segformer_model is None:
return None, "⚠️ SegFormer model not available"
img = preprocess_image(image)
h_orig, w_orig = img.shape[:2]
img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0)
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
output = segformer_model(img_tensor)
pred_prob = torch.sigmoid(output).squeeze().cpu().numpy()
pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8)
pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
overlay = img.copy()
overlay[pred_mask > 0] = [0, 255, 0]
result = cv2.addWeighted(img, 1-alpha, overlay, alpha, 0)
bacilli_pixels = np.sum(pred_mask > 0)
info = f"Overlay with {bacilli_pixels} bacilli pixels | Alpha: {alpha}"
return Image.fromarray(result), info
def full_segmentation_pipeline(image, min_area=3):
if segformer_model is None:
return None, None, None
img = preprocess_image(image)
h_orig, w_orig = img.shape[:2]
img_resized = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0)
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
output = segformer_model(img_tensor)
pred_prob = torch.sigmoid(output).squeeze().cpu().numpy()
pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8)
pred_mask = cv2.resize(pred_mask, (w_orig, h_orig), interpolation=cv2.INTER_NEAREST)
seg_mask = (pred_mask * 255).astype(np.uint8)
detections = detect_bacilli(pred_mask, min_area=min_area)
detection_img = img.copy()
for det in detections:
x, y, w, h = det['bbox']
cv2.rectangle(detection_img, (x, y), (x+w, y+h), (0, 255, 0), 2)
overlay = img.copy()
overlay[pred_mask > 0] = [0, 255, 0]
overlay_img = cv2.addWeighted(img, 0.6, overlay, 0.4, 0)
return Image.fromarray(seg_mask), Image.fromarray(detection_img), Image.fromarray(overlay_img)
def complete_pipeline(image, min_area=3):
"""Full pipeline: LR to SR, then segmentation and detection"""
if edsr_model is None or segformer_model is None:
return None, None, None
# Step 1: LR to SR
sr_img_pil, sr_info = lr_to_sr(image)
if sr_img_pil is None:
return None, None, None
sr_img = np.array(sr_img_pil)
h_sr, w_sr = sr_img.shape[:2]
# Step 2: Segmentation on SR image
img_resized = cv2.resize(sr_img, (IMAGE_SIZE, IMAGE_SIZE))
img_tensor = torch.from_numpy(img_resized.astype(np.float32) / 255.0)
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
output = segformer_model(img_tensor)
pred_prob = torch.sigmoid(output).squeeze().cpu().numpy()
pred_mask = (pred_prob > CONFIDENCE_THRESHOLD).astype(np.uint8)
# Resize back to SR size
pred_mask = cv2.resize(pred_mask, (w_sr, h_sr), interpolation=cv2.INTER_NEAREST)
# Step 3: Detection
detections = detect_bacilli(pred_mask, min_area=min_area)
# Create outputs
# 1. SR image
sr_output = Image.fromarray(sr_img)
# 2. Segmentation mask
seg_mask = (pred_mask * 255).astype(np.uint8)
# 3. Detection with bounding boxes
detection_img = sr_img.copy()
for det in detections:
x, y, w, h = det['bbox']
cv2.rectangle(detection_img, (x, y), (x+w, y+h), (0, 255, 0), 2)
return sr_output, Image.fromarray(seg_mask), Image.fromarray(detection_img)
# Gradio Interface
with gr.Blocks(title="TB Bacilli Analysis System", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🔬 TB Bacilli Analysis System")
gr.Markdown("Upload a microscopy image for tuberculosis bacilli detection and analysis")
if device.type == "cpu":
gr.Markdown("⚠️ **Running on CPU** - Processing will be slower. For best performance, use GPU.")
with gr.Tab("🎯 Basic Segmentation"):
with gr.Row():
seg_input = gr.Image(type="pil", label="Input Microscopy Image")
seg_output = gr.Image(type="pil", label="Segmentation Mask")
seg_info = gr.Textbox(label="Info", interactive=False)
with gr.Row():
seg_btn = gr.Button("Segment Image", variant="primary")
seg_compare_btn = gr.Button("Compare Side-by-Side")
seg_btn.click(segment_image, inputs=seg_input, outputs=[seg_output, seg_info])
seg_compare_btn.click(segment_comparison, inputs=seg_input, outputs=[seg_output, seg_info])
with gr.Tab("📊 Detection & Analysis"):
with gr.Row():
det_input = gr.Image(type="pil", label="Input Microscopy Image")
det_output = gr.Image(type="pil", label="Detection Result")
det_info = gr.Textbox(label="Detection Info", interactive=False)
min_area_slider = gr.Slider(minimum=1, maximum=20, value=3, step=1,
label="Minimum Bacilli Area (pixels)")
det_btn = gr.Button("Detect Bacilli", variant="primary")
det_btn.click(segment_and_detect, inputs=[det_input, min_area_slider],
outputs=[det_output, det_info])
with gr.Tab("🎨 Overlay Visualization"):
with gr.Row():
overlay_input = gr.Image(type="pil", label="Input Microscopy Image")
overlay_output = gr.Image(type="pil", label="Overlay Result")
overlay_info = gr.Textbox(label="Info", interactive=False)
alpha_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.1,
label="Overlay Transparency")
overlay_btn = gr.Button("Create Overlay", variant="primary")
overlay_btn.click(segment_overlay, inputs=[overlay_input, alpha_slider],
outputs=[overlay_output, overlay_info])
with gr.Tab("⚙️ Segmentation Pipeline"):
with gr.Row():
pipe_input = gr.Image(type="pil", label="Input Microscopy Image")
pipe_min_area = gr.Slider(minimum=1, maximum=20, value=3, step=1,
label="Minimum Bacilli Area (pixels)")
with gr.Row():
pipe_seg_output = gr.Image(type="pil", label="Segmentation Mask")
pipe_det_output = gr.Image(type="pil", label="Detection with Bounding Boxes")
pipe_overlay_output = gr.Image(type="pil", label="Overlay Visualization")
pipe_btn = gr.Button("Run Segmentation Analysis", variant="primary")
pipe_btn.click(full_segmentation_pipeline, inputs=[pipe_input, pipe_min_area],
outputs=[pipe_seg_output, pipe_det_output, pipe_overlay_output])
with gr.Tab("🔍 LR to SR Conversion"):
gr.Markdown("### Super-Resolution Enhancement")
with gr.Row():
lr_input = gr.Image(type="pil", label="Input Low-Resolution Image")
sr_output = gr.Image(type="pil", label="Super-Resolution Output")
sr_info = gr.Textbox(label="Info", interactive=False)
with gr.Row():
sr_btn = gr.Button("Convert to SR", variant="primary")
compare_btn = gr.Button("Compare with Bicubic")
sr_btn.click(lr_to_sr, inputs=lr_input, outputs=[sr_output, sr_info])
compare_btn.click(lr_to_sr_comparison, inputs=lr_input, outputs=[sr_output, sr_info])
with gr.Tab("🚀 Complete Pipeline (LR→SR→Segmentation)"):
gr.Markdown("### Full end-to-end pipeline: Low-Resolution → Super-Resolution → Segmentation → Detection")
with gr.Row():
full_input = gr.Image(type="pil", label="Input Low-Resolution Image")
full_min_area = gr.Slider(minimum=1, maximum=20, value=3, step=1,
label="Minimum Bacilli Area (pixels)")
with gr.Row():
full_sr_output = gr.Image(type="pil", label="Super-Resolution Result")
full_seg_output = gr.Image(type="pil", label="Segmentation Mask")
full_det_output = gr.Image(type="pil", label="Detection with Bounding Boxes")
full_btn = gr.Button("Run Complete Pipeline", variant="primary", size="lg")
full_btn.click(complete_pipeline, inputs=[full_input, full_min_area],
outputs=[full_sr_output, full_seg_output, full_det_output])
if __name__ == "__main__":
demo.launch()