From ba30949e3b193024a17d6d2948da9e2ca8e1a261 Mon Sep 17 00:00:00 2001 From: Sohom Paul <2020spaul@tjhsst.edu> Date: Fri, 19 Jun 2020 14:09:25 -0400 Subject: [PATCH] added GIN with Spektral disjoint training mode --- multitask/QM9GNN2_Multitask.ipynb | 423 ++++++++++++++++++++++++------ 1 file changed, 345 insertions(+), 78 deletions(-) diff --git a/multitask/QM9GNN2_Multitask.ipynb b/multitask/QM9GNN2_Multitask.ipynb index a322923..661d9ea 100644 --- a/multitask/QM9GNN2_Multitask.ipynb +++ b/multitask/QM9GNN2_Multitask.ipynb @@ -2,75 +2,102 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 42, "metadata": {}, "outputs": [], "source": [ + "import itertools\n", "import pickle\n", "from os import path\n", "from time import time\n", - "import itertools\n", "\n", "import numpy as np\n", + "import tensorflow as tf\n", "from sklearn.model_selection import train_test_split\n", + "from sklearn.preprocessing import StandardScaler, PowerTransformer\n", "from tensorflow.keras.layers import Input, Dense\n", + "from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError\n", "from tensorflow.keras.models import Model\n", "from tensorflow.keras.optimizers import Adam\n", "from tensorflow.keras.callbacks import EarlyStopping\n", "from tensorflow.keras.backend import mean, square\n", "\n", "from spektral.datasets import qm9\n", - "from spektral.layers import EdgeConditionedConv, GlobalSumPool, GlobalAttentionPool\n", - "from spektral.utils import label_to_one_hot\n", - "\n", - "from sklearn.preprocessing import StandardScaler, PowerTransformer" + "from spektral.layers import EdgeConditionedConv, GINConv, GatedGraphConv\n", + "from spektral.layers import ops, GlobalSumPool, GlobalAttentionPool\n", + "from spektral.utils import batch_iterator, numpy_to_disjoint\n", + "from spektral.utils import label_to_one_hot" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 43, "metadata": { "scrolled": true }, "outputs": [], "source": [ - "def load_data(amount=None):\n", + "def load_data(amount=None, mode='batch'):\n", + " if mode not in ['batch', 'disjoint']:\n", + " raise ValueError(f\"mode {mode} not recognized; \"\n", + " \"choose 'batch' or 'disjoint'\")\n", + " \n", " A_all, X_all, E_all, y_all = qm9.load_data(return_type='numpy',\n", " nf_keys='atomic_num',\n", " ef_keys='type',\n", " self_loops=True,\n", " amount=amount) # None for entire dataset\n", " # Preprocessing\n", - " X_uniq = np.unique(X_all)\n", - " X_uniq = X_uniq[X_uniq != 0]\n", - " E_uniq = np.unique(E_all)\n", - " E_uniq = E_uniq[E_uniq != 0]\n", - " \n", - " X_all = label_to_one_hot(X_all, X_uniq)\n", - " E_all = label_to_one_hot(E_all, E_uniq)\n", + " if mode == 'batch':\n", + " X_uniq = np.unique(X_all)\n", + " X_uniq = X_uniq[X_uniq != 0]\n", + " E_uniq = np.unique(E_all)\n", + " E_uniq = E_uniq[E_uniq != 0]\n", + "\n", + " X_all = label_to_one_hot(X_all, X_uniq)\n", + " E_all = label_to_one_hot(E_all, E_uniq)\n", + " elif mode == 'disjoint':\n", + " X_uniq = np.unique([v for x in X_all for v in np.unique(x)])\n", + " E_uniq = np.unique([v for e in E_all for v in np.unique(e)])\n", + " X_uniq = X_uniq[X_uniq != 0]\n", + " E_uniq = E_uniq[E_uniq != 0]\n", + "\n", + " X_all = [label_to_one_hot(x, labels=X_uniq) for x in X_all]\n", + " E_all = [label_to_one_hot(e, labels=E_uniq) for e in E_all]\n", " \n", " return A_all, X_all, E_all, y_all" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 44, "metadata": {}, "outputs": [], "source": [ - "def sample_from_data(sample_size, A_all, X_all, E_all, y_all):\n", - " indices = np.random.choice(X_all.shape[0], sample_size, replace=False)\n", - " A = A_all[indices, :, :]\n", - " X = X_all[indices, :, :]\n", - " E = E_all[indices, :, :, :]\n", - " y = y_all.iloc[indices, :].copy()\n", + "def sample_from_data(sample_size, A_all, X_all, E_all, y_all, mode='batch'):\n", + " if mode not in ['batch', 'disjoint']:\n", + " raise ValueError(f\"mode {mode} not recognized; \"\n", + " \"choose 'batch' or 'disjoint'\")\n", + " if mode == 'batch':\n", + " indices = np.random.choice(X_all.shape[0], sample_size, replace=False)\n", + " A = A_all[indices, :, :]\n", + " X = X_all[indices, :, :]\n", + " E = E_all[indices, :, :, :]\n", + " y = y_all.iloc[indices, :].copy()\n", + " \n", + " if mode == 'disjoint':\n", + " indices = np.random.choice(len(X_all), sample_size, replace=False)\n", + " A = [A_all[i] for i in indices]\n", + " X = [X_all[i] for i in indices]\n", + " E = [E_all[i] for i in indices]\n", + " y = y_all.iloc[indices, :].copy()\n", " \n", - " return A, X, E, y" + " return A, X, E, y " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -85,99 +112,178 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ - "def get_shape_params(*, A, X, E):\n", - " N = X.shape[-2] # Number of nodes in the graphs\n", - " F = X[0].shape[-1] # Dimension of node features\n", - " S = E[0].shape[-1] # Dimension of edge features\n", - " \n", - " return N, F, S" + "def get_shape_params(*, A, X, E, mode='batch'):\n", + " if mode not in ['batch', 'disjoint']:\n", + " raise ValueError(f\"mode {mode} not recognized; \"\n", + " \"choose 'batch' or 'disjoint'\")\n", + " F = X[0].shape[-1] # Dimension of node features\n", + " S = E[0].shape[-1] # Dimension of edge features\n", + " if mode == 'batch':\n", + " N = X.shape[-2] # Number of nodes in the graphs\n", + " return N, F, S\n", + " if mode == 'disjoint':\n", + " return F, S" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ - "def get_input_tensors(*, A, X, E):\n", - " N, F, S = get_shape_params(A=A, X=X, E=E)\n", - " X_in = Input(shape=(N, F))\n", - " A_in = Input(shape=(N, N))\n", - " E_in = Input(shape=(N, N, S))\n", + "def get_input_tensors(*, A, X, E, mode='batch'):\n", + " if mode not in ['batch', 'disjoint']:\n", + " raise ValueError(f\"mode {mode} not recognized; \"\n", + " \"choose 'batch' or 'disjoint'\")\n", + " if mode == 'batch':\n", + " N, F, S = get_shape_params(A=A, X=X, E=E, mode=mode)\n", + " X_in = Input(shape=(N, F), name='X_in')\n", + " A_in = Input(shape=(N, N), name='A_in')\n", + " E_in = Input(shape=(N, N, S), name='E_in')\n", + "\n", + " return X_in, A_in, E_in\n", " \n", - " return X_in, A_in, E_in" + " if mode == 'disjoint':\n", + " F, S = get_shape_params(A=A, X=X, E=E, mode=mode)\n", + " X_in = Input(shape=(F,), name='X_in')\n", + " A_in = Input(shape=(None,), sparse=True, name='A_in')\n", + " E_in = Input(shape=(S,), name='E_in')\n", + " I_in = Input(shape=(), name='segment_ids_in', dtype=tf.int32)\n", + " \n", + " return X_in, A_in, E_in, I_in" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ - "def build_single_task_model(*, A, X, E, learning_rate=1e-3, loss='mse'):\n", - " X_in, A_in, E_in = get_input_tensors(A=A, X=X, E=E)\n", + "def build_single_task_model(*, A, X, E, learning_rate=1e-3, conv='ecc', mode='batch'): \n", + " if mode not in ['batch', 'disjoint']:\n", + " raise ValueError(f\"mode {mode} not recognized; \"\n", + " \"choose 'batch' or 'disjoint'\") \n", + " if conv not in ['ecc', 'gin']:\n", + " raise ValueError(f\"convolution layer {conv} not recognized; \"\n", + " \"choose 'ecc' or 'gin'\")\n", + " \n", + " if mode == 'batch':\n", + " X_in, A_in, E_in = get_input_tensors(A=A, X=X, E=E, mode=mode)\n", + " if mode == 'disjoint':\n", + " X_in, A_in, E_in, I_in = get_input_tensors(A=A, X=X, E=E, mode=mode)\n", "\n", - " gc1 = EdgeConditionedConv(64, activation='relu')([X_in, A_in, E_in])\n", - " gc2 = EdgeConditionedConv(128, activation='relu')([gc1, A_in, E_in])\n", - " pool = GlobalAttentionPool(256)(gc2)\n", + " if conv == 'ecc': \n", + " gc1 = EdgeConditionedConv(64, activation='relu')([X_in, A_in, E_in])\n", + " gc2 = EdgeConditionedConv(128, activation='relu')([gc1, A_in, E_in])\n", + " if conv == 'gin':\n", + " assert mode == 'disjoint', 'cannot run GIN in batch mode'\n", + " gc1 = GINConv(64, activation='relu')([X_in, A_in, E_in])\n", + " gc2 = GINConv(128, activation='relu')([gc1, A_in, E_in])\n", + " if mode == 'batch':\n", + " pool = GlobalAttentionPool(256)(gc2)\n", + " if mode == 'disjoint':\n", + " pool = GlobalAttentionPool(256)([gc2, I_in])\n", " dense = Dense(256, activation='relu')(pool)\n", " output = Dense(1)(dense)\n", "\n", " # Build model\n", - " model = Model(inputs=[X_in, A_in, E_in], outputs=output)\n", " optimizer = Adam(lr=learning_rate)\n", - " model.compile(optimizer=optimizer, loss=loss)\n", - "\n", - " return model" + " loss_fn = MeanSquaredError()\n", + " if mode == 'batch':\n", + " model = Model(inputs=[X_in, A_in, E_in], outputs=output)\n", + " model.compile(optimizer=optimizer, loss=loss)\n", + " return model\n", + " if mode == 'disjoint':\n", + " model = Model(inputs=[X_in, A_in, E_in, I_in], outputs=output)\n", + " return model, loss_fn" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "def build_hard_sharing_model(*, A, X, E, num_tasks, \n", - " learning_rate=1e-3, loss='mse'):\n", - " X_in, A_in, E_in = get_input_tensors(A=A, X=X, E=E)\n", + " learning_rate=1e-3, conv='ecc', mode='batch'):\n", + " if mode not in ['batch', 'disjoint']:\n", + " raise ValueError(f\"mode {mode} not recognized; \"\n", + " \"choose 'batch' or 'disjoint'\") \n", + " if conv not in ['ecc', 'gin']:\n", + " raise ValueError(f\"convolution layer {conv} not recognized; \"\n", + " \"choose 'ecc' or 'gin'\")\n", + " if mode == 'batch':\n", + " X_in, A_in, E_in = get_input_tensors(A=A, X=X, E=E, mode=mode)\n", + " if mode == 'disjoint':\n", + " X_in, A_in, E_in, I_in = get_input_tensors(A=A, X=X, E=E, mode=mode)\n", "\n", - " gc1 = EdgeConditionedConv(64, activation='relu')([X_in, A_in, E_in])\n", - " gc2 = EdgeConditionedConv(128, activation='relu')([gc1, A_in, E_in])\n", - " pool = GlobalAttentionPool(256)(gc2)\n", + " \n", + " if conv == 'ecc': \n", + " gc1 = EdgeConditionedConv(64, activation='relu')([X_in, A_in, E_in])\n", + " gc2 = EdgeConditionedConv(128, activation='relu')([gc1, A_in, E_in])\n", + " if conv == 'gin':\n", + " assert mode == 'disjoint', 'cannot run GIN in batch mode'\n", + " gc1 = GINConv(64, activation='relu')([X_in, A_in, E_in])\n", + " gc2 = GINConv(128, activation='relu')([gc1, A_in, E_in])\n", + " if mode == 'batch':\n", + " pool = GlobalAttentionPool(256)(gc2)\n", + " if mode == 'disjoint':\n", + " pool = GlobalAttentionPool(256)([gc2, I_in])\n", " dense_list = [Dense(256, activation='relu')(pool) \n", " for i in range(num_tasks)]\n", " output_list = [Dense(1)(dense_layer) for dense_layer in dense_list]\n", "\n", - " model = Model(inputs=[X_in, A_in, E_in], outputs=output_list)\n", " optimizer = Adam(lr=learning_rate)\n", - " model.compile(optimizer=optimizer, loss=loss)\n", - "\n", - " return model" + " loss_fn = MeanSquaredError()\n", + " if mode == 'batch':\n", + " model = Model(inputs=[X_in, A_in, E_in], outputs=output_list)\n", + " model.compile(optimizer=optimizer, loss=loss)\n", + " return model\n", + " if mode == 'disjoint':\n", + " model = Model(inputs=[X_in, A_in, E_in, I_in], outputs=output_list)\n", + " return model, loss_fn" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "def build_soft_sharing_model(*, A, X, E, num_tasks, share_param, \n", - " learning_rate=1e-3, loss='mse'):\n", - " X_in, A_in, E_in = get_input_tensors(A=A, X=X, E=E)\n", + " learning_rate=1e-3, conv='ecc', mode='batch'):\n", + " if mode not in ['batch', 'disjoint']:\n", + " raise ValueError(f\"mode {mode} not recognized; \"\n", + " \"choose 'batch' or 'disjoint'\") \n", + " if conv not in ['ecc', 'gin']:\n", + " raise ValueError(f\"convolution layer {conv} not recognized; \"\n", + " \"choose 'ecc' or 'gin'\")\n", + " if mode == 'batch':\n", + " X_in, A_in, E_in = get_input_tensors(A=A, X=X, E=E, mode=mode)\n", + " if mode == 'disjoint':\n", + " X_in, A_in, E_in, I_in = get_input_tensors(A=A, X=X, E=E, mode=mode)\n", + " \n", + " if conv == 'ecc':\n", + " conv_layer = EdgeConditionedConv\n", + " if conv == 'gin':\n", + " conv_layer = GINConv\n", "\n", - " gc1_list = [EdgeConditionedConv(64, activation='relu')([X_in, A_in, E_in]) \n", + " gc1_list = [conv_layer(64, activation='relu')([X_in, A_in, E_in]) \n", " for i in range(num_tasks)]\n", - " gc2_list = [EdgeConditionedConv(128, activation='relu')([gc1, A_in, E_in]) \n", + " gc2_list = [conv_layer(128, activation='relu')([gc1, A_in, E_in]) \n", " for gc1 in gc1_list]\n", - " pool_list = [GlobalAttentionPool(256)(gc2) for gc2 in gc2_list]\n", + " if mode == 'batch':\n", + " pool_list = [GlobalAttentionPool(256)(gc2) for gc2 in gc2_list]\n", + " if mode == 'disjoint':\n", + " pool_list = [GlobalAttentionPool(256)([gc2, I_in]) for gc2 in gc2_list]\n", " dense_list = [Dense(256, activation='relu')(pool) for pool in pool_list]\n", " output_list = [Dense(1)(dense) for dense in dense_list]\n", "\n", - " def loss(y_actual, y_pred):\n", + " def loss_fn(y_actual, y_pred):\n", " avg_layer_diff = 0\n", " for i, j in itertools.combinations(range(num_tasks), 2):\n", " for gc in [gc1_list, gc2_list]:\n", @@ -186,21 +292,24 @@ " avg_layer_diff /= (num_tasks)*(num_tasks-1)/2 \n", " return mean(square(y_actual - y_pred)) + share_param*avg_layer_diff\n", "\n", - " model = Model(inputs=[X_in, A_in, E_in], outputs=output_list)\n", " optimizer = Adam(lr=learning_rate)\n", - " model.compile(optimizer=optimizer, loss=loss)\n", - "\n", - " return model" + " if mode == 'batch':\n", + " model = Model(inputs=[X_in, A_in, E_in], outputs=output_list)\n", + " model.compile(optimizer=optimizer, loss=loss_fn)\n", + " return model\n", + " if mode == 'disjoint':\n", + " model = Model(inputs=[X_in, A_in, E_in, I_in], outputs=output_list)\n", + " return model, loss_fn" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ - "def generate_model_filename(tasks, folder_path='demo_models'):\n", - " filename = \"\".join(sorted(tasks))\n", + "def generate_model_filename(tasks, conv='ecc', mode='batch', folder_path='demo_models'):\n", + " filename = \"\".join(sorted(tasks)) + '_' + conv + '_' + mode \n", " return path.join(folder_path, f'{filename}.h5')\n", "\n", "def generate_task_scaler_filename(task, folder_path='demo_models'):\n", @@ -209,7 +318,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ @@ -221,9 +330,11 @@ " scaler = task_to_scaler[task]\n", " pickle.dump(obj=scaler, file=f)\n", "\n", - "def load_hard_sharing_model(*, A, X, E, tasks, task_to_scaler=dict()):\n", - " model = build_hard_sharing_model(A=A, X=X, E=E, num_tasks=len(tasks))\n", - " model.load_weights(generate_model_filename(tasks))\n", + "def load_hard_sharing_model(*, A, X, E, tasks, conv='ecc', \n", + " mode='batch', task_to_scaler=dict()):\n", + " model = build_hard_sharing_model(A=A, X=X, E=E, conv=conv, mode=mode,\n", + " num_tasks=len(tasks))\n", + " model.load_weights(generate_model_filename(tasks, conv=conv, mode=mode))\n", " for task in tasks:\n", " if task not in task_to_scaler:\n", " with open(generate_task_scaler_filename(task), 'rb') as f:\n", @@ -231,6 +342,80 @@ " return model, task_to_scaler" ] }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "def train_multitask_disjoint(model, cluster, *, opt, loss_fn, batch_size, epochs, A_train, X_train, E_train, y_train):\n", + " F, S = get_shape_params(A=A_train, X=X_train, E=E_train, mode='disjoint')\n", + " @tf.function(\n", + " input_signature=(tf.TensorSpec((None, F), dtype=tf.float64),\n", + " tf.SparseTensorSpec((None, None), dtype=tf.float64),\n", + " tf.TensorSpec((None, S), dtype=tf.float64),\n", + " tf.TensorSpec((None,), dtype=tf.int32),\n", + " tf.TensorSpec((None, len(cluster)), dtype=tf.float64)),\n", + " experimental_relax_shapes=True)\n", + " def train_step(X_, A_, E_, I_, y_):\n", + " with tf.GradientTape() as tape:\n", + " predictions = model([X_, A_, E_, I_], training=True)\n", + " loss = loss_fn(y_, predictions)\n", + " loss += sum(model.losses)\n", + " gradients = tape.gradient(loss, model.trainable_variables)\n", + " opt.apply_gradients(zip(gradients, model.trainable_variables))\n", + " return loss\n", + " \n", + " current_batch = 0\n", + " model_loss = 0\n", + " batches_in_epoch = np.ceil(len(A_train) / batch_size)\n", + "\n", + " print('Fitting model')\n", + " batches_train = batch_iterator([X_train, A_train, E_train, y_train[cluster].values],\n", + " batch_size=batch_size, epochs=epochs)\n", + " for b in batches_train:\n", + " X_, A_, E_, I_ = numpy_to_disjoint(*b[:-1])\n", + " A_ = ops.sp_matrix_to_sp_tensor(A_)\n", + " y_ = b[-1]\n", + " outs = train_step(X_, A_, E_, I_, y_)\n", + "\n", + " model_loss += outs.numpy()\n", + " current_batch += 1\n", + " if current_batch == batches_in_epoch:\n", + " print('Loss: {}'.format(model_loss / batches_in_epoch))\n", + " model_loss = 0\n", + " current_batch = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [], + "source": [ + "def test_multitask_disjoint(model, cluster, *, loss_fn, batch_size, A_test, X_test, E_test, y_test):\n", + " print('Testing model')\n", + " model_loss = 0\n", + " batches_in_epoch = np.ceil(len(A_test) / batch_size)\n", + " batches_test = batch_iterator([X_test, A_test, E_test, y_test[cluster].values], batch_size=batch_size)\n", + " for b in batches_test:\n", + " X_, A_, E_, I_ = numpy_to_disjoint(*b[:-1])\n", + " A_ = ops.sp_matrix_to_sp_tensor(A_)\n", + " y_ = b[3]\n", + "\n", + " predictions = model([X_, A_, E_, I_], training=False)\n", + " model_loss += loss_fn(y_, predictions)\n", + " model_loss /= batches_in_epoch\n", + " print('Done. Test loss: {}'.format(model_loss))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TODO: Rewrite code below to allow for disjoint mode." + ] + }, { "cell_type": "code", "execution_count": null, @@ -266,8 +451,7 @@ "source": [ "if __name__ == '__main__' and '__file__' not in globals(): \n", " A_all, X_all, E_all, y_all = load_data()\n", - " N, F, S = get_shape_params(A=A_all, X=X_all, E=E_all)\n", - " # n_out = y_all.shape[-1] # Dimension of the target" + " N, F, S = get_shape_params(A=A_all, X=X_all, E=E_all, mode='batch')" ] }, { @@ -360,6 +544,89 @@ " errors.append(err[0])\n", " print(f'Avg error of {prop} is {sum(errors)/len(errors):.2%}')" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### The following code tests if disjoint mode is working correctly." + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading QM9 dataset.\n", + "Reading SDF\n", + "Fitting model\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/sohompaul/anaconda3/envs/senior_research/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:434: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n", + " \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loss: 4.391317373719708\n", + "Loss: 0.9110169924538711\n", + "Loss: 0.839188085547809\n", + "Loss: 0.8122989535331726\n", + "Loss: 0.8126783185991747\n", + "Loss: 0.7306714705352125\n", + "Loss: 0.7450922830351467\n", + "Loss: 0.7062880828462798\n", + "Loss: 0.6563559421177568\n", + "Loss: 0.607425882898528\n", + "Testing model\n", + "Done. Test loss: 0.5633670687675476\n" + ] + } + ], + "source": [ + "mode = 'disjoint'\n", + "conv = 'gin'\n", + "A_all, X_all, E_all, y_all = load_data(amount=2000, mode=mode)\n", + "A, X, E, y = sample_from_data(1000, A_all, X_all, E_all, y_all, mode=mode)\n", + "task_to_scaler = standardize(y)\n", + "\n", + "cluster = ['cv', 'r2']\n", + "A_train, A_test, \\\n", + " X_train, X_test, \\\n", + " E_train, E_test, \\\n", + " y_train, y_test = train_test_split(A, X, E, y, test_size=0.1)\n", + "\n", + "model, loss_fn = build_hard_sharing_model(A=A_train, X=X_train, E=E_train, \n", + " num_tasks=len(cluster), conv=conv, mode=mode)\n", + "train_multitask_disjoint(model, \n", + " cluster, \n", + " opt = Adam(lr=1e-3),\n", + " loss_fn=loss_fn,\n", + " batch_size=32, \n", + " epochs=10, \n", + " A_train=A_train, \n", + " X_train=X_train,\n", + " E_train=E_train,\n", + " y_train=y_train)\n", + "test_multitask_disjoint(model,\n", + " cluster,\n", + " loss_fn=loss_fn,\n", + " batch_size=32,\n", + " A_test=A_test,\n", + " X_test=X_test,\n", + " E_test=E_test,\n", + " y_test=y_test)" + ] } ], "metadata": {