diff --git a/README.md b/README.md new file mode 100644 index 0000000..3ca09f9 --- /dev/null +++ b/README.md @@ -0,0 +1,118 @@ +# Meta-Polyp: a baseline for efficient Polyp segmentation, CBMS 2023 + +This repo is the official implementation for paper: + +Meta-Polyp: a baseline for efficient Polyp segmentation. + +Authors: Quoc-Huy Trinh + +In the IEEE 36th International Symposium on Computer Based Medical Systems (CBMS) 2023. + +Detail of each model modules can be found in original paper. Please citation if you use our implementation for research purpose. + +## Overall architecture + +Architecutre Meta-Polyp baseline model: + +
+ + + +
+ +## Installation + +Our implementation is on ``` Python 3.9 ``` , please make sure to config your environment compatible with the requirements. + +To install all packages, use ``` requirements.txt ``` file to install. Install with ```pip ``` by the following command: + +``` +pip install -r requirements.txt +``` + +All packages will be automatically installed. + +## Config + +All of configs for training and benchmark are in ```./config/``` folder. Please take a look for tuning phase. + +## Training + +For training, use ``` train.py ``` file for start training. + +The following command should be used: + +``` +python train.py +``` + +## Benchmark + +For benchmar, use ```benchmark.py``` file for start testing. + +The following command should be used: + +``` +python benchmark.py +``` + +### Note: +you should fix model_path for your model path and directory to your benchmark dataset. + +## Pretrained weights + +The weight will be update later. + +## Dataset + +In our experiment, we use the dataset config from (PraNet)[https://github.com/DengPingFan/PraNet], with training set from 50% of Kvasir-SEG and 50% of ClinicDB dataset. + +With our test dataset, we use the following: + +In same distribution: + +- Kvasir SEG + +- ClinicDB + + +Out of distribution: + +- Etis dataset + +- ColonDB + +- CVC300 + + +## Results + +The IOU score on SOTA for both 5 datasets: + +
+ + + +
+ +We do some qualiative result with others SOTA method visualization: + +
+ + + +
+ +## Weights + +Coming soon + +## Customize +You can change the backbone from Ca-former to PVT or something else to get different results. + +## Citation + +``` +Coming soon +``` + diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 0000000..a770796 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,54 @@ +# from save_model.pvt_CAM_channel_att_upscale import build_model +import os +import tensorflow as tf +# from metrics.metrics_last import iou_metric, MAE, WFbetaMetric, SMeasure, Emeasure, dice_coef, iou_metric +from metrics.segmentation_metrics import dice_coeff, bce_dice_loss, IoU, zero_IoU, dice_loss +from dataloader.dataloader import build_augmenter, build_dataset, build_decoder +from tensorflow.keras.utils import get_custom_objects +from model_research import build_model + +os.environ["CUDA_VISIBLE_DEVICES"]="0" + +def load_dataset(route): + X_path = '{}/images/'.format(route) + Y_path = '{}/masks/'.format(route) + X_full = sorted(os.listdir(f'{route}/images')) + Y_full = sorted(os.listdir(f'{route}/masks')) + + X_train = [X_path + x for x in X_full] + Y_train = [Y_path + x for x in Y_full] + + test_decoder = build_decoder(with_labels=False, target_size=(img_size, img_size), ext='jpg', + segment=True, ext2='jpg') + test_dataset = build_dataset(X_train, Y_train, bsize=BATCH_SIZE, decode_fn=test_decoder, + augmentAdv=False, augment=False, augmentAdvSeg=False) + return test_dataset, len(X_train) + +def benchmark(route, model, BATCH_SIZE = 32, save_file_name = "benchmark_result.txt"): + + list_of_datasets = os.listdir(route) + f = open(save_file_name,"a") + f.write("\n") + for datasets in list_of_datasets: + print(datasets, ":") + test_dataset, len_data = load_dataset(os.path.join(route,datasets)) + steps_per_epoch = len_data // BATCH_SIZE + loss, dice_coeff, bce_dice_loss, IoU, zero_IoU, mae = model.evaluate(test_dataset, steps=steps_per_epoch) + f.write("{}:".format(datasets)) + f.write("dice_coeff: {}, bce_didce_loss: {}, IoU: {}, zero_IoU: {}, mae: {}".format(dice_coeff, bce_dice_loss, IoU, zero_IoU, mae)) + f.write('\n') + +if __name__ == "__main__": + + img_size = 256 + BATCH_SIZE = 1 + SEED = 1024 + save_path = "best_model.h5" + route_data = "./TestDataset/" + path_to_test_dataset = "./TestDataset/" + model = build_model(img_size) + model.load_weights(save_path) + + model.compile(metrics=[dice_coeff, bce_dice_loss, IoU, zero_IoU, tf.keras.metrics.MeanSquaredError()]) + + benchmark(path_to_test_dataset, model) diff --git a/callbacks/callbacks.py b/callbacks/callbacks.py new file mode 100644 index 0000000..a4e1fef --- /dev/null +++ b/callbacks/callbacks.py @@ -0,0 +1,68 @@ +import math +import tensorflow as tf +import matplotlib.pyplot as plt + +def cosine_annealing_with_warmup(epochIdx): + aMax, aMin = max_lr, min_lr + warmupEpochs, stagnateEpochs, cosAnnealingEpochs = 0, 0, cos_anne_ep + epochIdx = epochIdx % (warmupEpochs + stagnateEpochs + cosAnnealingEpochs) + if(epochIdx < warmupEpochs): + return aMin + (aMax - aMin) / (warmupEpochs - 1) * epochIdx + else: + epochIdx -= warmupEpochs + if(epochIdx < stagnateEpochs): + return aMax + else: + epochIdx -= stagnateEpochs + return aMin + 0.5 * (aMax - aMin) * (1 + math.cos((epochIdx + 1) / (cosAnnealingEpochs + 1) * math.pi)) + +def plt_lr(step, schedulers): + x = range(step) + y = [schedulers(_) for _ in x] + + plt.plot(x, y) + plt.xlabel('Epoch') + plt.ylabel('Learning Rate') + plt.legend() + +def get_callbacks(monitor, mode, save_path, _max_lr, _min_lr, _cos_anne_ep, save_weights_only): + global max_lr + max_lr = _max_lr + global min_lr + min_lr = _min_lr + global cos_anne_ep + cos_anne_ep = _cos_anne_ep + + early_stopping = tf.keras.callbacks.EarlyStopping( + monitor=monitor, + patience=60, + restore_best_weights=True, + mode=mode + ) + + reduce_lr = tf.keras.callbacks.ReduceLROnPlateau( + monitor=monitor, + factor=0.2, + patience=50, + verbose=1, + mode=mode, + min_lr=1e-5, + ) + + checkpoint = tf.keras.callbacks.ModelCheckpoint( + filepath=save_path, + monitor=monitor, + verbose=1, + save_best_only=True, + save_weights_only=save_weights_only, + mode=mode, + save_freq="epoch", + ) + + lr_schedule = tf.keras.callbacks.LearningRateScheduler(cosine_annealing_with_warmup, verbose=0) + + csv_logger = tf.keras.callbacks.CSVLogger('training.csv') + + callbacks = [checkpoint, csv_logger, reduce_lr] +# , reduce_lr + return callbacks \ No newline at end of file diff --git a/dataloader/dataloader.py b/dataloader/dataloader.py new file mode 100644 index 0000000..2d58b23 --- /dev/null +++ b/dataloader/dataloader.py @@ -0,0 +1,132 @@ +import tensorflow as tf +import os + +# def auto_select_accelerator(): +# try: +# tpu = tf.distribute.cluster_resolver.TPUClusterResolver() +# tf.config.experimental_connect_to_cluster(tpu) +# tf.tpu.experimental.initialize_tpu_system(tpu) +# strategy = tf.distribute.experimental.TPUStrategy(tpu) +# print("Running on TPU:", tpu.master()) +# except ValueError: +# strategy = tf.distribute.get_strategy() +# print(f"Running on {strategy.num_replicas_in_sync} replicas") + +# return strategy + +def default_augment_seg(input_image, input_mask): + + input_image = tf.image.random_brightness(input_image, 0.1) + input_image = tf.image.random_contrast(input_image, 0.9, 1.1) + input_image = tf.image.random_saturation(input_image, 0.9, 1.1) + input_image = tf.image.random_hue(input_image, 0.01) + + # flipping random horizontal or vertical + if tf.random.uniform(()) > 0.5: + input_image = tf.image.flip_left_right(input_image) + input_mask = tf.image.flip_left_right(input_mask) + if tf.random.uniform(()) > 0.5: + input_image = tf.image.flip_up_down(input_image) + input_mask = tf.image.flip_up_down(input_mask) + + return input_image, input_mask + +def BatchAdvAugmentSeg(imagesT, masksT): + + images, masks = default_augment_seg(imagesT, masksT) + + return images, masks + +def build_decoder(with_labels=True, target_size=(256, 256), ext='png', segment=False, ext2='png'): + + def decode(path): + file_bytes = tf.io.read_file(path) + if ext == 'png': + img = tf.image.decode_png(file_bytes, channels=3, dct_method='INTEGER_ACCURATE') + elif ext in ['jpg', 'jpeg']: + img = tf.image.decode_jpeg(file_bytes, channels=3, dct_method='INTEGER_ACCURATE') + else: + raise ValueError("Image extension not supported") + + img = tf.image.resize(img, target_size) + # img = tf.cast(img, tf.float32) / 255.0 + + return img + + def decode_mask(path, gray=True): + file_bytes = tf.io.read_file(path) + if ext2 == 'png': + img = tf.image.decode_png(file_bytes, channels=3) + elif ext2 in ['jpg', 'jpeg']: + img = tf.image.decode_jpeg(file_bytes, channels=3) + else: + raise ValueError("Image extension not supported") + + img = tf.image.rgb_to_grayscale(img) if gray else img + img = tf.image.resize(img, target_size) + img = tf.cast(img, tf.float32) / 255.0 + + return img + + def decode_with_labels(path, label): + return decode(path), label + + def decode_with_segments(path, path2, gray=True): + return decode(path), decode_mask(path2, gray) + + if segment: + return decode_with_segments + + return decode_with_labels if with_labels else decode + + +def build_augmenter(with_labels=True): + def augment(img): + + img = tf.image.random_flip_up_down(img) + img = tf.image.random_flip_left_right(img) +# img = tf.image.rot90(img, k=tf.random.uniform([],0,4,tf.int32)) + + img = tf.image.random_brightness(img, 0.1) + img = tf.image.random_contrast(img, 0.9, 1.1) + img = tf.image.random_saturation(img, 0.9, 1.1) + img = tf.image.random_hue(img, 0.02) + + # img = transform_mat(img) + + return img + + def augment_with_labels(img, label): + return augment(img), label + + return augment_with_labels if with_labels else augment + + +def build_dataset(paths, labels=None, bsize=32, cache=True, + decode_fn=None, augment_fn=None, + augment=True, augmentAdv=False, augmentAdvSeg=False, repeat=True, shuffle=1024, + cache_dir=""): + if cache_dir != "" and cache is True: + os.makedirs(cache_dir, exist_ok=True) + + if decode_fn is None: + decode_fn = build_decoder(labels is not None) + + if augment_fn is None: + augment_fn = build_augmenter(labels is not None) + + AUTO = tf.data.experimental.AUTOTUNE + slices = paths if labels is None else (paths, labels) + + dset = tf.data.Dataset.from_tensor_slices(slices) + dset = dset.map(decode_fn, num_parallel_calls=AUTO) + dset = dset.cache(cache_dir) if cache else dset + dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset + dset = dset.repeat() if repeat else dset + dset = dset.shuffle(shuffle) if shuffle else dset + dset = dset.batch(bsize) + # dset = dset.map(BatchAdvAugment, num_parallel_calls=AUTO) if augmentAdv else dset + dset = dset.map(BatchAdvAugmentSeg, num_parallel_calls=AUTO) if augmentAdvSeg else dset + dset = dset.prefetch(AUTO) + + return dset \ No newline at end of file diff --git a/img/MetaPolyp.png b/img/MetaPolyp.png new file mode 100644 index 0000000..c57ed67 Binary files /dev/null and b/img/MetaPolyp.png differ diff --git a/img/miccaivis.png b/img/miccaivis.png new file mode 100644 index 0000000..fccf475 Binary files /dev/null and b/img/miccaivis.png differ diff --git a/img/res.png b/img/res.png new file mode 100644 index 0000000..0d23fdc Binary files /dev/null and b/img/res.png differ diff --git a/layers/convformer.py b/layers/convformer.py new file mode 100644 index 0000000..b7b04b4 --- /dev/null +++ b/layers/convformer.py @@ -0,0 +1,14 @@ +import tensorflow as tf + +def convformer(input_tensor, filters, padding = "same"): + + x = tf.keras.layers.LayerNormalization()(input_tensor) + x = tf.keras.layers.SeparableConv2D(filters, kernel_size = (3,3), padding = padding)(x) + # x = x1 + x2 + x3 + x = tf.keras.layers.Attention()([x, x, x]) + out = tf.keras.layers.Add()([x, input_tensor]) + + x1 = tf.keras.layers.Dense(filters, activation = "gelu")(out) + x1 = tf.keras.layers.Dense(filters)(x1) + out_tensor = tf.keras.layers.Add()([out, x1]) + return out_tensor \ No newline at end of file diff --git a/layers/upsampling.py b/layers/upsampling.py new file mode 100644 index 0000000..d32da54 --- /dev/null +++ b/layers/upsampling.py @@ -0,0 +1,34 @@ +import tensorflow as tf +from tensorflow.keras.layers import Conv2D + +def bn_act(inputs, activation='swish'): + + x = tf.keras.layers.BatchNormalization()(inputs) + if activation: + x = tf.keras.layers.Activation(activation)(x) + + return x + +def decode(input_tensor, filters, scale = 2, activation = 'relu'): + + x1 = tf.keras.layers.Conv2D(filters, (1, 1), activation=activation, use_bias=False, + kernel_initializer='he_normal', padding = 'same')(input_tensor) + + x2 = tf.keras.layers.Conv2D(filters, (3, 3), activation=activation, + use_bias=False, padding = 'same')(input_tensor) + + merge = tf.keras.layers.Add()([x1, x2]) + x = tf.keras.layers.UpSampling2D((scale, scale))(merge) + + skip_feature = tf.keras.layers.Conv2D(filters, (3, 3), activation=activation, use_bias=False, + kernel_initializer='he_normal', padding = 'same')(merge) + + skip_feature = tf.keras.layers.Conv2D(filters, (1, 1), activation=activation, use_bias=False, + kernel_initializer='he_normal', padding = 'same')(skip_feature) + + merge = tf.keras.layers.Add()([merge, skip_feature]) + + x = bn_act(x, activation = activation) + + + return x \ No newline at end of file diff --git a/layers/util_layers.py b/layers/util_layers.py new file mode 100644 index 0000000..0d90d55 --- /dev/null +++ b/layers/util_layers.py @@ -0,0 +1,28 @@ +import tensorflow as tf +from tensorflow.keras.layers import Conv2D + +def bn_act(inputs, activation='swish'): + + x = tf.keras.layers.BatchNormalization()(inputs) + if activation: + x = tf.keras.layers.Activation(activation)(x) + + return x + +def conv_bn_act(inputs, filters, kernel_size, strides=(1, 1), activation='relu', padding='same'): + + x = Conv2D(filters, kernel_size=kernel_size, padding=padding)(inputs) + x = bn_act(x, activation=activation) + + return x + +def merge(l, filters=None): + if filters is None: + channel_axis = 1 if K.image_data_format() == "channels_first" else -1 + filters = l[0].shape[channel_axis] + + x = tf.keras.layers.Add()([l[0],l[1]]) + + # x = block(x, filters) + + return x \ No newline at end of file diff --git a/metrics/segmentation_metrics.py b/metrics/segmentation_metrics.py new file mode 100644 index 0000000..1408d14 --- /dev/null +++ b/metrics/segmentation_metrics.py @@ -0,0 +1,57 @@ +import tensorflow as tf +import tensorflow.keras.backend as K +from tensorflow.keras.losses import binary_crossentropy +import numpy as np +def dice_coeff(y_true, y_pred): + + _epsilon = 10 ** -7 + intersections = tf.reduce_sum(y_true * y_pred) + unions = tf.reduce_sum(y_true + y_pred) + dice_scores = (2.0 * intersections + _epsilon) / (unions + _epsilon) + + return dice_scores + +def dice_loss(y_true, y_pred): + + loss = 1 - dice_coeff(y_true, y_pred) + + return loss + +def total_loss(y_true, y_pred): + return 0.5*binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred) + +def IoU(y_true, y_pred, eps=1e-6): + + intersection = K.sum(y_true * y_pred, axis=[1,2,3]) + union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3]) - intersection + + return K.mean( (intersection + eps) / (union + eps), axis=0) + +def zero_IoU(y_true, y_pred): + + return IoU(1-y_true, 1-y_pred) + +def bce_dice_loss(y_true, y_pred): + + return binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred) + +def tversky(y_true, y_pred, smooth=1, alpha=0.7): + + y_true_pos = tf.reshape(y_true,[-1]) + y_pred_pos = tf.reshape(y_pred,[-1]) + true_pos = tf.reduce_sum(y_true_pos * y_pred_pos) + false_neg = tf.reduce_sum(y_true_pos * (1 - y_pred_pos)) + false_pos = tf.reduce_sum((1 - y_true_pos) * y_pred_pos) + + return (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth) + + +def tversky_loss(y_true, y_pred): + + return 1 - tversky(y_true, y_pred) + +def focal_tversky_loss(y_true, y_pred, gamma=0.75): + + tv = tversky(y_true, y_pred) + + return K.pow((1 - tv), gamma) diff --git a/model.py b/model.py new file mode 100644 index 0000000..417043b --- /dev/null +++ b/model.py @@ -0,0 +1,49 @@ +import tensorflow as tf +from keras_cv_attention_models import caformer +from layers.upsampling import decode +from layers.convformer import convformer +from layers.util_layers import merge, conv_bn_act +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Conv2D +import tensorflow.keras.backend as K + +def build_model(img_size = 256, num_classes = 1): + backbone = caformer.CAFormerS18(input_shape=(256, 256, 3), pretrained="imagenet", num_classes = 0) + + layer_names = ['stack4_block3_mlp_Dense_1', 'stack3_block9_mlp_Dense_1', 'stack2_block3_mlp_Dense_1', 'stack1_block3_mlp_Dense_1'] + layers = [backbone.get_layer(x).output for x in layer_names] + + channel_axis = 1 if K.image_data_format() == "channels_first" else -1 + h_axis, w_axis = [2, 3] if K.image_data_format() == "channels_first" else [1, 2] + + x = layers[0] + + upscale_feature = decode(x, scale = 4, filters = x.shape[channel_axis]) + + for i, layer in enumerate(layers[1:]): + + x = decode(x, scale = 2, filters = layer.shape[channel_axis]) + + layer_fusion = convformer(layer) + + ## Doing multi-level concatenation + if (i%2 == 1): + upscale_feature = tf.keras.layers.Conv2D(layer.shape[channel_axis], (1, 1), activation = "relu", padding = "same")(upscale_feature) + x = tf.keras.layers.Add()([x, upscale_feature]) + x = tf.keras.layers.Conv2D(x.shape[channel_axis], (1, 1), activation = "relu", padding = "same")(x) + + x = merge([x, layer_fusion], layer.shape[channel_axis]) + x = conv_bn_act(x, layer.shape[channel_axis], (1, 1)) + + ## Upscale for next level feature + if (i%2 == 1): + upscale_feature = decode(x, scale = 8, filters = layer.shape[channel_axis]) + + filters = x.shape[channel_axis] //2 + x = decode(x, filters, 4) + x = tf.keras.layers.Add()([x, upscale_feature]) + x = conv_bn_act(x, filters, 1) + x = Conv2D(num_classes, kernel_size=1, padding='same', activation='sigmoid')(x) + model = Model(backbone.input, x) + + return model \ No newline at end of file diff --git a/optimizers/lion_opt.py b/optimizers/lion_opt.py new file mode 100644 index 0000000..1cb8ed4 --- /dev/null +++ b/optimizers/lion_opt.py @@ -0,0 +1,86 @@ +import tensorflow as tf +class Lion(tf.keras.optimizers.legacy.Optimizer): + r"""Optimizer that implements the Lion algorithm.""" + + def __init__(self, + learning_rate=0.0001, + beta_1=0.9, + beta_2=0.99, + wd=0, + name='lion', + **kwargs): + """Construct a new Lion optimizer.""" + + super(Lion, self).__init__(name, **kwargs) + self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) + self._set_hyper('beta_1', beta_1) + self._set_hyper('beta_2', beta_2) + self._set_hyper('wd', wd) + + def _create_slots(self, var_list): + # Create slots for the first and second moments. + # Separate for-loops to respect the ordering of slot variables from v1. + for var in var_list: + self.add_slot(var, 'm') + + def _prepare_local(self, var_device, var_dtype, apply_state): + super(Lion, self)._prepare_local(var_device, var_dtype, apply_state) + + beta_1_t = tf.identity(self._get_hyper('beta_1', var_dtype)) + beta_2_t = tf.identity(self._get_hyper('beta_2', var_dtype)) + wd_t = tf.identity(self._get_hyper('wd', var_dtype)) + lr = apply_state[(var_device, var_dtype)]['lr_t'] + apply_state[(var_device, var_dtype)].update( + dict( + lr=lr, + beta_1_t=beta_1_t, + one_minus_beta_1_t=1 - beta_1_t, + beta_2_t=beta_2_t, + one_minus_beta_2_t=1 - beta_2_t, + wd_t=wd_t)) + + @tf.function(jit_compile=True) + def _resource_apply_dense(self, grad, var, apply_state=None): + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = ((apply_state or {}).get((var_device, var_dtype)) or + self._fallback_apply_state(var_device, var_dtype)) + + m = self.get_slot(var, 'm') + var_t = var.assign_sub( + coefficients['lr_t'] * + (tf.math.sign(m * coefficients['beta_1_t'] + + grad * coefficients['one_minus_beta_1_t']) + + var * coefficients['wd_t'])) + with tf.control_dependencies([var_t]): + m.assign(m * coefficients['beta_2_t'] + + grad * coefficients['one_minus_beta_2_t']) + + @tf.function(jit_compile=True) + def _resource_apply_sparse(self, grad, var, indices, apply_state=None): + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = ((apply_state or {}).get((var_device, var_dtype)) or + self._fallback_apply_state(var_device, var_dtype)) + + m = self.get_slot(var, 'm') + m_t = m.assign(m * coefficients['beta_1_t']) + m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] + m_t = m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices)) + var_t = var.assign_sub(coefficients['lr'] * + (tf.math.sign(m_t) + var * coefficients['wd_t'])) + + with tf.control_dependencies([var_t]): + m_t = m_t.scatter_add(tf.IndexedSlices(-m_scaled_g_values, indices)) + m_t = m_t.assign(m_t * coefficients['beta_2_t'] / + coefficients['beta_1_t']) + m_scaled_g_values = grad * coefficients['one_minus_beta_2_t'] + m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices)) + + def get_config(self): + config = super(Lion, self).get_config() + config.update({ + 'learning_rate': self._serialize_hyperparameter('learning_rate'), + 'beta_1': self._serialize_hyperparameter('beta_1'), + 'beta_2': self._serialize_hyperparameter('beta_2'), + 'wd': self._serialize_hyperparameter('wd'), + }) + return config \ No newline at end of file diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..ce43352 --- /dev/null +++ b/predict.py @@ -0,0 +1,25 @@ +import cv2 +from model import build_model +import numpy as np + +def load_model(model_path): + model = build_model() + model.load_weights(model_path) + return model + +def predict_single(model, imgPath): + img = cv2.imread(imgPath) + img /= 255 + result = model.predict(img) + return result + +if __name__ == "__main__": + save_path = "best_model.h5" + img_in = "test.png" + img_out = "out.png" + + model = load_model(save_path) + + mask = predict_single(model, img_out) + mask_out = np.dstack([mask, mask, mask]) + cv2.imwrite(img_out, mask_out) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..55b6111 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +opencv-python +tensorflow==2.11.0 +Pillow +tqdm +scikit-learn +numpy +matplotlib \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..3fab03a --- /dev/null +++ b/train.py @@ -0,0 +1,126 @@ +from sklearn.model_selection import train_test_split +import tensorflow as tf +from metrics.segmentation_metrics import dice_coeff, bce_dice_loss, IoU, zero_IoU, dice_loss, total_loss +from tensorflow.keras.utils import get_custom_objects +import os +from callbacks.callbacks import get_callbacks, cosine_annealing_with_warmup +from dataloader.dataloader import build_augmenter, build_dataset, build_decoder +# from supervision.dataloader import build_augmenter, build_dataset, build_decoder +from model import create_segment_model, build_model +import os +import tensorflow_addons as tfa +from optimizers.lion_opt import Lion + +os.environ["CUDA_VISIBLE_DEVICES"]="2" + +img_size = 256 +BATCH_SIZE = 8 +SEED = 42 +save_path = "best_model.h5" + +valid_size = 0.1 +test_size = 0.15 +epochs = 350 +save_weights_only = True +max_lr = 1e-4 +min_lr = 1e-6 + +# lr_schedule = tf.keras.callbacks.LearningRateScheduler(cosine_annealing_with_warmup, verbose=0) +# opts = tfa.optimizers.AdamW(lr= 1e-3, weight_decay = lr_schedule) +# opts = tf.keras.optimizers.SGD(lr=1e-4) +# route = "./Kvasir-SEG/" +# X_path = '/root/tqhuy/Polyp/PEFNet-main/Kvasir-SEG/images/' +# Y_path = '/root/tqhuy/Polyp/PEFNet-main/Kvasir-SEG/masks/' + +model = build_model(img_size) +def myprint(s): + with open('modelsummary.txt','a') as f: + print(s, file=f) + +model.summary(print_fn=myprint) +model.summary() +# model = create_segment_model() +starter_learning_rate = 1e-4 +end_learning_rate = 1e-6 +decay_steps = 1000 +learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( + starter_learning_rate, + decay_steps, + end_learning_rate, + power=0.2) + +opts = tfa.optimizers.AdamW(learning_rate = 1e-4, weight_decay = learning_rate_fn) + +get_custom_objects().update({"dice": dice_loss}) +model.compile(optimizer = opts, + loss='dice', + metrics=[dice_coeff,bce_dice_loss, IoU, zero_IoU]) + +# model.summary() +route = './TrainDataset' +X_path = './TrainDataset/images/' +Y_path = './TrainDataset/masks/' + +# route = './Kvasir-SEG' +# X_path = './Kvasir-SEG/images/' +# Y_path = './Kvasir-SEG/masks/' + +X_full = sorted(os.listdir(f'{route}/images')) +Y_full = sorted(os.listdir(f'{route}/masks')) + +print(len(X_full)) + +# valid_size = 0.1 +# test_size = 0. + +X_train, X_valid = train_test_split(X_full, test_size=valid_size, random_state=SEED) +Y_train, Y_valid = train_test_split(Y_full, test_size=valid_size, random_state=SEED) + +X_train, X_test = train_test_split(X_train, test_size=test_size, random_state=SEED) +Y_train, Y_test = train_test_split(Y_train, test_size=test_size, random_state=SEED) + +X_train = [X_path + x for x in X_train] +X_valid = [X_path + x for x in X_valid] +X_test = [X_path + x for x in X_test] + +Y_train = [Y_path + x for x in Y_train] +Y_valid = [Y_path + x for x in Y_valid] +Y_test = [Y_path + x for x in Y_test] + +print("N Train:", len(X_train)) +print("N Valid:", len(X_valid)) +print("N test:", len(X_test)) +# print(X_train) +train_decoder = build_decoder(with_labels=True, target_size=(img_size, img_size), ext='jpg', segment=True, ext2='jpg') +train_dataset = build_dataset(X_train, Y_train, bsize=BATCH_SIZE, decode_fn=train_decoder, + augmentAdv=False, augment=False, augmentAdvSeg=True) + +valid_decoder = build_decoder(with_labels=True, target_size=(img_size, img_size), ext='jpg', segment=True, ext2='jpg') +valid_dataset = build_dataset(X_valid, Y_valid, bsize=BATCH_SIZE, decode_fn=valid_decoder, + augmentAdv=False, augment=False, repeat=False, shuffle=False, + augmentAdvSeg=False) + +test_decoder = build_decoder(with_labels=True, target_size=(img_size, img_size), ext='jpg', segment=True, ext2='jpg') +test_dataset = build_dataset(X_test, Y_test, bsize=BATCH_SIZE, decode_fn=test_decoder, + augmentAdv=False, augment=False, repeat=False, shuffle=False, + augmentAdvSeg=False) + +callbacks = get_callbacks(monitor = 'val_loss', mode = 'min', save_path = save_path, _max_lr = max_lr + , _min_lr = min_lr , _cos_anne_ep = 1000, save_weights_only = save_weights_only) + +steps_per_epoch = len(X_train) // BATCH_SIZE + +print("START TRAINING:") + +print(train_dataset) +his = model.fit(train_dataset, + epochs=epochs, + verbose=1, + callbacks=callbacks, + steps_per_epoch=steps_per_epoch, + validation_data=valid_dataset) + +model.load_weights(save_path) + +model.evaluate(test_dataset) +model.save("final_model.h5") \ No newline at end of file diff --git a/visualize.py b/visualize.py new file mode 100644 index 0000000..4609142 --- /dev/null +++ b/visualize.py @@ -0,0 +1,67 @@ +import os +import tensorflow as tf +from metrics.segmentation_metrics import dice_coeff, bce_dice_loss, IoU, zero_IoU, dice_loss +from dataloader.dataloader import build_augmenter, build_dataset, build_decoder +from tensorflow.keras.utils import get_custom_objects +from model import build_model +import cv2 +import numpy as np +# from save_model.best_up_ca import build_model +# from save_model.ca_best_msf import build_model +import matplotlib.pyplot as plt +os.environ["CUDA_VISIBLE_DEVICES"]="2" + + + +def load_dataset(route, img_size = 256): + BATCH_SIZE = 1 + X_path = '{}/images/'.format(route) + Y_path = '{}/masks/'.format(route) + X_full = sorted(os.listdir(f'{route}/images')) + Y_full = sorted(os.listdir(f'{route}/masks')) + + X_train = [X_path + x for x in X_full] + Y_train = [Y_path + x for x in Y_full] + + test_decoder = build_decoder(with_labels=True, target_size=(img_size, img_size), ext='jpg', segment=True, ext2='jpg') + test_dataset = build_dataset(X_train, Y_train, bsize=BATCH_SIZE, decode_fn=test_decoder, + augmentAdv=False, augment=False, augmentAdvSeg=False, shuffle = None) + return test_dataset, len(X_train) + +def predict(model, dataset, len_data, outdir ="./save_vis/Etis/"): + steps_per_epoch = len_data//1 + masks = model.predict(dataset, steps=steps_per_epoch) + # print(masks.shape) + i = 0 + for x, y in dataset: + print(y[0].shape) + # print(i, masks[i].shape) + a = masks[i] + mask_new = np.dstack([a, a, a]) + # print(x.shape, y.shape) + gt = np.dstack([y[0], y[0], y[0]]) + # gt = cv2.cvtColor(y[0], cv2.COLOR_GRAY2RGB) + # true = cv2.cvtColor(x[0], cv2.COLOR_BGR2RGB) + im_h = np.concatenate([x[0], gt * 255, mask_new *255], axis = 1) + cv2.imwrite("{}/{}.jpg".format(outdir, i), im_h) + i+=1 + +def visualize(src_dir, model, outdir ="./save_vis/Etis/"): + dataset, len_data = load_dataset(src_dir) + predict(model, dataset, len_data, outdir) + +if __name__ == "__main__": + + BATCH_SIZE = 16 + img_size = 256 + SEED = 1024 + save_path = "best_model.h5" + route_data = "./TestDataset/" + outdir ="./save_vis/cvc300/" + src_dir = "./TestDataset/CVC-300" + + model = build_model(img_size) + model.load_weights(save_path) + + visualize(src_dir, model, outdir) + \ No newline at end of file