Skip to content

Commit

Permalink
full pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
huyquoctrinh committed May 13, 2023
0 parents commit bc85cba
Show file tree
Hide file tree
Showing 17 changed files with 865 additions and 0 deletions.
118 changes: 118 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Meta-Polyp: a baseline for efficient Polyp segmentation, CBMS 2023

This repo is the official implementation for paper:

Meta-Polyp: a baseline for efficient Polyp segmentation.

Authors: Quoc-Huy Trinh

In the IEEE 36th International Symposium on Computer Based Medical Systems (CBMS) 2023.

Detail of each model modules can be found in original paper. Please citation if you use our implementation for research purpose.

## Overall architecture

Architecutre Meta-Polyp baseline model:

<div align="center">
<a href="./">
<img src="img/MetaPolyp.png" width="79%"/>
</a>
</div>

## Installation

Our implementation is on ``` Python 3.9 ``` , please make sure to config your environment compatible with the requirements.

To install all packages, use ``` requirements.txt ``` file to install. Install with ```pip ``` by the following command:

```
pip install -r requirements.txt
```

All packages will be automatically installed.

## Config

All of configs for training and benchmark are in ```./config/``` folder. Please take a look for tuning phase.

## Training

For training, use ``` train.py ``` file for start training.

The following command should be used:

```
python train.py
```

## Benchmark

For benchmar, use ```benchmark.py``` file for start testing.

The following command should be used:

```
python benchmark.py
```

### Note:
you should fix model_path for your model path and directory to your benchmark dataset.

## Pretrained weights

The weight will be update later.

## Dataset

In our experiment, we use the dataset config from (PraNet)[https://github.com/DengPingFan/PraNet], with training set from 50% of Kvasir-SEG and 50% of ClinicDB dataset.

With our test dataset, we use the following:

In same distribution:

- Kvasir SEG

- ClinicDB


Out of distribution:

- Etis dataset

- ColonDB

- CVC300


## Results

The IOU score on SOTA for both 5 datasets:

<div align="center">
<a href="./">
<img src="img/res.png" width="79%"/>
</a>
</div>

We do some qualiative result with others SOTA method visualization:

<div align="center">
<a href="./">
<img src="img/miccaivis.png" width="79%"/>
</a>
</div>

## Weights

Coming soon

## Customize
You can change the backbone from Ca-former to PVT or something else to get different results.

## Citation

```
Coming soon
```

54 changes: 54 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# from save_model.pvt_CAM_channel_att_upscale import build_model
import os
import tensorflow as tf
# from metrics.metrics_last import iou_metric, MAE, WFbetaMetric, SMeasure, Emeasure, dice_coef, iou_metric
from metrics.segmentation_metrics import dice_coeff, bce_dice_loss, IoU, zero_IoU, dice_loss
from dataloader.dataloader import build_augmenter, build_dataset, build_decoder
from tensorflow.keras.utils import get_custom_objects
from model_research import build_model

os.environ["CUDA_VISIBLE_DEVICES"]="0"

def load_dataset(route):
X_path = '{}/images/'.format(route)
Y_path = '{}/masks/'.format(route)
X_full = sorted(os.listdir(f'{route}/images'))
Y_full = sorted(os.listdir(f'{route}/masks'))

X_train = [X_path + x for x in X_full]
Y_train = [Y_path + x for x in Y_full]

test_decoder = build_decoder(with_labels=False, target_size=(img_size, img_size), ext='jpg',
segment=True, ext2='jpg')
test_dataset = build_dataset(X_train, Y_train, bsize=BATCH_SIZE, decode_fn=test_decoder,
augmentAdv=False, augment=False, augmentAdvSeg=False)
return test_dataset, len(X_train)

def benchmark(route, model, BATCH_SIZE = 32, save_file_name = "benchmark_result.txt"):

list_of_datasets = os.listdir(route)
f = open(save_file_name,"a")
f.write("\n")
for datasets in list_of_datasets:
print(datasets, ":")
test_dataset, len_data = load_dataset(os.path.join(route,datasets))
steps_per_epoch = len_data // BATCH_SIZE
loss, dice_coeff, bce_dice_loss, IoU, zero_IoU, mae = model.evaluate(test_dataset, steps=steps_per_epoch)
f.write("{}:".format(datasets))
f.write("dice_coeff: {}, bce_didce_loss: {}, IoU: {}, zero_IoU: {}, mae: {}".format(dice_coeff, bce_dice_loss, IoU, zero_IoU, mae))
f.write('\n')

if __name__ == "__main__":

img_size = 256
BATCH_SIZE = 1
SEED = 1024
save_path = "best_model.h5"
route_data = "./TestDataset/"
path_to_test_dataset = "./TestDataset/"
model = build_model(img_size)
model.load_weights(save_path)

model.compile(metrics=[dice_coeff, bce_dice_loss, IoU, zero_IoU, tf.keras.metrics.MeanSquaredError()])

benchmark(path_to_test_dataset, model)
68 changes: 68 additions & 0 deletions callbacks/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import math
import tensorflow as tf
import matplotlib.pyplot as plt

def cosine_annealing_with_warmup(epochIdx):
aMax, aMin = max_lr, min_lr
warmupEpochs, stagnateEpochs, cosAnnealingEpochs = 0, 0, cos_anne_ep
epochIdx = epochIdx % (warmupEpochs + stagnateEpochs + cosAnnealingEpochs)
if(epochIdx < warmupEpochs):
return aMin + (aMax - aMin) / (warmupEpochs - 1) * epochIdx
else:
epochIdx -= warmupEpochs
if(epochIdx < stagnateEpochs):
return aMax
else:
epochIdx -= stagnateEpochs
return aMin + 0.5 * (aMax - aMin) * (1 + math.cos((epochIdx + 1) / (cosAnnealingEpochs + 1) * math.pi))

def plt_lr(step, schedulers):
x = range(step)
y = [schedulers(_) for _ in x]

plt.plot(x, y)
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.legend()

def get_callbacks(monitor, mode, save_path, _max_lr, _min_lr, _cos_anne_ep, save_weights_only):
global max_lr
max_lr = _max_lr
global min_lr
min_lr = _min_lr
global cos_anne_ep
cos_anne_ep = _cos_anne_ep

early_stopping = tf.keras.callbacks.EarlyStopping(
monitor=monitor,
patience=60,
restore_best_weights=True,
mode=mode
)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor=monitor,
factor=0.2,
patience=50,
verbose=1,
mode=mode,
min_lr=1e-5,
)

