Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VoxelMorph Atlas is just a slightly blurred version of the fixed image #589

Open
wilcamsdington opened this issue Feb 22, 2024 · 23 comments

Comments

@wilcamsdington
Copy link

Task (what are you trying to do/register?)

I am trying to build a template from 3D brain T1 MRI dataset of around 100 images. size of (224,224,224).

What have you tried

I ran the unconditional template code for brain 3D data, followed the code in this colab

I used one of brain MRI image as template(fixed image) during the training.
Otherwise, I encounter the same issue as described in #587

Details of experiments

def template_gen(x, batch_size):
    vol_shape = list(x.shape[1:-1])
    zero = np.zeros([batch_size] + vol_shape + [2])
    mean_atlas = np.repeat(np.mean(tem, 0, keepdims=True), batch_size, 0)  # mean_atlas is replaced by a select MRI image(fixed image) 
    while True:
        idx = np.random.randint(0, x.shape[0], batch_size)
        img = x[idx, ...]
        inputs = [mean_atlas, img]
        outputs = [img, zero, zero, zero]
        yield inputs, outputs
# get the model
enc_nf = [16, 32, 32, 32]
dec_nf = [32, 32, 32, 32, 32, 16, 16]
model = vxm.networks.TemplateCreation(vol_shape, nb_unet_features=[enc_nf, dec_nf], src_feats=224, atlas_feats=224)
# prepare losses
image_loss_func = vxm.losses.MSE().loss
neg_loss_func = lambda _, y_pred: image_loss_func(model.references.atlas_tensor, y_pred)
losses = [image_loss_func, neg_loss_func, vxm.losses.MSE().loss, vxm.losses.Grad('l2', loss_mult=2).loss]
loss_weights = [0.5, 0.5, 1, 0.01]
# train
gen = template_gen(x_vols, batch_size=2)
from tensorflow.keras.callbacks import ModelCheckpoint
save_callback = tf.keras.callbacks.ModelCheckpoint('model_checkpoint_{epoch:02d}.h5', period=50)
model.compile('adam', loss=losses, loss_weights=loss_weights)
hist = model.fit(gen, epochs=1400, steps_per_epoch=32, callbacks=[save_callback],
                 verbose=0)

vxm.py.utils.save_volfile(model.get_atlas(), 'slice_224.nii.gz')

image

The final VoxelMorph template is just a slightly blurred version of the fixed image
image

@vsleios
Copy link

vsleios commented Mar 1, 2024

I am facing exactly the same problem as well in my application ! My outputs are blurred versions of the images I provide.

@adalca
Copy link
Collaborator

adalca commented Mar 2, 2024

two quick questions -- are the images normalized in [0, 1], and are they affinely aligned? @wilcamsdington in your example, original 1 and 2 seem to be very out of alignment

@wilcamsdington
Copy link
Author

two quick questions -- are the images normalized in [0, 1], and are they affinely aligned? @wilcamsdington in your example, original 1 and 2 seem to be very out of alignment

Yes, all images are normalized in [0, 1] and affinely aligned. and all images are 3D images, in my example original 1 and 2 are different slices from the same 3D image

@adalca
Copy link
Collaborator

adalca commented Mar 2, 2024

In order to help, i need to see a few images to understand what is happening.

Could you show a middle slice of the atlas and maybe 2-3 subjects at that slice and their deformation fields? are they super stiff?

@wilcamsdington
Copy link
Author

image
Here are some middle slice of the atlas
Not sure how to get the corresponding deformation fields
code(.ipynb) is my training code and the model_checkpoint: model_checkpoint.h5.
I hopeit would help.

@adalca
Copy link
Collaborator

adalca commented Mar 2, 2024

@wilcamsdington I'm sorry, we dont have the bandwith to help go through other people's jupyters and help debug these things, although I'd love to.

It would really help us if you give us as much information as possible

We need to see a few of the subjects so we understand the diversity, the size of the images to understand the deformations you expect. We have a tutorial on visualizing deformation fields here: https://colab.research.google.com/drive/1F8f1imh5WfyBv-crllfeJBFY16-KHl9c?usp=sharing

@wilcamsdington
Copy link
Author

image
Here is the middle slice of the atlas and 3 subjects at that slice and their deformation fields

@adalca
Copy link
Collaborator

adalca commented Mar 2, 2024

Thanks. I only see two subjects and it's hard to understand what those deformations are doing -- I personally can read the grid visualization much better. My suspicion is that the regularization is too strong.

Also note that your images include everything, not just the brain, which will make it difficult for any algorithm to do a good job since it will try to match all kinds of anatomy that has substantial deformations, like the neck and such. I would recommend running synthstrip on the images first.

@wilcamsdington
Copy link
Author

Thanks. I only see two subjects and it's hard to understand what those deformations are doing -- I personally can read the grid visualization much better. My suspicion is that the regularization is too strong.

Also note that your images include everything, not just the brain, which will make it difficult for any algorithm to do a good job since it will try to match all kinds of anatomy that has substantial deformations, like the neck and such. I would recommend running synthstrip on the images first.

Thank you for your suggestion. I've attempted running synthstrip on the images, but encountered similar outcomes. I share your suspicion that the strong regularization might be the cause. I appreciate your advice, and I'll try reducing the regularization term to see if it yields better results.

@wilcamsdington
Copy link
Author

image
I'm running synthstrip on the images first and trying to switch loss vxm.losses.MSE().loss to vxm.losses.NCC().loss during the training, and I modified some weights of the losses. The results are as shown in the figure.

image_loss_func = vxm.losses.NCC().loss # vxm.losses.MSE().loss

