Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
acse-ww721 committed Sep 12, 2023
1 parent 669fd4d commit 063d0af
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 106 deletions.
40 changes: 21 additions & 19 deletions src/models/bilinear_interpolation_1x.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# The function implementation below is a modification version from Tensorflow
# Original code link: https://github.com/ashesh6810/DDWP-DA/blob/master/layers.py


from keras import backend as K
from keras.engine.topology import Layer

if K.backend() == 'tensorflow':
if K.backend() == "tensorflow":
import tensorflow as tf

def K_meshgrid(x, y):
Expand Down Expand Up @@ -29,7 +33,7 @@ def __init__(self, output_size, **kwargs):

def get_config(self):
return {
'output_size': self.output_size,
"output_size": self.output_size,
}

def compute_output_shape(self, input_shapes):
Expand All @@ -43,21 +47,20 @@ def call(self, tensors, mask=None):
return output

def _interpolate(self, image, sampled_grids, output_size):

batch_size = K.shape(image)[0]
height = K.shape(image)[1]
width = K.shape(image)[2]
num_channels = K.shape(image)[3]

x = K.cast(K.flatten(sampled_grids[:, 0:1, :]), dtype='float32')
y = K.cast(K.flatten(sampled_grids[:, 1:2, :]), dtype='float32')
x = K.cast(K.flatten(sampled_grids[:, 0:1, :]), dtype="float32")
y = K.cast(K.flatten(sampled_grids[:, 1:2, :]), dtype="float32")

x = .5 * (x + 1.0) * K.cast(width, dtype='float32')
y = .5 * (y + 1.0) * K.cast(height, dtype='float32')
x = 0.5 * (x + 1.0) * K.cast(width, dtype="float32")
y = 0.5 * (y + 1.0) * K.cast(height, dtype="float32")

x0 = K.cast(x, 'int32')
x0 = K.cast(x, "int32")
x1 = x0 + 1
y0 = K.cast(y, 'int32')
y0 = K.cast(y, "int32")
y1 = y0 + 1

max_x = int(K.int_shape(image)[2] - 1)
Expand Down Expand Up @@ -87,16 +90,16 @@ def _interpolate(self, image, sampled_grids, output_size):
indices_d = base_y1 + x1

flat_image = K.reshape(image, shape=(-1, num_channels))
flat_image = K.cast(flat_image, dtype='float32')
flat_image = K.cast(flat_image, dtype="float32")
pixel_values_a = K.gather(flat_image, indices_a)
pixel_values_b = K.gather(flat_image, indices_b)
pixel_values_c = K.gather(flat_image, indices_c)
pixel_values_d = K.gather(flat_image, indices_d)

x0 = K.cast(x0, 'float32')
x1 = K.cast(x1, 'float32')
y0 = K.cast(y0, 'float32')
y1 = K.cast(y1, 'float32')
x0 = K.cast(x0, "float32")
x1 = K.cast(x1, "float32")
y0 = K.cast(y0, "float32")
y1 = K.cast(y1, "float32")

area_a = K.expand_dims(((x1 - x) * (y1 - y)), 1)
area_b = K.expand_dims(((x1 - x) * (y - y0)), 1)
Expand All @@ -111,8 +114,8 @@ def _interpolate(self, image, sampled_grids, output_size):

def _make_regular_grids(self, batch_size, height, width):
# making a single regular grid
x_linspace = K_linspace(-1., 1., width)
y_linspace = K_linspace(-1., 1., height)
x_linspace = K_linspace(-1.0, 1.0, width)
y_linspace = K_linspace(-1.0, 1.0, height)
x_coordinates, y_coordinates = K_meshgrid(x_linspace, y_linspace)
x_coordinates = K.flatten(x_coordinates)
y_coordinates = K.flatten(y_coordinates)
Expand All @@ -126,12 +129,11 @@ def _make_regular_grids(self, batch_size, height, width):

def _transform(self, X, affine_transformation, output_size):
batch_size, num_channels = K.shape(X)[0], K.shape(X)[3]
transformations = K.reshape(affine_transformation,
shape=(batch_size, 2, 3))
transformations = K.reshape(affine_transformation, shape=(batch_size, 2, 3))
# transformations = K.cast(affine_transformation[:, 0:2, :], 'float32')
regular_grids = self._make_regular_grids(batch_size, *output_size)
sampled_grids = K.batch_dot(transformations, regular_grids)
interpolated_image = self._interpolate(X, sampled_grids, output_size)
new_shape = (batch_size, output_size[0], output_size[1], num_channels)
interpolated_image = K.reshape(interpolated_image, new_shape)
return interpolated_image
return interpolated_image
3 changes: 3 additions & 0 deletions src/models/model_unet_stn_1x.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# The function implementation below is a modification version from Tensorflow
# Original code link: https://github.com/ashesh6810/DDWP-DA/blob/master/EnKF_DD_all_time.py

import tensorflow
import keras.backend as K

Expand Down
99 changes: 12 additions & 87 deletions src/models/tf_model_1x_test.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "481a4d5d-dbe6-479c-b26b-343102d0617d",
"metadata": {},
"outputs": [],
"source": [
"# The function implementation below is a modification version from Tensorflow\n",
"# Original code link: https://github.com/ashesh6810/DDWP-DA/blob/master/Unet_STN_lead12.ipynb"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down Expand Up @@ -4074,92 +4085,6 @@
"\n",
" count=count+1"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "7a01b664",
"metadata": {
"ExecuteTime": {
"end_time": "2023-08-30T20:58:55.759994Z",
"start_time": "2023-08-30T20:58:55.744559Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.13.1\n",
"GPU Available: False\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"print(tf.__version__)\n",
"print(\"GPU Available:\", tf.test.is_gpu_available())\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "37c676e6",
"metadata": {
"ExecuteTime": {
"end_time": "2023-08-30T21:57:53.380598Z",
"start_time": "2023-08-30T21:57:53.374600Z"
}
},
"outputs": [],
"source": [
"import h5py\n",
"\n",
"def modify_h5_file(filepath):\n",
" \"\"\"\n",
" Check and modify the 'keras_version' attribute of the HDF5 file.\n",
" \"\"\"\n",
" with h5py.File(filepath, 'r+') as f:\n",
" if isinstance(f.attrs['keras_version'], bytes):\n",
" f.attrs['keras_version'] = f.attrs['keras_version'].decode('utf8')\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "08816071",
"metadata": {
"ExecuteTime": {
"end_time": "2023-08-30T21:58:48.284673Z",
"start_time": "2023-08-30T21:58:48.269675Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'str'> 2.2.4\n"
]
}
],
"source": [
"import h5py\n",
"\n",
"with h5py.File('best_weights_lead12.h5', 'r') as f:\n",
" keras_version = f.attrs['keras_version']\n",
" print(type(keras_version), keras_version)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7087d02d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -4179,7 +4104,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
"version": "3.9.17"
},
"toc": {
"base_numbering": 1,
Expand Down

0 comments on commit 063d0af

Please sign in to comment.