checkpoint = tf.keras.callbacks.ModelCheckpoint(
filepath=save_path,
monitor=monitor,
verbose=1,
save_best_only=True,
save_weights_only=save_weights_only,
mode=mode,
save_freq="epoch",
)

lr_schedule = tf.keras.callbacks.LearningRateScheduler(cosine_annealing_with_warmup, verbose=0)

csv_logger = tf.keras.callbacks.CSVLogger('training.csv')

callbacks = [checkpoint, csv_logger, reduce_lr]
# , reduce_lr
return callbacks
132 changes: 132 additions & 0 deletions dataloader/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import tensorflow as tf
import os

# def auto_select_accelerator():
# try:
# tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
# tf.config.experimental_connect_to_cluster(tpu)
# tf.tpu.experimental.initialize_tpu_system(tpu)
# strategy = tf.distribute.experimental.TPUStrategy(tpu)
# print("Running on TPU:", tpu.master())
# except ValueError:
# strategy = tf.distribute.get_strategy()
# print(f"Running on {strategy.num_replicas_in_sync} replicas")

# return strategy

def default_augment_seg(input_image, input_mask):

input_image = tf.image.random_brightness(input_image, 0.1)
input_image = tf.image.random_contrast(input_image, 0.9, 1.1)
input_image = tf.image.random_saturation(input_image, 0.9, 1.1)
input_image = tf.image.random_hue(input_image, 0.01)

# flipping random horizontal or vertical
if tf.random.uniform(()) > 0.5:
input_image = tf.image.flip_left_right(input_image)
input_mask = tf.image.flip_left_right(input_mask)
if tf.random.uniform(()) > 0.5:
input_image = tf.image.flip_up_down(input_image)
input_mask = tf.image.flip_up_down(input_mask)

