-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit bc85cba
Showing
17 changed files
with
865 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Meta-Polyp: a baseline for efficient Polyp segmentation, CBMS 2023 | ||
|
||
This repo is the official implementation for paper: | ||
|
||
Meta-Polyp: a baseline for efficient Polyp segmentation. | ||
|
||
Authors: Quoc-Huy Trinh | ||
|
||
In the IEEE 36th International Symposium on Computer Based Medical Systems (CBMS) 2023. | ||
|
||
Detail of each model modules can be found in original paper. Please citation if you use our implementation for research purpose. | ||
|
||
## Overall architecture | ||
|
||
Architecutre Meta-Polyp baseline model: | ||
|
||
<div align="center"> | ||
<a href="./"> | ||
<img src="img/MetaPolyp.png" width="79%"/> | ||
</a> | ||
</div> | ||
|
||
## Installation | ||
|
||
Our implementation is on ``` Python 3.9 ``` , please make sure to config your environment compatible with the requirements. | ||
|
||
To install all packages, use ``` requirements.txt ``` file to install. Install with ```pip ``` by the following command: | ||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
All packages will be automatically installed. | ||
|
||
## Config | ||
|
||
All of configs for training and benchmark are in ```./config/``` folder. Please take a look for tuning phase. | ||
|
||
## Training | ||
|
||
For training, use ``` train.py ``` file for start training. | ||
|
||
The following command should be used: | ||
|
||
``` | ||
python train.py | ||
``` | ||
|
||
## Benchmark | ||
|
||
For benchmar, use ```benchmark.py``` file for start testing. | ||
|
||
The following command should be used: | ||
|
||
``` | ||
python benchmark.py | ||
``` | ||
|
||
### Note: | ||
you should fix model_path for your model path and directory to your benchmark dataset. | ||
|
||
## Pretrained weights | ||
|
||
The weight will be update later. | ||
|
||
## Dataset | ||
|
||
In our experiment, we use the dataset config from (PraNet)[https://github.com/DengPingFan/PraNet], with training set from 50% of Kvasir-SEG and 50% of ClinicDB dataset. | ||
|
||
With our test dataset, we use the following: | ||
|
||
In same distribution: | ||
|
||
- Kvasir SEG | ||
|
||
- ClinicDB | ||
|
||
|
||
Out of distribution: | ||
|
||
- Etis dataset | ||
|
||
- ColonDB | ||
|
||
- CVC300 | ||
|
||
|
||
## Results | ||
|
||
The IOU score on SOTA for both 5 datasets: | ||
|
||
<div align="center"> | ||
<a href="./"> | ||
<img src="img/res.png" width="79%"/> | ||
</a> | ||
</div> | ||
|
||
We do some qualiative result with others SOTA method visualization: | ||
|
||
<div align="center"> | ||
<a href="./"> | ||
<img src="img/miccaivis.png" width="79%"/> | ||
</a> | ||
</div> | ||
|
||
## Weights | ||
|
||
Coming soon | ||
|
||
## Customize | ||
You can change the backbone from Ca-former to PVT or something else to get different results. | ||
|
||
## Citation | ||
|
||
``` | ||
Coming soon | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# from save_model.pvt_CAM_channel_att_upscale import build_model | ||
import os | ||
import tensorflow as tf | ||
# from metrics.metrics_last import iou_metric, MAE, WFbetaMetric, SMeasure, Emeasure, dice_coef, iou_metric | ||
from metrics.segmentation_metrics import dice_coeff, bce_dice_loss, IoU, zero_IoU, dice_loss | ||
from dataloader.dataloader import build_augmenter, build_dataset, build_decoder | ||
from tensorflow.keras.utils import get_custom_objects | ||
from model_research import build_model | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"]="0" | ||
|
||
def load_dataset(route): | ||
X_path = '{}/images/'.format(route) | ||
Y_path = '{}/masks/'.format(route) | ||
X_full = sorted(os.listdir(f'{route}/images')) | ||
Y_full = sorted(os.listdir(f'{route}/masks')) | ||
|
||
X_train = [X_path + x for x in X_full] | ||
Y_train = [Y_path + x for x in Y_full] | ||
|
||
test_decoder = build_decoder(with_labels=False, target_size=(img_size, img_size), ext='jpg', | ||
segment=True, ext2='jpg') | ||
test_dataset = build_dataset(X_train, Y_train, bsize=BATCH_SIZE, decode_fn=test_decoder, | ||
augmentAdv=False, augment=False, augmentAdvSeg=False) | ||
return test_dataset, len(X_train) | ||
|
||
def benchmark(route, model, BATCH_SIZE = 32, save_file_name = "benchmark_result.txt"): | ||
|
||
list_of_datasets = os.listdir(route) | ||
f = open(save_file_name,"a") | ||
f.write("\n") | ||
for datasets in list_of_datasets: | ||
print(datasets, ":") | ||
test_dataset, len_data = load_dataset(os.path.join(route,datasets)) | ||
steps_per_epoch = len_data // BATCH_SIZE | ||
loss, dice_coeff, bce_dice_loss, IoU, zero_IoU, mae = model.evaluate(test_dataset, steps=steps_per_epoch) | ||
f.write("{}:".format(datasets)) | ||
f.write("dice_coeff: {}, bce_didce_loss: {}, IoU: {}, zero_IoU: {}, mae: {}".format(dice_coeff, bce_dice_loss, IoU, zero_IoU, mae)) | ||
f.write('\n') | ||
|
||
if __name__ == "__main__": | ||
|
||
img_size = 256 | ||
BATCH_SIZE = 1 | ||
SEED = 1024 | ||
save_path = "best_model.h5" | ||
route_data = "./TestDataset/" | ||
path_to_test_dataset = "./TestDataset/" | ||
model = build_model(img_size) | ||
model.load_weights(save_path) | ||
|
||
model.compile(metrics=[dice_coeff, bce_dice_loss, IoU, zero_IoU, tf.keras.metrics.MeanSquaredError()]) | ||
|
||
benchmark(path_to_test_dataset, model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import math | ||
import tensorflow as tf | ||
import matplotlib.pyplot as plt | ||
|
||
def cosine_annealing_with_warmup(epochIdx): | ||
aMax, aMin = max_lr, min_lr | ||
warmupEpochs, stagnateEpochs, cosAnnealingEpochs = 0, 0, cos_anne_ep | ||
epochIdx = epochIdx % (warmupEpochs + stagnateEpochs + cosAnnealingEpochs) | ||
if(epochIdx < warmupEpochs): | ||
return aMin + (aMax - aMin) / (warmupEpochs - 1) * epochIdx | ||
else: | ||
epochIdx -= warmupEpochs | ||
if(epochIdx < stagnateEpochs): | ||
return aMax | ||
else: | ||
epochIdx -= stagnateEpochs | ||
return aMin + 0.5 * (aMax - aMin) * (1 + math.cos((epochIdx + 1) / (cosAnnealingEpochs + 1) * math.pi)) | ||
|
||
def plt_lr(step, schedulers): | ||
x = range(step) | ||
y = [schedulers(_) for _ in x] | ||
|
||
plt.plot(x, y) | ||
plt.xlabel('Epoch') | ||
plt.ylabel('Learning Rate') | ||
plt.legend() | ||
|
||
def get_callbacks(monitor, mode, save_path, _max_lr, _min_lr, _cos_anne_ep, save_weights_only): | ||
global max_lr | ||
max_lr = _max_lr | ||
global min_lr | ||
min_lr = _min_lr | ||
global cos_anne_ep | ||
cos_anne_ep = _cos_anne_ep | ||
|
||
early_stopping = tf.keras.callbacks.EarlyStopping( | ||
monitor=monitor, | ||
patience=60, | ||
restore_best_weights=True, | ||
mode=mode | ||
) | ||
|
||
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau( | ||
monitor=monitor, | ||
factor=0.2, | ||
patience=50, | ||
verbose=1, | ||
mode=mode, | ||
min_lr=1e-5, | ||
) | ||
|
||
checkpoint = tf.keras.callbacks.ModelCheckpoint( | ||
filepath=save_path, | ||
monitor=monitor, | ||
verbose=1, | ||
save_best_only=True, | ||
save_weights_only=save_weights_only, | ||
mode=mode, | ||
save_freq="epoch", | ||
) | ||
|
||
lr_schedule = tf.keras.callbacks.LearningRateScheduler(cosine_annealing_with_warmup, verbose=0) | ||
|
||
csv_logger = tf.keras.callbacks.CSVLogger('training.csv') | ||
|
||
callbacks = [checkpoint, csv_logger, reduce_lr] | ||
# , reduce_lr | ||
return callbacks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import tensorflow as tf | ||
import os | ||
|
||
# def auto_select_accelerator(): | ||
# try: | ||
# tpu = tf.distribute.cluster_resolver.TPUClusterResolver() | ||
# tf.config.experimental_connect_to_cluster(tpu) | ||
# tf.tpu.experimental.initialize_tpu_system(tpu) | ||
# strategy = tf.distribute.experimental.TPUStrategy(tpu) | ||
# print("Running on TPU:", tpu.master()) | ||
# except ValueError: | ||
# strategy = tf.distribute.get_strategy() | ||
# print(f"Running on {strategy.num_replicas_in_sync} replicas") | ||
|
||
# return strategy | ||
|
||
def default_augment_seg(input_image, input_mask): | ||
|
||
input_image = tf.image.random_brightness(input_image, 0.1) | ||
input_image = tf.image.random_contrast(input_image, 0.9, 1.1) | ||
input_image = tf.image.random_saturation(input_image, 0.9, 1.1) | ||
input_image = tf.image.random_hue(input_image, 0.01) | ||
|
||
# flipping random horizontal or vertical | ||
if tf.random.uniform(()) > 0.5: | ||
input_image = tf.image.flip_left_right(input_image) | ||
input_mask = tf.image.flip_left_right(input_mask) | ||
if tf.random.uniform(()) > 0.5: | ||
input_image = tf.image.flip_up_down(input_image) | ||
input_mask = tf.image.flip_up_down(input_mask) | ||
|
||
return input_image, input_mask | ||
|
||
def BatchAdvAugmentSeg(imagesT, masksT): | ||
|
||
images, masks = default_augment_seg(imagesT, masksT) | ||
|
||
return images, masks | ||
|
||
def build_decoder(with_labels=True, target_size=(256, 256), ext='png', segment=False, ext2='png'): | ||
|
||
def decode(path): | ||
file_bytes = tf.io.read_file(path) | ||
if ext == 'png': | ||
img = tf.image.decode_png(file_bytes, channels=3, dct_method='INTEGER_ACCURATE') | ||
elif ext in ['jpg', 'jpeg']: | ||
img = tf.image.decode_jpeg(file_bytes, channels=3, dct_method='INTEGER_ACCURATE') | ||
else: | ||
raise ValueError("Image extension not supported") | ||
|
||
img = tf.image.resize(img, target_size) | ||
# img = tf.cast(img, tf.float32) / 255.0 | ||
|
||
return img | ||
|
||
def decode_mask(path, gray=True): | ||
file_bytes = tf.io.read_file(path) | ||
if ext2 == 'png': | ||
img = tf.image.decode_png(file_bytes, channels=3) | ||
elif ext2 in ['jpg', 'jpeg']: | ||
img = tf.image.decode_jpeg(file_bytes, channels=3) | ||
else: | ||
raise ValueError("Image extension not supported") | ||
|
||
img = tf.image.rgb_to_grayscale(img) if gray else img | ||
img = tf.image.resize(img, target_size) | ||
img = tf.cast(img, tf.float32) / 255.0 | ||
|
||
return img | ||
|
||
def decode_with_labels(path, label): | ||
return decode(path), label | ||
|
||
def decode_with_segments(path, path2, gray=True): | ||
return decode(path), decode_mask(path2, gray) | ||
|
||
if segment: | ||
return decode_with_segments | ||
|
||
return decode_with_labels if with_labels else decode | ||
|
||
|
||
def build_augmenter(with_labels=True): | ||
def augment(img): | ||
|
||
img = tf.image.random_flip_up_down(img) | ||
img = tf.image.random_flip_left_right(img) | ||
# img = tf.image.rot90(img, k=tf.random.uniform([],0,4,tf.int32)) | ||
|
||
img = tf.image.random_brightness(img, 0.1) | ||
img = tf.image.random_contrast(img, 0.9, 1.1) | ||
img = tf.image.random_saturation(img, 0.9, 1.1) | ||
img = tf.image.random_hue(img, 0.02) | ||
|
||
# img = transform_mat(img) | ||
|
||
return img | ||
|
||
def augment_with_labels(img, label): | ||
return augment(img), label | ||
|
||
return augment_with_labels if with_labels else augment | ||
|
||
|
||
def build_dataset(paths, labels=None, bsize=32, cache=True, | ||
decode_fn=None, augment_fn=None, | ||
augment=True, augmentAdv=False, augmentAdvSeg=False, repeat=True, shuffle=1024, | ||
cache_dir=""): | ||
if cache_dir != "" and cache is True: | ||
os.makedirs(cache_dir, exist_ok=True) | ||
|
||
if decode_fn is None: | ||
decode_fn = build_decoder(labels is not None) | ||
|
||
if augment_fn is None: | ||
augment_fn = build_augmenter(labels is not None) | ||
|
||
AUTO = tf.data.experimental.AUTOTUNE | ||
slices = paths if labels is None else (paths, labels) | ||
|
||
dset = tf.data.Dataset.from_tensor_slices(slices) | ||
dset = dset.map(decode_fn, num_parallel_calls=AUTO) | ||
dset = dset.cache(cache_dir) if cache else dset | ||
dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset | ||
dset = dset.repeat() if repeat else dset | ||
dset = dset.shuffle(shuffle) if shuffle else dset | ||
dset = dset.batch(bsize) | ||
# dset = dset.map(BatchAdvAugment, num_parallel_calls=AUTO) if augmentAdv else dset | ||
dset = dset.map(BatchAdvAugmentSeg, num_parallel_calls=AUTO) if augmentAdvSeg else dset | ||
dset = dset.prefetch(AUTO) | ||
|
||
return dset |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import tensorflow as tf | ||
|
||
def convformer(input_tensor, filters, padding = "same"): | ||
|
||
x = tf.keras.layers.LayerNormalization()(input_tensor) | ||
x = tf.keras.layers.SeparableConv2D(filters, kernel_size = (3,3), padding = padding)(x) | ||
# x = x1 + x2 + x3 | ||
x = tf.keras.layers.Attention()([x, x, x]) | ||
out = tf.keras.layers.Add()([x, input_tensor]) | ||
|
||
x1 = tf.keras.layers.Dense(filters, activation = "gelu")(out) | ||
x1 = tf.keras.layers.Dense(filters)(x1) | ||
out_tensor = tf.keras.layers.Add()([out, x1]) | ||
return out_tensor |
Oops, something went wrong.