Spaces:
Build error
Build error
| import os | |
| import gradio as gr | |
| import torch | |
| from monai import bundle | |
| from monai.transforms import ( | |
| Compose, | |
| LoadImaged, | |
| EnsureChannelFirstd, | |
| Orientationd, | |
| NormalizeIntensityd, | |
| Activationsd, | |
| AsDiscreted, | |
| ScaleIntensityd, | |
| ) | |
| # Define the bundle name and path for downloading | |
| BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0' | |
| BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME) | |
| # Title and description | |
| title = '<h1 style="text-align: center;">Segment Brain Tumors with MONAI! 🧠 </h1>' | |
| description = """ | |
| ## 🚀 To run | |
| Upload a brain MRI image file, or try out one of the examples below! | |
| If you want to see a different slice, update the slider. | |
| More details on the model can be found [here!](https://huggingface.co/katielink/brats_mri_segmentation_v0.1.0) | |
| ## ⚠️ Disclaimer | |
| This is an example, not to be used for diagnostic purposes. | |
| """ | |
| references = """ | |
| ## 👀 References | |
| 1. Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654. | |
| 2. Menze BH, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694 | |
| 3. Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117 | |
| """ | |
| examples = [ | |
| ['examples/BRATS_485.nii.gz', 65], | |
| ['examples/BRATS_486.nii.gz', 80] | |
| ] | |
| # Load the MONAI pretrained model from Hugging Face Hub | |
| model, _, _ = bundle.load( | |
| name = BUNDLE_NAME, | |
| source = 'huggingface_hub', | |
| repo = 'katielink/brats_mri_segmentation_v0.1.0', | |
| load_ts_module=True, | |
| ) | |
| # Use GPU if available | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # Load the parser from the MONAI bundle's inference config | |
| parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json') | |
| # Compose the preprocessing transforms | |
| preproc_transforms = Compose( | |
| [ | |
| LoadImaged(keys=["image"]), | |
| EnsureChannelFirstd(keys="image"), | |
| Orientationd(keys=["image"], axcodes="RAS"), | |
| NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), | |
| ] | |
| ) | |
| # Get the inferer from the bundle's inference config | |
| inferer = parser.get_parsed_content( | |
| 'inferer', | |
| lazy=True, eval_expr=True, instantiate=True | |
| ) | |
| # Compose the postprocessing transforms | |
| post_transforms = Compose( | |
| [ | |
| Activationsd(keys='pred', sigmoid=True), | |
| AsDiscreted(keys='pred', threshold=0.5), | |
| ScaleIntensityd(keys='image', minv=0., maxv=1.) | |
| ] | |
| ) | |
| # Define the predict function for the demo | |
| def predict(input_file, z_axis, model=model, device=device): | |
| # Load and process data in MONAI format | |
| data = {'image': [input_file.name]} | |
| data = preproc_transforms(data) | |
| # Run inference and post-process predicted labels | |
| model.to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| inputs = data['image'].to(device) | |
| data['pred'] = inferer(inputs=inputs[None,...], network=model) | |
| data = post_transforms(data) | |
| # Convert tensors back to numpy arrays | |
| data['image'] = data['image'].numpy() | |
| data['pred'] = data['pred'].cpu().detach().numpy() | |
| # Magnetic resonance imaging sequences | |
| t1c = data['image'][0, :, :, z_axis] # T1-weighted, post contrast | |
| t1 = data['image'][1, :, :, z_axis] # T1-weighted, pre contrast | |
| t2 = data['image'][2, :, :, z_axis] # T2-weighted | |
| flair = data['image'][3, :, :, z_axis] # FLAIR | |
| # BraTS labels | |
| tc = data['pred'][0, 0, :, :, z_axis] # Tumor core | |
| wt = data['pred'][0, 1, :, :, z_axis] # Whole tumor | |
| et = data['pred'][0, 2, :, :, z_axis] # Enhancing tumor | |
| return [t1c, t1, t2, flair], [tc, wt, et] | |
| # Use blocks to set up a more complex demo | |
| with gr.Blocks() as demo: | |
| # Show title and description | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| # Get the input file and slice slider as inputs | |
| input_file = gr.File(label='input file') | |
| z_axis = gr.Slider(0, 200, label='slice', value=50) | |
| with gr.Row(): | |
| # Show the button with custom label | |
| button = gr.Button("Segment Tumor!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Show the input image with different MR sequences | |
| input_image = gr.Gallery(label='input MRI sequences (T1+, T1, T2, FLAIR)') | |
| with gr.Column(): | |
| # Show the segmentation labels | |
| output_segmentation = gr.Gallery(label='output segmentations (TC, WT, ET)') | |
| # Run prediction on button click | |
| button.click( | |
| predict, | |
| inputs=[input_file, z_axis], | |
| outputs=[input_image, output_segmentation] | |
| ) | |
| # Have some example for the user to try out | |
| examples = gr.Examples( | |
| examples=examples, | |
| inputs=[input_file, z_axis], | |
| outputs=[input_image, output_segmentation], | |
| fn=predict, | |
| cache_examples=False | |
| ) | |
| # Show references at the bottom of the demo | |
| gr.Markdown(references) | |
| # Launch the demo | |
| demo.launch() | |