return input_image, input_mask

def BatchAdvAugmentSeg(imagesT, masksT):

images, masks = default_augment_seg(imagesT, masksT)

return images, masks

def build_decoder(with_labels=True, target_size=(256, 256), ext='png', segment=False, ext2='png'):

def decode(path):
file_bytes = tf.io.read_file(path)
if ext == 'png':
img = tf.image.decode_png(file_bytes, channels=3, dct_method='INTEGER_ACCURATE')
elif ext in ['jpg', 'jpeg']:
img = tf.image.decode_jpeg(file_bytes, channels=3, dct_method='INTEGER_ACCURATE')
else:
raise ValueError("Image extension not supported")

img = tf.image.resize(img, target_size)
# img = tf.cast(img, tf.float32) / 255.0

return img

def decode_mask(path, gray=True):
file_bytes = tf.io.read_file(path)
if ext2 == 'png':
img = tf.image.decode_png(file_bytes, channels=3)
elif ext2 in ['jpg', 'jpeg']:
img = tf.image.decode_jpeg(file_bytes, channels=3)
else:
raise ValueError("Image extension not supported")

img = tf.image.rgb_to_grayscale(img) if gray else img
img = tf.image.resize(img, target_size)
img = tf.cast(img, tf.float32) / 255.0

return img

def decode_with_labels(path, label):
return decode(path), label

def decode_with_segments(path, path2, gray=True):
return decode(path), decode_mask(path2, gray)

if segment:
return decode_with_segments

return decode_with_labels if with_labels else decode


def build_augmenter(with_labels=True):
def augment(img):

img = tf.image.random_flip_up_down(img)
img = tf.image.random_flip_left_right(img)
# img = tf.image.rot90(img, k=tf.random.uniform([],0,4,tf.int32))

img = tf.image.random_brightness(img, 0.1)
img = tf.image.random_contrast(img, 0.9, 1.1)
img = tf.image.random_saturation(img, 0.9, 1.1)
img = tf.image.random_hue(img, 0.02)

# img = transform_mat(img)

return img

def augment_with_labels(img, label):
return augment(img), label

return augment_with_labels if with_labels else augment


def build_dataset(paths, labels=None, bsize=32, cache=True,
decode_fn=None, augment_fn=None,
augment=True, augmentAdv=False, augmentAdvSeg=False, repeat=True, shuffle=1024,
cache_dir=""):
if cache_dir != "" and cache is True:
os.makedirs(cache_dir, exist_ok=True)

if decode_fn is None:
decode_fn = build_decoder(labels is not None)

if augment_fn is None:
augment_fn = build_augmenter(labels is not None)

AUTO = tf.data.experimental.AUTOTUNE
slices = paths if labels is None else (paths, labels)

dset = tf.data.Dataset.from_tensor_slices(slices)
dset = dset.map(decode_fn, num_parallel_calls=AUTO)
dset = dset.cache(cache_dir) if cache else dset
dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
dset = dset.repeat() if repeat else dset
dset = dset.shuffle(shuffle) if shuffle else dset
dset = dset.batch(bsize)
# dset = dset.map(BatchAdvAugment, num_parallel_calls=AUTO) if augmentAdv else dset
dset = dset.map(BatchAdvAugmentSeg, num_parallel_calls=AUTO) if augmentAdvSeg else dset
dset = dset.prefetch(AUTO)

return dset
Binary file added img/MetaPolyp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/miccaivis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/res.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions layers/convformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import tensorflow as tf

def convformer(input_tensor, filters, padding = "same"):

x = tf.keras.layers.LayerNormalization()(input_tensor)
x = tf.keras.layers.SeparableConv2D(filters, kernel_size = (3,3), padding = padding)(x)
# x = x1 + x2 + x3
x = tf.keras.layers.Attention()([x, x, x])
out = tf.keras.layers.Add()([x, input_tensor])

x1 = tf.keras.layers.Dense(filters, activation = "gelu")(out)
x1 = tf.keras.layers.Dense(filters)(x1)
out_tensor = tf.keras.layers.Add()([out, x1])
return out_tensor
Loading

0 comments on commit bc85cba

Please sign in to comment.