neg_loss_func = lambda _, y_pred: image_loss_func(model.references.atlas_tensor, y_pred)
losses = [image_loss_func, neg_loss_func, vxm.losses.MSE().loss, vxm.losses.Grad('l2', loss_mult=2).loss]
loss_weights = [0.5, 0.5, 1, 0.1]

Do you have any suggestions or recommendations?

@liamburrows
Copy link

liamburrows commented Mar 15, 2024

Hi,
I am also working on this problem, and when reproducing the MNIST tutorial from this link: (https://colab.research.google.com/drive/1PJ-aRZrkU-2SfVEIBlg8kHRJZrjdnmOT#scrollTo=9OylH0vLjBNz , in the 'unconditional template (MNIST)' section), I was also only getting a blurry 5, rather than a sharp 5 as shown in the colab link.

After looking at the voxelmorph code I note than TemplateCreation expects a single input, whereas the generator provided in all tutorials provide a list of length two as input: [mean_atlas, img], where mean_atlas is an average of all images in the dataset. Which I guess can serve as an initial guess for an atlas, but i'm not sure what purpose this serves.
Now, I may be misunderstanding how Keras works, but I think perhaps when the network needs a single input but you supply an input of list of length two, Keras treats each entry of the list as a seperate batch. If this is the case, then half of the images in the training set will be identical and will be this mean_atlas. Therefore you would expect a blurry average as the training set is dominated by a blurry average.

I have been able to reproduce the MNIST tutorial successfully by changing the code in the generator from:

  • inputs = [mean_atlas, img]
    to:
  • inputs = img

Please let me know if this works on your end as I am also working on a similar problem of atlas creation currently!

@wilcamsdington
Copy link
Author

Hi, I am also working on this problem, and when reproducing the MNIST tutorial from this link: (https://colab.research.google.com/drive/1PJ-aRZrkU-2SfVEIBlg8kHRJZrjdnmOT#scrollTo=9OylH0vLjBNz , in the 'unconditional template (MNIST)' section), I was also only getting a blurry 5, rather than a sharp 5 as shown in the colab link.

After looking at the voxelmorph code I note than TemplateCreation expects a single input, whereas the generator provided in all tutorials provide a list of length two as input: [mean_atlas, img], where mean_atlas is an average of all images in the dataset. Which I guess can serve as an initial guess for an atlas, but i'm not sure what purpose this serves. Now, I may be misunderstanding how Keras works, but I think perhaps when the network needs a single input but you supply an input of list of length two, Keras treats each entry of the list as a seperate batch. If this is the case, then half of the images in the training set will be identical and will be this mean_atlas. Therefore you would expect a blurry average as the training set is dominated by a blurry average.

I have been able to reproduce the MNIST tutorial successfully by changing the code in the generator from:

  • inputs = [mean_atlas, img]
    to:
  • inputs = img

Please let me know if this works on your end as I am also working on a similar problem of atlas creation currently!

I have try it, the results are still blur and lack of details(like #587)
image

@adalca
Copy link
Collaborator

adalca commented Mar 16, 2024

This is very weird, I wonder if its a change in tf, or something we pushed but missed it.
I will try to run the MNIST tutorial with your @LRB13 's edit in mind and and see if it still works for me.
@wilcamsdington did you try the MNIST tutorial?

@vsleios
Copy link

vsleios commented Apr 4, 2024

Do we have any updates regarding the blurring problem ?

@adalca
Copy link
Collaborator

adalca commented Apr 4, 2024

Hmm, I'm not quite sure what to say -- I re-went through our template atlas, which you can see here

The MNIST template looks good to me --- can you verify you can replicate this?
image

I am still running the 2D OASIS brain atlas in that jupyter right now (I dont have access to fancy GPUs on google colab so it's slow) and will update here when it's done.

@adalca
Copy link
Collaborator

adalca commented Apr 4, 2024

Here is the 2D T1w brain with an untuned model. Seems reasonable (not optimal! but not horrible) to me ?

image

@adalca
Copy link
Collaborator

adalca commented Apr 4, 2024

@LRB13 @vsleios @wilcamsdington can you all check out this tutorial and tell me if it works for you?

The only updates i made was

  • change the generator to only give as input the image instead of image and initial-blur, just as @wilcamsdington suggested
  • change the tf version

So it works for both MNIST and the brain for me. Does it still post a problem for you all?

@liamburrows
Copy link

liamburrows commented Apr 4, 2024 via email

@wilcamsdington
Copy link
Author

Thank you for your updates and responses!I've tested the updated tutorial, and it seems to be working fine for me as well, both MNIST and the brain dataset are running smoothly without any issues.

For my case, however, the VoxelMorph (unconditional)Atlas is still blurry and lacks detail. I guess the code/model itself is not the problem, it might be because the MRI data is in 3D, or perhaps due to the variations within the data itself (large and small brains).
image

@adalca
Copy link
Collaborator

adalca commented Apr 5, 2024

It would be great to maybe see a few brain images side by side from your dataset. Are they very different? are their intensities in different ranges?

The same code should work decently well for 3D atlases

@wilcamsdington
Copy link
Author

Here are a few examples of our dataset, and intensities are in the same ranges.

@adalca
Copy link
Collaborator

adalca commented Apr 6, 2024

@wilcamsdington They are not affinely aligned, which is likely causing the problem. I'd recommend the following:

  • run each subject through synthstrip to extract the brain. Should work out of the box.
  • affinely register each subject to the first subject using synthmorph
  • THen take the resulting affinely-aligned brains and (1) show them to us :) and (2) run affinemorph!

@wilcamsdington
Copy link
Author

Thank you for your guidance. I followed your suggestions, and the results indeed showed improvement(more details). However, some blurriness persists in the overall output. I will attempt to adjust the parameters to further optimize the process and see if it can be improved.
Your support has been greatly appreciated, and I'm thankful for your guidance.

Examples of aligned brains:
image

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants