diff --git a/README.md b/README.md
new file mode 100644
index 0000000..3ca09f9
--- /dev/null
+++ b/README.md
@@ -0,0 +1,118 @@
+# Meta-Polyp: a baseline for efficient Polyp segmentation, CBMS 2023
+
+This repo is the official implementation for paper:
+
+Meta-Polyp: a baseline for efficient Polyp segmentation.
+
+Authors: Quoc-Huy Trinh
+
+In the IEEE 36th International Symposium on Computer Based Medical Systems (CBMS) 2023.
+
+Detail of each model modules can be found in original paper. Please citation if you use our implementation for research purpose.
+
+## Overall architecture
+
+Architecutre Meta-Polyp baseline model:
+
+
+
+## Installation
+
+Our implementation is on ``` Python 3.9 ``` , please make sure to config your environment compatible with the requirements.
+
+To install all packages, use ``` requirements.txt ``` file to install. Install with ```pip ``` by the following command:
+
+```
+pip install -r requirements.txt
+```
+
+All packages will be automatically installed.
+
+## Config
+
+All of configs for training and benchmark are in ```./config/``` folder. Please take a look for tuning phase.
+
+## Training
+
+For training, use ``` train.py ``` file for start training.
+
+The following command should be used:
+
+```
+python train.py
+```
+
+## Benchmark
+
+For benchmar, use ```benchmark.py``` file for start testing.
+
+The following command should be used:
+
+```
+python benchmark.py
+```
+
+### Note:
+you should fix model_path for your model path and directory to your benchmark dataset.
+
+## Pretrained weights
+
+The weight will be update later.
+
+## Dataset
+
+In our experiment, we use the dataset config from (PraNet)[https://github.com/DengPingFan/PraNet], with training set from 50% of Kvasir-SEG and 50% of ClinicDB dataset.
+
+With our test dataset, we use the following:
+
+In same distribution:
+
+- Kvasir SEG
+
+- ClinicDB
+
+
+Out of distribution:
+
+- Etis dataset
+
+- ColonDB
+
+- CVC300
+
+
+## Results
+
+The IOU score on SOTA for both 5 datasets:
+
+
+
+We do some qualiative result with others SOTA method visualization:
+
+
+
+## Weights
+
+Coming soon
+
+## Customize
+You can change the backbone from Ca-former to PVT or something else to get different results.
+
+## Citation
+
+```
+Coming soon
+```
+
diff --git a/benchmark.py b/benchmark.py
new file mode 100644
index 0000000..a770796
--- /dev/null
+++ b/benchmark.py
@@ -0,0 +1,54 @@
+# from save_model.pvt_CAM_channel_att_upscale import build_model
+import os
+import tensorflow as tf
+# from metrics.metrics_last import iou_metric, MAE, WFbetaMetric, SMeasure, Emeasure, dice_coef, iou_metric
+from metrics.segmentation_metrics import dice_coeff, bce_dice_loss, IoU, zero_IoU, dice_loss
+from dataloader.dataloader import build_augmenter, build_dataset, build_decoder
+from tensorflow.keras.utils import get_custom_objects
+from model_research import build_model
+
+os.environ["CUDA_VISIBLE_DEVICES"]="0"
+
+def load_dataset(route):
+ X_path = '{}/images/'.format(route)
+ Y_path = '{}/masks/'.format(route)
+ X_full = sorted(os.listdir(f'{route}/images'))
+ Y_full = sorted(os.listdir(f'{route}/masks'))
+
+ X_train = [X_path + x for x in X_full]
+ Y_train = [Y_path + x for x in Y_full]
+
+ test_decoder = build_decoder(with_labels=False, target_size=(img_size, img_size), ext='jpg',
+ segment=True, ext2='jpg')
+ test_dataset = build_dataset(X_train, Y_train, bsize=BATCH_SIZE, decode_fn=test_decoder,
+ augmentAdv=False, augment=False, augmentAdvSeg=False)
+ return test_dataset, len(X_train)
+
+def benchmark(route, model, BATCH_SIZE = 32, save_file_name = "benchmark_result.txt"):
+
+ list_of_datasets = os.listdir(route)
+ f = open(save_file_name,"a")
+ f.write("\n")
+ for datasets in list_of_datasets:
+ print(datasets, ":")
+ test_dataset, len_data = load_dataset(os.path.join(route,datasets))
+ steps_per_epoch = len_data // BATCH_SIZE
+ loss, dice_coeff, bce_dice_loss, IoU, zero_IoU, mae = model.evaluate(test_dataset, steps=steps_per_epoch)
+ f.write("{}:".format(datasets))
+ f.write("dice_coeff: {}, bce_didce_loss: {}, IoU: {}, zero_IoU: {}, mae: {}".format(dice_coeff, bce_dice_loss, IoU, zero_IoU, mae))
+ f.write('\n')
+
+if __name__ == "__main__":
+
+ img_size = 256
+ BATCH_SIZE = 1
+ SEED = 1024
+ save_path = "best_model.h5"
+ route_data = "./TestDataset/"
+ path_to_test_dataset = "./TestDataset/"
+ model = build_model(img_size)
+ model.load_weights(save_path)
+
+ model.compile(metrics=[dice_coeff, bce_dice_loss, IoU, zero_IoU, tf.keras.metrics.MeanSquaredError()])
+
+ benchmark(path_to_test_dataset, model)
diff --git a/callbacks/callbacks.py b/callbacks/callbacks.py
new file mode 100644
index 0000000..a4e1fef
--- /dev/null
+++ b/callbacks/callbacks.py
@@ -0,0 +1,68 @@
+import math
+import tensorflow as tf
+import matplotlib.pyplot as plt
+
+def cosine_annealing_with_warmup(epochIdx):
+ aMax, aMin = max_lr, min_lr
+ warmupEpochs, stagnateEpochs, cosAnnealingEpochs = 0, 0, cos_anne_ep
+ epochIdx = epochIdx % (warmupEpochs + stagnateEpochs + cosAnnealingEpochs)
+ if(epochIdx < warmupEpochs):
+ return aMin + (aMax - aMin) / (warmupEpochs - 1) * epochIdx
+ else:
+ epochIdx -= warmupEpochs
+ if(epochIdx < stagnateEpochs):
+ return aMax
+ else:
+ epochIdx -= stagnateEpochs
+ return aMin + 0.5 * (aMax - aMin) * (1 + math.cos((epochIdx + 1) / (cosAnnealingEpochs + 1) * math.pi))
+
+def plt_lr(step, schedulers):
+ x = range(step)
+ y = [schedulers(_) for _ in x]
+
+ plt.plot(x, y)
+ plt.xlabel('Epoch')
+ plt.ylabel('Learning Rate')
+ plt.legend()
+
+def get_callbacks(monitor, mode, save_path, _max_lr, _min_lr, _cos_anne_ep, save_weights_only):
+ global max_lr
+ max_lr = _max_lr
+ global min_lr
+ min_lr = _min_lr
+ global cos_anne_ep
+ cos_anne_ep = _cos_anne_ep
+
+ early_stopping = tf.keras.callbacks.EarlyStopping(
+ monitor=monitor,
+ patience=60,
+ restore_best_weights=True,
+ mode=mode
+ )
+
+ reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
+ monitor=monitor,
+ factor=0.2,
+ patience=50,
+ verbose=1,
+ mode=mode,
+ min_lr=1e-5,
+ )
+
+ checkpoint = tf.keras.callbacks.ModelCheckpoint(
+ filepath=save_path,
+ monitor=monitor,
+ verbose=1,
+ save_best_only=True,
+ save_weights_only=save_weights_only,
+ mode=mode,
+ save_freq="epoch",
+ )
+
+ lr_schedule = tf.keras.callbacks.LearningRateScheduler(cosine_annealing_with_warmup, verbose=0)
+
+ csv_logger = tf.keras.callbacks.CSVLogger('training.csv')
+
+ callbacks = [checkpoint, csv_logger, reduce_lr]
+# , reduce_lr
+ return callbacks
\ No newline at end of file
diff --git a/dataloader/dataloader.py b/dataloader/dataloader.py
new file mode 100644
index 0000000..2d58b23
--- /dev/null
+++ b/dataloader/dataloader.py
@@ -0,0 +1,132 @@
+import tensorflow as tf
+import os
+
+# def auto_select_accelerator():
+# try:
+# tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
+# tf.config.experimental_connect_to_cluster(tpu)
+# tf.tpu.experimental.initialize_tpu_system(tpu)
+# strategy = tf.distribute.experimental.TPUStrategy(tpu)
+# print("Running on TPU:", tpu.master())
+# except ValueError:
+# strategy = tf.distribute.get_strategy()
+# print(f"Running on {strategy.num_replicas_in_sync} replicas")
+
+# return strategy
+
+def default_augment_seg(input_image, input_mask):
+
+ input_image = tf.image.random_brightness(input_image, 0.1)
+ input_image = tf.image.random_contrast(input_image, 0.9, 1.1)
+ input_image = tf.image.random_saturation(input_image, 0.9, 1.1)
+ input_image = tf.image.random_hue(input_image, 0.01)
+
+ # flipping random horizontal or vertical
+ if tf.random.uniform(()) > 0.5:
+ input_image = tf.image.flip_left_right(input_image)
+ input_mask = tf.image.flip_left_right(input_mask)
+ if tf.random.uniform(()) > 0.5:
+ input_image = tf.image.flip_up_down(input_image)
+ input_mask = tf.image.flip_up_down(input_mask)
+
+ return input_image, input_mask
+
+def BatchAdvAugmentSeg(imagesT, masksT):
+
+ images, masks = default_augment_seg(imagesT, masksT)
+
+ return images, masks
+
+def build_decoder(with_labels=True, target_size=(256, 256), ext='png', segment=False, ext2='png'):
+
+ def decode(path):
+ file_bytes = tf.io.read_file(path)
+ if ext == 'png':
+ img = tf.image.decode_png(file_bytes, channels=3, dct_method='INTEGER_ACCURATE')
+ elif ext in ['jpg', 'jpeg']:
+ img = tf.image.decode_jpeg(file_bytes, channels=3, dct_method='INTEGER_ACCURATE')
+ else:
+ raise ValueError("Image extension not supported")
+
+ img = tf.image.resize(img, target_size)
+ # img = tf.cast(img, tf.float32) / 255.0
+
+ return img
+
+ def decode_mask(path, gray=True):
+ file_bytes = tf.io.read_file(path)
+ if ext2 == 'png':
+ img = tf.image.decode_png(file_bytes, channels=3)
+ elif ext2 in ['jpg', 'jpeg']:
+ img = tf.image.decode_jpeg(file_bytes, channels=3)
+ else:
+ raise ValueError("Image extension not supported")
+
+ img = tf.image.rgb_to_grayscale(img) if gray else img
+ img = tf.image.resize(img, target_size)
+ img = tf.cast(img, tf.float32) / 255.0
+
+ return img
+
+ def decode_with_labels(path, label):
+ return decode(path), label
+
+ def decode_with_segments(path, path2, gray=True):
+ return decode(path), decode_mask(path2, gray)
+
+ if segment:
+ return decode_with_segments
+
+ return decode_with_labels if with_labels else decode
+
+
+def build_augmenter(with_labels=True):
+ def augment(img):
+
+ img = tf.image.random_flip_up_down(img)
+ img = tf.image.random_flip_left_right(img)
+# img = tf.image.rot90(img, k=tf.random.uniform([],0,4,tf.int32))
+
+ img = tf.image.random_brightness(img, 0.1)
+ img = tf.image.random_contrast(img, 0.9, 1.1)
+ img = tf.image.random_saturation(img, 0.9, 1.1)
+ img = tf.image.random_hue(img, 0.02)
+
+ # img = transform_mat(img)
+
+ return img
+
+ def augment_with_labels(img, label):
+ return augment(img), label
+
+ return augment_with_labels if with_labels else augment
+
+
+def build_dataset(paths, labels=None, bsize=32, cache=True,
+ decode_fn=None, augment_fn=None,
+ augment=True, augmentAdv=False, augmentAdvSeg=False, repeat=True, shuffle=1024,
+ cache_dir=""):
+ if cache_dir != "" and cache is True:
+ os.makedirs(cache_dir, exist_ok=True)
+
+ if decode_fn is None:
+ decode_fn = build_decoder(labels is not None)
+
+ if augment_fn is None:
+ augment_fn = build_augmenter(labels is not None)
+
+ AUTO = tf.data.experimental.AUTOTUNE
+ slices = paths if labels is None else (paths, labels)
+
+ dset = tf.data.Dataset.from_tensor_slices(slices)
+ dset = dset.map(decode_fn, num_parallel_calls=AUTO)
+ dset = dset.cache(cache_dir) if cache else dset
+ dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
+ dset = dset.repeat() if repeat else dset
+ dset = dset.shuffle(shuffle) if shuffle else dset
+ dset = dset.batch(bsize)
+ # dset = dset.map(BatchAdvAugment, num_parallel_calls=AUTO) if augmentAdv else dset
+ dset = dset.map(BatchAdvAugmentSeg, num_parallel_calls=AUTO) if augmentAdvSeg else dset
+ dset = dset.prefetch(AUTO)
+
+ return dset
\ No newline at end of file
diff --git a/img/MetaPolyp.png b/img/MetaPolyp.png
new file mode 100644
index 0000000..c57ed67
Binary files /dev/null and b/img/MetaPolyp.png differ
diff --git a/img/miccaivis.png b/img/miccaivis.png
new file mode 100644
index 0000000..fccf475
Binary files /dev/null and b/img/miccaivis.png differ
diff --git a/img/res.png b/img/res.png
new file mode 100644
index 0000000..0d23fdc
Binary files /dev/null and b/img/res.png differ
diff --git a/layers/convformer.py b/layers/convformer.py
new file mode 100644
index 0000000..b7b04b4
--- /dev/null
+++ b/layers/convformer.py
@@ -0,0 +1,14 @@
+import tensorflow as tf
+
+def convformer(input_tensor, filters, padding = "same"):
+
+ x = tf.keras.layers.LayerNormalization()(input_tensor)
+ x = tf.keras.layers.SeparableConv2D(filters, kernel_size = (3,3), padding = padding)(x)
+ # x = x1 + x2 + x3
+ x = tf.keras.layers.Attention()([x, x, x])
+ out = tf.keras.layers.Add()([x, input_tensor])
+
+ x1 = tf.keras.layers.Dense(filters, activation = "gelu")(out)
+ x1 = tf.keras.layers.Dense(filters)(x1)
+ out_tensor = tf.keras.layers.Add()([out, x1])
+ return out_tensor
\ No newline at end of file
diff --git a/layers/upsampling.py b/layers/upsampling.py
new file mode 100644
index 0000000..d32da54
--- /dev/null
+++ b/layers/upsampling.py
@@ -0,0 +1,34 @@
+import tensorflow as tf
+from tensorflow.keras.layers import Conv2D
+
+def bn_act(inputs, activation='swish'):
+
+ x = tf.keras.layers.BatchNormalization()(inputs)
+ if activation:
+ x = tf.keras.layers.Activation(activation)(x)
+
+ return x
+
+def decode(input_tensor, filters, scale = 2, activation = 'relu'):
+
+ x1 = tf.keras.layers.Conv2D(filters, (1, 1), activation=activation, use_bias=False,
+ kernel_initializer='he_normal', padding = 'same')(input_tensor)
+
+ x2 = tf.keras.layers.Conv2D(filters, (3, 3), activation=activation,
+ use_bias=False, padding = 'same')(input_tensor)
+
+ merge = tf.keras.layers.Add()([x1, x2])
+ x = tf.keras.layers.UpSampling2D((scale, scale))(merge)
+
+ skip_feature = tf.keras.layers.Conv2D(filters, (3, 3), activation=activation, use_bias=False,
+ kernel_initializer='he_normal', padding = 'same')(merge)
+
+ skip_feature = tf.keras.layers.Conv2D(filters, (1, 1), activation=activation, use_bias=False,
+ kernel_initializer='he_normal', padding = 'same')(skip_feature)
+
+ merge = tf.keras.layers.Add()([merge, skip_feature])
+
+ x = bn_act(x, activation = activation)
+
+
+ return x
\ No newline at end of file
diff --git a/layers/util_layers.py b/layers/util_layers.py
new file mode 100644
index 0000000..0d90d55
--- /dev/null
+++ b/layers/util_layers.py
@@ -0,0 +1,28 @@
+import tensorflow as tf
+from tensorflow.keras.layers import Conv2D
+
+def bn_act(inputs, activation='swish'):
+
+ x = tf.keras.layers.BatchNormalization()(inputs)
+ if activation:
+ x = tf.keras.layers.Activation(activation)(x)
+
+ return x
+
+def conv_bn_act(inputs, filters, kernel_size, strides=(1, 1), activation='relu', padding='same'):
+
+ x = Conv2D(filters, kernel_size=kernel_size, padding=padding)(inputs)
+ x = bn_act(x, activation=activation)
+
+ return x
+
+def merge(l, filters=None):
+ if filters is None:
+ channel_axis = 1 if K.image_data_format() == "channels_first" else -1
+ filters = l[0].shape[channel_axis]
+
+ x = tf.keras.layers.Add()([l[0],l[1]])
+
+ # x = block(x, filters)
+
+ return x
\ No newline at end of file
diff --git a/metrics/segmentation_metrics.py b/metrics/segmentation_metrics.py
new file mode 100644
index 0000000..1408d14
--- /dev/null
+++ b/metrics/segmentation_metrics.py
@@ -0,0 +1,57 @@
+import tensorflow as tf
+import tensorflow.keras.backend as K
+from tensorflow.keras.losses import binary_crossentropy
+import numpy as np
+def dice_coeff(y_true, y_pred):
+
+ _epsilon = 10 ** -7
+ intersections = tf.reduce_sum(y_true * y_pred)
+ unions = tf.reduce_sum(y_true + y_pred)
+ dice_scores = (2.0 * intersections + _epsilon) / (unions + _epsilon)
+
+ return dice_scores
+
+def dice_loss(y_true, y_pred):
+
+ loss = 1 - dice_coeff(y_true, y_pred)
+
+ return loss
+
+def total_loss(y_true, y_pred):
+ return 0.5*binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
+
+def IoU(y_true, y_pred, eps=1e-6):
+
+ intersection = K.sum(y_true * y_pred, axis=[1,2,3])
+ union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3]) - intersection
+
+ return K.mean( (intersection + eps) / (union + eps), axis=0)
+
+def zero_IoU(y_true, y_pred):
+
+ return IoU(1-y_true, 1-y_pred)
+
+def bce_dice_loss(y_true, y_pred):
+
+ return binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
+
+def tversky(y_true, y_pred, smooth=1, alpha=0.7):
+
+ y_true_pos = tf.reshape(y_true,[-1])
+ y_pred_pos = tf.reshape(y_pred,[-1])
+ true_pos = tf.reduce_sum(y_true_pos * y_pred_pos)
+ false_neg = tf.reduce_sum(y_true_pos * (1 - y_pred_pos))
+ false_pos = tf.reduce_sum((1 - y_true_pos) * y_pred_pos)
+
+ return (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth)
+
+
+def tversky_loss(y_true, y_pred):
+
+ return 1 - tversky(y_true, y_pred)
+
+def focal_tversky_loss(y_true, y_pred, gamma=0.75):
+
+ tv = tversky(y_true, y_pred)
+
+ return K.pow((1 - tv), gamma)
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..417043b
--- /dev/null
+++ b/model.py
@@ -0,0 +1,49 @@
+import tensorflow as tf
+from keras_cv_attention_models import caformer
+from layers.upsampling import decode
+from layers.convformer import convformer
+from layers.util_layers import merge, conv_bn_act
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import Conv2D
+import tensorflow.keras.backend as K
+
+def build_model(img_size = 256, num_classes = 1):
+ backbone = caformer.CAFormerS18(input_shape=(256, 256, 3), pretrained="imagenet", num_classes = 0)
+
+ layer_names = ['stack4_block3_mlp_Dense_1', 'stack3_block9_mlp_Dense_1', 'stack2_block3_mlp_Dense_1', 'stack1_block3_mlp_Dense_1']
+ layers = [backbone.get_layer(x).output for x in layer_names]
+
+ channel_axis = 1 if K.image_data_format() == "channels_first" else -1
+ h_axis, w_axis = [2, 3] if K.image_data_format() == "channels_first" else [1, 2]
+
+ x = layers[0]
+
+ upscale_feature = decode(x, scale = 4, filters = x.shape[channel_axis])
+
+ for i, layer in enumerate(layers[1:]):
+
+ x = decode(x, scale = 2, filters = layer.shape[channel_axis])
+
+ layer_fusion = convformer(layer)
+
+ ## Doing multi-level concatenation
+ if (i%2 == 1):
+ upscale_feature = tf.keras.layers.Conv2D(layer.shape[channel_axis], (1, 1), activation = "relu", padding = "same")(upscale_feature)
+ x = tf.keras.layers.Add()([x, upscale_feature])
+ x = tf.keras.layers.Conv2D(x.shape[channel_axis], (1, 1), activation = "relu", padding = "same")(x)
+
+ x = merge([x, layer_fusion], layer.shape[channel_axis])
+ x = conv_bn_act(x, layer.shape[channel_axis], (1, 1))
+
+ ## Upscale for next level feature
+ if (i%2 == 1):
+ upscale_feature = decode(x, scale = 8, filters = layer.shape[channel_axis])
+
+ filters = x.shape[channel_axis] //2
+ x = decode(x, filters, 4)
+ x = tf.keras.layers.Add()([x, upscale_feature])
+ x = conv_bn_act(x, filters, 1)
+ x = Conv2D(num_classes, kernel_size=1, padding='same', activation='sigmoid')(x)
+ model = Model(backbone.input, x)
+
+ return model
\ No newline at end of file
diff --git a/optimizers/lion_opt.py b/optimizers/lion_opt.py
new file mode 100644
index 0000000..1cb8ed4
--- /dev/null
+++ b/optimizers/lion_opt.py
@@ -0,0 +1,86 @@
+import tensorflow as tf
+class Lion(tf.keras.optimizers.legacy.Optimizer):
+ r"""Optimizer that implements the Lion algorithm."""
+
+ def __init__(self,
+ learning_rate=0.0001,
+ beta_1=0.9,
+ beta_2=0.99,
+ wd=0,
+ name='lion',
+ **kwargs):
+ """Construct a new Lion optimizer."""
+
+ super(Lion, self).__init__(name, **kwargs)
+ self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
+ self._set_hyper('beta_1', beta_1)
+ self._set_hyper('beta_2', beta_2)
+ self._set_hyper('wd', wd)
+
+ def _create_slots(self, var_list):
+ # Create slots for the first and second moments.
+ # Separate for-loops to respect the ordering of slot variables from v1.
+ for var in var_list:
+ self.add_slot(var, 'm')
+
+ def _prepare_local(self, var_device, var_dtype, apply_state):
+ super(Lion, self)._prepare_local(var_device, var_dtype, apply_state)
+
+ beta_1_t = tf.identity(self._get_hyper('beta_1', var_dtype))
+ beta_2_t = tf.identity(self._get_hyper('beta_2', var_dtype))
+ wd_t = tf.identity(self._get_hyper('wd', var_dtype))
+ lr = apply_state[(var_device, var_dtype)]['lr_t']
+ apply_state[(var_device, var_dtype)].update(
+ dict(
+ lr=lr,
+ beta_1_t=beta_1_t,
+ one_minus_beta_1_t=1 - beta_1_t,
+ beta_2_t=beta_2_t,
+ one_minus_beta_2_t=1 - beta_2_t,
+ wd_t=wd_t))
+
+ @tf.function(jit_compile=True)
+ def _resource_apply_dense(self, grad, var, apply_state=None):
+ var_device, var_dtype = var.device, var.dtype.base_dtype
+ coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
+ self._fallback_apply_state(var_device, var_dtype))
+
+ m = self.get_slot(var, 'm')
+ var_t = var.assign_sub(
+ coefficients['lr_t'] *
+ (tf.math.sign(m * coefficients['beta_1_t'] +
+ grad * coefficients['one_minus_beta_1_t']) +
+ var * coefficients['wd_t']))
+ with tf.control_dependencies([var_t]):
+ m.assign(m * coefficients['beta_2_t'] +
+ grad * coefficients['one_minus_beta_2_t'])
+
+ @tf.function(jit_compile=True)
+ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
+ var_device, var_dtype = var.device, var.dtype.base_dtype
+ coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
+ self._fallback_apply_state(var_device, var_dtype))
+
+ m = self.get_slot(var, 'm')
+ m_t = m.assign(m * coefficients['beta_1_t'])
+ m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
+ m_t = m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices))
+ var_t = var.assign_sub(coefficients['lr'] *
+ (tf.math.sign(m_t) + var * coefficients['wd_t']))
+
+ with tf.control_dependencies([var_t]):
+ m_t = m_t.scatter_add(tf.IndexedSlices(-m_scaled_g_values, indices))
+ m_t = m_t.assign(m_t * coefficients['beta_2_t'] /
+ coefficients['beta_1_t'])
+ m_scaled_g_values = grad * coefficients['one_minus_beta_2_t']
+ m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices))
+
+ def get_config(self):
+ config = super(Lion, self).get_config()
+ config.update({
+ 'learning_rate': self._serialize_hyperparameter('learning_rate'),
+ 'beta_1': self._serialize_hyperparameter('beta_1'),
+ 'beta_2': self._serialize_hyperparameter('beta_2'),
+ 'wd': self._serialize_hyperparameter('wd'),
+ })
+ return config
\ No newline at end of file
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..ce43352
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,25 @@
+import cv2
+from model import build_model
+import numpy as np
+
+def load_model(model_path):
+ model = build_model()
+ model.load_weights(model_path)
+ return model
+
+def predict_single(model, imgPath):
+ img = cv2.imread(imgPath)
+ img /= 255
+ result = model.predict(img)
+ return result
+
+if __name__ == "__main__":
+ save_path = "best_model.h5"
+ img_in = "test.png"
+ img_out = "out.png"
+
+ model = load_model(save_path)
+
+ mask = predict_single(model, img_out)
+ mask_out = np.dstack([mask, mask, mask])
+ cv2.imwrite(img_out, mask_out)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..55b6111
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,7 @@
+opencv-python
+tensorflow==2.11.0
+Pillow
+tqdm
+scikit-learn
+numpy
+matplotlib
\ No newline at end of file
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..3fab03a
--- /dev/null
+++ b/train.py
@@ -0,0 +1,126 @@
+from sklearn.model_selection import train_test_split
+import tensorflow as tf
+from metrics.segmentation_metrics import dice_coeff, bce_dice_loss, IoU, zero_IoU, dice_loss, total_loss
+from tensorflow.keras.utils import get_custom_objects
+import os
+from callbacks.callbacks import get_callbacks, cosine_annealing_with_warmup
+from dataloader.dataloader import build_augmenter, build_dataset, build_decoder
+# from supervision.dataloader import build_augmenter, build_dataset, build_decoder
+from model import create_segment_model, build_model
+import os
+import tensorflow_addons as tfa
+from optimizers.lion_opt import Lion
+
+os.environ["CUDA_VISIBLE_DEVICES"]="2"
+
+img_size = 256
+BATCH_SIZE = 8
+SEED = 42
+save_path = "best_model.h5"
+
+valid_size = 0.1
+test_size = 0.15
+epochs = 350
+save_weights_only = True
+max_lr = 1e-4
+min_lr = 1e-6
+
+# lr_schedule = tf.keras.callbacks.LearningRateScheduler(cosine_annealing_with_warmup, verbose=0)
+# opts = tfa.optimizers.AdamW(lr= 1e-3, weight_decay = lr_schedule)
+# opts = tf.keras.optimizers.SGD(lr=1e-4)
+# route = "./Kvasir-SEG/"
+# X_path = '/root/tqhuy/Polyp/PEFNet-main/Kvasir-SEG/images/'
+# Y_path = '/root/tqhuy/Polyp/PEFNet-main/Kvasir-SEG/masks/'
+
+model = build_model(img_size)
+def myprint(s):
+ with open('modelsummary.txt','a') as f:
+ print(s, file=f)
+
+model.summary(print_fn=myprint)
+model.summary()
+# model = create_segment_model()
+starter_learning_rate = 1e-4
+end_learning_rate = 1e-6
+decay_steps = 1000
+learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
+ starter_learning_rate,
+ decay_steps,
+ end_learning_rate,
+ power=0.2)
+
+opts = tfa.optimizers.AdamW(learning_rate = 1e-4, weight_decay = learning_rate_fn)
+
+get_custom_objects().update({"dice": dice_loss})
+model.compile(optimizer = opts,
+ loss='dice',
+ metrics=[dice_coeff,bce_dice_loss, IoU, zero_IoU])
+
+# model.summary()
+route = './TrainDataset'
+X_path = './TrainDataset/images/'
+Y_path = './TrainDataset/masks/'
+
+# route = './Kvasir-SEG'
+# X_path = './Kvasir-SEG/images/'
+# Y_path = './Kvasir-SEG/masks/'
+
+X_full = sorted(os.listdir(f'{route}/images'))
+Y_full = sorted(os.listdir(f'{route}/masks'))
+
+print(len(X_full))
+
+# valid_size = 0.1
+# test_size = 0.
+
+X_train, X_valid = train_test_split(X_full, test_size=valid_size, random_state=SEED)
+Y_train, Y_valid = train_test_split(Y_full, test_size=valid_size, random_state=SEED)
+
+X_train, X_test = train_test_split(X_train, test_size=test_size, random_state=SEED)
+Y_train, Y_test = train_test_split(Y_train, test_size=test_size, random_state=SEED)
+
+X_train = [X_path + x for x in X_train]
+X_valid = [X_path + x for x in X_valid]
+X_test = [X_path + x for x in X_test]
+
+Y_train = [Y_path + x for x in Y_train]
+Y_valid = [Y_path + x for x in Y_valid]
+Y_test = [Y_path + x for x in Y_test]
+
+print("N Train:", len(X_train))
+print("N Valid:", len(X_valid))
+print("N test:", len(X_test))
+# print(X_train)
+train_decoder = build_decoder(with_labels=True, target_size=(img_size, img_size), ext='jpg', segment=True, ext2='jpg')
+train_dataset = build_dataset(X_train, Y_train, bsize=BATCH_SIZE, decode_fn=train_decoder,
+ augmentAdv=False, augment=False, augmentAdvSeg=True)
+
+valid_decoder = build_decoder(with_labels=True, target_size=(img_size, img_size), ext='jpg', segment=True, ext2='jpg')
+valid_dataset = build_dataset(X_valid, Y_valid, bsize=BATCH_SIZE, decode_fn=valid_decoder,
+ augmentAdv=False, augment=False, repeat=False, shuffle=False,
+ augmentAdvSeg=False)
+
+test_decoder = build_decoder(with_labels=True, target_size=(img_size, img_size), ext='jpg', segment=True, ext2='jpg')
+test_dataset = build_dataset(X_test, Y_test, bsize=BATCH_SIZE, decode_fn=test_decoder,
+ augmentAdv=False, augment=False, repeat=False, shuffle=False,
+ augmentAdvSeg=False)
+
+callbacks = get_callbacks(monitor = 'val_loss', mode = 'min', save_path = save_path, _max_lr = max_lr
+ , _min_lr = min_lr , _cos_anne_ep = 1000, save_weights_only = save_weights_only)
+
+steps_per_epoch = len(X_train) // BATCH_SIZE
+
+print("START TRAINING:")
+
+print(train_dataset)
+his = model.fit(train_dataset,
+ epochs=epochs,
+ verbose=1,
+ callbacks=callbacks,
+ steps_per_epoch=steps_per_epoch,
+ validation_data=valid_dataset)
+
+model.load_weights(save_path)
+
+model.evaluate(test_dataset)
+model.save("final_model.h5")
\ No newline at end of file
diff --git a/visualize.py b/visualize.py
new file mode 100644
index 0000000..4609142
--- /dev/null
+++ b/visualize.py
@@ -0,0 +1,67 @@
+import os
+import tensorflow as tf
+from metrics.segmentation_metrics import dice_coeff, bce_dice_loss, IoU, zero_IoU, dice_loss
+from dataloader.dataloader import build_augmenter, build_dataset, build_decoder
+from tensorflow.keras.utils import get_custom_objects
+from model import build_model
+import cv2
+import numpy as np
+# from save_model.best_up_ca import build_model
+# from save_model.ca_best_msf import build_model
+import matplotlib.pyplot as plt
+os.environ["CUDA_VISIBLE_DEVICES"]="2"
+
+
+
+def load_dataset(route, img_size = 256):
+ BATCH_SIZE = 1
+ X_path = '{}/images/'.format(route)
+ Y_path = '{}/masks/'.format(route)
+ X_full = sorted(os.listdir(f'{route}/images'))
+ Y_full = sorted(os.listdir(f'{route}/masks'))
+
+ X_train = [X_path + x for x in X_full]
+ Y_train = [Y_path + x for x in Y_full]
+
+ test_decoder = build_decoder(with_labels=True, target_size=(img_size, img_size), ext='jpg', segment=True, ext2='jpg')
+ test_dataset = build_dataset(X_train, Y_train, bsize=BATCH_SIZE, decode_fn=test_decoder,
+ augmentAdv=False, augment=False, augmentAdvSeg=False, shuffle = None)
+ return test_dataset, len(X_train)
+
+def predict(model, dataset, len_data, outdir ="./save_vis/Etis/"):
+ steps_per_epoch = len_data//1
+ masks = model.predict(dataset, steps=steps_per_epoch)
+ # print(masks.shape)
+ i = 0
+ for x, y in dataset:
+ print(y[0].shape)
+ # print(i, masks[i].shape)
+ a = masks[i]
+ mask_new = np.dstack([a, a, a])
+ # print(x.shape, y.shape)
+ gt = np.dstack([y[0], y[0], y[0]])
+ # gt = cv2.cvtColor(y[0], cv2.COLOR_GRAY2RGB)
+ # true = cv2.cvtColor(x[0], cv2.COLOR_BGR2RGB)
+ im_h = np.concatenate([x[0], gt * 255, mask_new *255], axis = 1)
+ cv2.imwrite("{}/{}.jpg".format(outdir, i), im_h)
+ i+=1
+
+def visualize(src_dir, model, outdir ="./save_vis/Etis/"):
+ dataset, len_data = load_dataset(src_dir)
+ predict(model, dataset, len_data, outdir)
+
+if __name__ == "__main__":
+
+ BATCH_SIZE = 16
+ img_size = 256
+ SEED = 1024
+ save_path = "best_model.h5"
+ route_data = "./TestDataset/"
+ outdir ="./save_vis/cvc300/"
+ src_dir = "./TestDataset/CVC-300"
+
+ model = build_model(img_size)
+ model.load_weights(save_path)
+
+ visualize(src_dir, model, outdir)
+
\ No newline at end of file