Skip to content

Commit

Permalink
added __main__ guards
Browse files Browse the repository at this point in the history
  • Loading branch information
sohompaul committed May 30, 2020
1 parent 98424fb commit c5b08bc
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions multitask/QM9GNN2_Multitask.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
" nf_keys='atomic_num',\n",
" ef_keys='type',\n",
" self_loops=True,\n",
" amount=2000) # chnage this to None to load entire dataset\n",
" amount=None) # chnage this to None to load entire dataset\n",
"# Preprocessing\n",
"X_uniq = np.unique(X_all)\n",
"X_uniq = X_uniq[X_uniq != 0]\n",
Expand Down Expand Up @@ -291,14 +291,15 @@
"metadata": {},
"outputs": [],
"source": [
"for cluster in clusters[:1]:\n",
" model = build_hard_sharing_model(N=N, F=F, S=S, num_tasks=len(cluster))\n",
" model.fit(x=[X_train, A_train, E_train], \n",
" y=y_train[cluster].values,\n",
" batch_size=batch_size,\n",
" validation_split=0.1,\n",
" epochs=3)\n",
"# save_model(model, cluster)"
"if __name__ == '__main__' and '__file__' not in globals():\n",
" for cluster in clusters[:1]:\n",
" model = build_hard_sharing_model(N=N, F=F, S=S, num_tasks=len(cluster))\n",
" model.fit(x=[X_train, A_train, E_train], \n",
" y=y_train[cluster].values,\n",
" batch_size=batch_size,\n",
" validation_split=0.1,\n",
" epochs=3)\n",
" save_model(model, cluster)"
]
},
{
Expand All @@ -307,10 +308,12 @@
"metadata": {},
"outputs": [],
"source": [
"# model, _ = load_hard_sharing_model(N=N, F=F, S=S, tasks=clusters[0])\n",
"model_loss = model.evaluate(x=[X_test, A_test, E_test],\n",
" y=y_test[cluster].values)\n",
"print(f\"Test loss: {model_loss}\")"
"if __name__ == '__main__' and '__file__' not in globals():\n",
" for cluster in clusters[:1]:\n",
" model, _ = load_hard_sharing_model(N=N, F=F, S=S, tasks=clusters[0])\n",
" model_loss = model.evaluate(x=[X_test, A_test, E_test],\n",
" y=y_test[cluster].values)\n",
" print(f\"Test loss: {model_loss}\")"
]
}
],
Expand Down

0 comments on commit c5b08bc

Please sign in to comment.