broadinstitute/axonet-neuromorpho-dataset
Updated • 134
A variational autoencoder for semantic segmentation of neuronal morphologies from 2D projections.
AxoNet VAE is a U-Net architecture with:
| Index | Class | Description |
|---|---|---|
| 0 | background | Non-neuron pixels |
| 1 | soma | Cell body |
| 2 | axon | Axonal processes |
| 3 | basal_dendrite | Basal dendritic arbor |
| 4 | apical_dendrite | Apical dendritic arbor |
| 5 | other | Unclassified neurite |
import torch
from huggingface_hub import hf_hub_download
# Download model
model_path = hf_hub_download(
repo_id="broadinstitute/axonet-vae-stage1",
filename="pytorch_model.bin"
)
# Load weights (requires axonet package)
from axonet.models.d3_swc_vae import SegVAE2D
model = SegVAE2D(
in_channels=1,
base_channels=64,
num_classes=6,
latent_channels=128,
skip_mode="variational",
)
model.load_state_dict(torch.load(model_path))
model.eval()
import torch
from PIL import Image
import numpy as np
# Load and preprocess image
img = Image.open("neuron_mask.png").convert("L")
img = img.resize((512, 512))
tensor = torch.from_numpy(np.array(img) / 255.0).float()
tensor = tensor.unsqueeze(0).unsqueeze(0) # (1, 1, 512, 512)
# Run inference
with torch.no_grad():
outputs = model(tensor, return_latent=True)
segmentation = outputs["seg_logits"].argmax(dim=1) # (1, 512, 512)
depth = outputs["depth"] # (1, 1, 512, 512)
embedding = outputs["mu"].mean(dim=(2, 3)) # (1, 128) - latent embedding
# Get neuron embedding for downstream tasks
with torch.no_grad():
z, mu, logvar, _, _, _ = model.encode(tensor)
embedding = mu.mean(dim=(2, 3)) # Global average pooling
| File | Description |
|---|---|
pytorch_model.bin |
PyTorch state dict |
model.safetensors |
Safetensors format (recommended) |
config.json |
Model configuration |
full_checkpoint/best.ckpt |
Full Lightning checkpoint |
This model serves as the encoder for:
@misc{axonet2025,
author = {Hall, Giles},
title = {AxoNet: Multimodal Neuron Morphology Embeddings via 2D Projections},
year = {2025},
publisher = {HuggingFace},
howpublished = {\url{https://huggingface.co/broadinstitute/axonet-vae-stage1}}
}
MIT License