Skip to content

Commit

Permalink
added learning rate scheduling, fixed storage of test data
Browse files Browse the repository at this point in the history
  • Loading branch information
sohompaul committed Jul 3, 2020
1 parent 2593f73 commit 8de3fdf
Showing 1 changed file with 65 additions and 48 deletions.
113 changes: 65 additions & 48 deletions multitask/QM9GNN2_Multitask.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import StandardScaler, PowerTransformer\n",
"from tensorflow.keras.layers import Input, Dense\n",
"from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError\n",
"from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError, LogCosh\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.callbacks import Callback, EarlyStopping\n",
Expand Down Expand Up @@ -181,17 +181,17 @@
" X_in, A_in, E_in, I_in = get_input_tensors(A=A, X=X, E=E, mode=mode)\n",
"\n",
" if conv == 'ecc': \n",
" gc1 = EdgeConditionedConv(64, activation='relu')([X_in, A_in, E_in])\n",
" gc2 = EdgeConditionedConv(128, activation='relu')([gc1, A_in, E_in])\n",
" gc1 = EdgeConditionedConv(128, activation='relu')([X_in, A_in, E_in])\n",
" gc2 = EdgeConditionedConv(256, activation='relu')([gc1, A_in, E_in])\n",
" if conv == 'gin':\n",
" assert mode == 'disjoint', 'cannot run GIN in batch mode'\n",
" gc1 = GINConv(64, activation='relu')([X_in, A_in, E_in])\n",
" gc2 = GINConv(128, activation='relu')([gc1, A_in, E_in])\n",
" gc1 = GINConv(128, activation='relu')([X_in, A_in, E_in])\n",
" gc2 = GINConv(256, activation='relu')([gc1, A_in, E_in])\n",
" if mode == 'batch':\n",
" pool = GlobalAttentionPool(256)(gc2)\n",
" pool = GlobalAttentionPool(512)(gc2)\n",
" if mode == 'disjoint':\n",
" pool = GlobalAttentionPool(256)([gc2, I_in])\n",
" dense = Dense(256, activation='relu')(pool)\n",
" pool = GlobalAttentionPool(512)([gc2, I_in])\n",
" dense = Dense(512, activation='relu')(pool)\n",
" output = Dense(1)(dense)\n",
"\n",
" optimizer = Adam(lr=learning_rate)\n",
Expand Down Expand Up @@ -226,17 +226,17 @@
"\n",
" \n",
" if conv == 'ecc': \n",
" gc1 = EdgeConditionedConv(64, activation='relu')([X_in, A_in, E_in])\n",
" gc2 = EdgeConditionedConv(128, activation='relu')([gc1, A_in, E_in])\n",
" gc1 = EdgeConditionedConv(128, activation='relu')([X_in, A_in, E_in])\n",
" gc2 = EdgeConditionedConv(256, activation='relu')([gc1, A_in, E_in])\n",
" if conv == 'gin':\n",
" assert mode == 'disjoint', 'cannot run GIN in batch mode'\n",
" gc1 = GINConv(64, activation='relu')([X_in, A_in, E_in])\n",
" gc2 = GINConv(128, activation='relu')([gc1, A_in, E_in])\n",
" gc1 = GINConv(128, activation='relu')([X_in, A_in, E_in])\n",
" gc2 = GINConv(256, activation='relu')([gc1, A_in, E_in])\n",
" if mode == 'batch':\n",
" pool = GlobalAttentionPool(256)(gc2)\n",
" pool = GlobalAttentionPool(512)(gc2)\n",
" if mode == 'disjoint':\n",
" pool = GlobalAttentionPool(256)([gc2, I_in])\n",
" dense_list = [Dense(256, activation='relu')(pool) \n",
" pool = GlobalAttentionPool(512)([gc2, I_in])\n",
" dense_list = [Dense(512, activation='relu')(pool) \n",
" for i in range(num_tasks)]\n",
" output_list = [Dense(1)(dense_layer) for dense_layer in dense_list]\n",
"\n",
Expand Down Expand Up @@ -275,15 +275,15 @@
" if conv == 'gin':\n",
" conv_layer = GINConv\n",
"\n",
" gc1_list = [conv_layer(64, activation='relu')([X_in, A_in, E_in]) \n",
" gc1_list = [conv_layer(128, activation='relu')([X_in, A_in, E_in]) \n",
" for i in range(num_tasks)]\n",
" gc2_list = [conv_layer(128, activation='relu')([gc1, A_in, E_in]) \n",
" gc2_list = [conv_layer(256, activation='relu')([gc1, A_in, E_in]) \n",
" for gc1 in gc1_list]\n",
" if mode == 'batch':\n",
" pool_list = [GlobalAttentionPool(256)(gc2) for gc2 in gc2_list]\n",
" pool_list = [GlobalAttentionPool(512)(gc2) for gc2 in gc2_list]\n",
" if mode == 'disjoint':\n",
" pool_list = [GlobalAttentionPool(256)([gc2, I_in]) for gc2 in gc2_list]\n",
" dense_list = [Dense(256, activation='relu')(pool) for pool in pool_list]\n",
" pool_list = [GlobalAttentionPool(512)([gc2, I_in]) for gc2 in gc2_list]\n",
" dense_list = [Dense(512, activation='relu')(pool) for pool in pool_list]\n",
" output_list = [Dense(1)(dense) for dense in dense_list]\n",
"\n",
" def loss_fn(y_actual, y_pred):\n",
Expand Down Expand Up @@ -353,7 +353,7 @@
"source": [
"def train_multitask_disjoint(model, cluster, *, opt, loss_fn, batch_size, \n",
" epochs, A_train, X_train, E_train, y_train, \n",
" loss_logger=None):\n",
" epoch_num=1, loss_logger=None):\n",
" F, S = get_shape_params(A=A_train, X=X_train, E=E_train, mode='disjoint')\n",
" @tf.function(\n",
" input_signature=(tf.TensorSpec((None, F), dtype=tf.float64),\n",
Expand All @@ -378,7 +378,6 @@
" print('Fitting model')\n",
" batches_train = batch_iterator([X_train, A_train, E_train, y_train[cluster].values],\n",
" batch_size=batch_size, epochs=epochs)\n",
" epoch_num = 1\n",
" for b in batches_train:\n",
" X_, A_, E_, I_ = numpy_to_disjoint(*b[:-1])\n",
" A_ = ops.sp_matrix_to_sp_tensor(A_)\n",
Expand Down Expand Up @@ -438,10 +437,13 @@
" \n",
" cluster = [c for c in clusters if prop in c][0]\n",
" if model is None:\n",
" model, task_to_scaler = load_hard_sharing_model(\n",
" A=A_all, X=X_all, E=E_all, tasks=cluster, \n",
" mode=mode, conv=conv, task_to_scaler=task_to_scaler\n",
" )\n",
" model, task_to_scaler = load_hard_sharing_model(A=A_all, \n",
" X=X_all, \n",
" E=E_all, \n",
" tasks=cluster,\n",
" mode=mode, \n",
" conv=conv, \n",
" task_to_scaler=task_to_scaler)\n",
" i = mol_id - 1\n",
"\n",
" # convert shape for batch mode\n",
Expand Down Expand Up @@ -512,9 +514,9 @@
" self.pred[task].append(pred)\n",
" \n",
" def _make_picklable(self):\n",
" return {'params': params, \n",
" 'actual': actual, \n",
" 'pred': pred, \n",
" return {'params': self.params, \n",
" 'actual': self.actual, \n",
" 'pred': self.pred, \n",
" 'losses': self.loss_logger.losses}\n",
" \n",
" def serialize(self, dirname='model_data', filename=''):\n",
Expand All @@ -536,8 +538,10 @@
" conv = 'gin'\n",
" batch_size = 32\n",
" epochs = 40\n",
" num_sampled = 20000\n",
" learning_rate = 1e-3\n",
" num_sampled = 30000\n",
" learning_rate = 1e-2\n",
" learning_rate_scheduler = 2\n",
" epochs_per_schedule = 5\n",
" amount = None\n",
" A_all, X_all, E_all, y_all = load_data(amount=amount, mode=mode)"
]
Expand All @@ -549,8 +553,12 @@
"outputs": [],
"source": [
"if __name__ == '__main__' and '__file__' not in globals(): \n",
" A, X, E, y = sample_from_data(num_sampled, A_all, X_all, E_all, \n",
" y_all, mode=mode)\n",
" A, X, E, y = sample_from_data(num_sampled, \n",
" A_all, \n",
" X_all, \n",
" E_all, \n",
" y_all, \n",
" mode=mode)\n",
" task_to_scaler = standardize(y)"
]
},
Expand Down Expand Up @@ -594,18 +602,20 @@
" print('begin training models')\n",
" \n",
" tasks = [[task] for cluster in clusters for task in cluster]\n",
" tasks_and_clusters = itertools.chain(tasks, clusters)\n",
" clusters_alt = [['B', 'g298'], ['alpha', 'zpve'], ['C', 'u0', 'u298'], ['r2', 'cv'], ['h298', 'mu']]\n",
" tasks_and_clusters = itertools.chain(tasks, clusters, clusters_alt)\n",
" for cluster, conv in itertools.product(tasks_and_clusters,\n",
" ['ecc', 'gin']):\n",
" print(f'training {cluster} with {mode} mode on {conv} conv')\n",
" \n",
" model, loss_fn = build_hard_sharing_model(A=A_train, \n",
" model, _ = build_hard_sharing_model(A=A_train, \n",
" X=X_train, \n",
" E=E_train, \n",
" num_tasks=len(cluster),\n",
" mode=mode,\n",
" conv=conv)\n",
" optimizer = Adam(lr=learning_rate)\n",
" loss_fn = MeanSquaredError()\n",
" \n",
" stream = io.StringIO()\n",
" model.summary(print_fn=lambda x: stream.write(x + '\\n'))\n",
Expand All @@ -621,7 +631,9 @@
" 'hard_sharing': True,\n",
" 'model_summary': summary,\n",
" 'loss_fn': type(loss_fn).__name__,\n",
" 'optimizer': type(optimizer).__name__}\n",
" 'optimizer': type(optimizer).__name__, \n",
" 'learning_rate_scheduler': learning_rate_scheduler, \n",
" 'epochs_per_schedule': epochs_per_schedule}\n",
" model_data = ModelData(params=params)\n",
" \n",
" if mode == 'batch':\n",
Expand All @@ -644,18 +656,23 @@
" cluster_pred = model.predict([X_test, A_test, E_test])\n",
"\n",
" if mode == 'disjoint':\n",
" # training\n",
" train_multitask_disjoint(model,\n",
" cluster,\n",
" opt=Adam(lr=1e-3),\n",
" loss_fn=loss_fn,\n",
" batch_size=batch_size,\n",
" epochs=epochs,\n",
" A_train=A_train,\n",
" X_train=X_train,\n",
" E_train=E_train,\n",
" y_train=y_train, \n",
" loss_logger=model_data.loss_logger)\n",
" # training with learning rate decay\n",
" for i in range(epochs//epochs_per_schedule):\n",
" if i > 0:\n",
" learning_rate /= learning_rate_scheduler\n",
" optimizer = Adam(lr=learning_rate)\n",
" train_multitask_disjoint(model,\n",
" cluster,\n",
" opt=optimizer,\n",
" loss_fn=loss_fn,\n",
" batch_size=batch_size,\n",
" epochs=epochs_per_schedule,\n",
" epoch_num=i*epochs_per_schedule+1,\n",
" A_train=A_train,\n",
" X_train=X_train,\n",
" E_train=E_train,\n",
" y_train=y_train, \n",
" loss_logger=model_data.loss_logger)\n",
" # testing\n",
" model_loss = test_multitask_disjoint(model,\n",
" cluster,\n",
Expand Down

0 comments on commit 8de3fdf

Please sign in to comment.