From 6f91d0da6b9024d8efd3c39b40eef09acbdeac22 Mon Sep 17 00:00:00 2001 From: namupark <58543540+namupark@users.noreply.github.com> Date: Sat, 11 Mar 2023 03:57:51 -0800 Subject: [PATCH 1/2] Code used to run the experiments for Result2 --- demo_promptehr.py | 201 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 demo_promptehr.py diff --git a/demo_promptehr.py b/demo_promptehr.py new file mode 100644 index 0000000..8c60084 --- /dev/null +++ b/demo_promptehr.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python +# coding: utf-8 + +# In[1]: + + +import transformers + + +# In[2]: + + +transformers.__version__ + + +# In[3]: + + +from pytrial.data.demo_data import load_synthetic_ehr_sequence +from pytrial.tasks.trial_simulation.data import SequencePatient + +demo = load_synthetic_ehr_sequence(n_sample=100) + + +# In[4]: + + +demo + + +# In[5]: + + +len(demo['visit']) + + +# In[6]: + + +demo.keys() + + +# In[7]: + + +# build sequence dataset +seqdata = SequencePatient(data={'v':demo['visit'], 'y':demo['y'], 'x':demo['feature'],}, + metadata={ + 'visit':{'mode':'dense'}, + 'label':{'mode':'tensor'}, + 'voc':demo['voc'], + 'max_visit':20, + } + ) + +print('visit', demo['visit'][0]) # a list of visit events +print('mortality', demo['y'][0]) # array of labels +print('feature', demo['feature'][0]) # array of patient baseline features +print('voc', demo['voc']) # dict of dicts containing the mapping from index to the original event names +print('order', demo['order']) # a list of three types of code +print('n_num_feature', demo['n_num_feature']) # int: a number of patient's numerical features +print('cat_cardinalities', demo['cat_cardinalities']) # list: a list of cardinalities of patient's categorical features + + +# In[9]: + + +demo['voc'] + + +# In[14]: + + +demo['voc']['med'].idx2word + + +# In[ ]: + + + + + +# In[18]: + + +from promptehr import PromptEHR + +# fit the model +model = PromptEHR( + code_type=demo['order'], + n_num_feature=demo['n_num_feature'], + cat_cardinalities=demo['cat_cardinalities'], + num_worker=0, + eval_step=1, + epoch=5, + device=[0], +) +model.fit( + train_data=seqdata, + val_data=seqdata, +) + + +# In[15]: + + +model.evaluate(seqdata) + + +# In[17]: + + +model + + +# In[ ]: + + + + + +# In[ ]: + + +# save the model +model.save_model('./simulation/promptEHR') + + +# In[12]: + + +# generate fake records +res = model.predict(seqdata, n_per_sample=10, n=10, verbose=True) + + +# In[13]: + + +print(res) + + +# In[ ]: + + + + + +# In[1]: + + +import os +os.chdir('../') + + +# In[16]: + + +# if you want pretrained model downloaded +from promptehr import PromptEHR +model = PromptEHR() +model.from_pretrained() + + +# In[17]: + + +model.training_args + + +# In[19]: + + +model.evaluate(seqdata) + + +# In[ ]: + + + + + +# In[ ]: + + + + + +# In[15]: + + +model.fit( + train_data=seqdata, + val_data=seqdata, +) + + +# In[ ]: + + + + From f9e7d364671392b781ebcabc5074d64cc3d0ff40 Mon Sep 17 00:00:00 2001 From: namupark <58543540+namupark@users.noreply.github.com> Date: Sat, 11 Mar 2023 04:01:22 -0800 Subject: [PATCH 2/2] Delete example directory --- example/demo_promptehr.ipynb | 433 ----------------------------------- 1 file changed, 433 deletions(-) delete mode 100644 example/demo_promptehr.ipynb diff --git a/example/demo_promptehr.ipynb b/example/demo_promptehr.ipynb deleted file mode 100644 index b0affa9..0000000 --- a/example/demo_promptehr.ipynb +++ /dev/null @@ -1,433 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "007dc61d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "visit [[[0, 1, 2, 3, 4, 5, 6, 7], [0, 1, 2], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], [[8, 9, 10, 7], [3, 4, 1], [0, 1, 2, 3, 5, 4, 6, 7, 8, 9, 10, 11, 13, 15, 16, 17, 18]]]\n", - "mortality False\n", - "feature [-1.02022055 0. 0. ]\n", - "voc {'diag': , 'prod': , 'med': }\n", - "order ['diag', 'prod', 'med']\n", - "n_num_feature 1\n", - "cat_cardinalities [2, 10]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/zifengw2/EHR_Simulation/github/PromptEHR/pytrial/data/patient_data.py:74: UserWarning: No metadata provided. Metadata will be automatically detected from your data. This process may not be accurate. We recommend writing metadata to ensure correct data handling.\n", - " warnings.warn('No metadata provided. Metadata will be automatically '\n" - ] - } - ], - "source": [ - "import os\n", - "os.chdir('../')\n", - "\n", - "from promptehr import PromptEHR\n", - "from promptehr import load_demo_data\n", - "\n", - "# load pytrial demodata, supported by PyTrial package to load the demo EHR data\n", - "from pytrial.data.demo_data import load_mimic_ehr_sequence\n", - "from pytrial.tasks.trial_simulation.data import SequencePatient\n", - "\n", - "# see the input format\n", - "demo = load_mimic_ehr_sequence(n_sample=100)\n", - "\n", - "# build sequence dataset\n", - "seqdata = SequencePatient(data={'v':demo['visit'], 'y':demo['mortality'], 'x':demo['feature'],},\n", - " metadata={\n", - " 'visit':{'mode':'dense'},\n", - " 'label':{'mode':'tensor'}, \n", - " 'voc':demo['voc'],\n", - " 'max_visit':20,\n", - " }\n", - " )\n", - "\n", - "print('visit', demo['visit'][0]) # a list of visit events\n", - "print('mortality', demo['mortality'][0]) # array of labels\n", - "print('feature', demo['feature'][0]) # array of patient baseline features\n", - "print('voc', demo['voc']) # dict of dicts containing the mapping from index to the original event names\n", - "print('order', demo['order']) # a list of three types of code\n", - "print('n_num_feature', demo['n_num_feature']) # int: a number of patient's numerical features\n", - "print('cat_cardinalities', demo['cat_cardinalities']) # list: a list of cardinalities of patient's categorical features" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "9fc0de6e", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", - "The tokenizer class you load from this checkpoint is 'BartTokenizer'. \n", - "The class this function is called from is 'DataTokenizer'.\n", - "/home/zifengw2/miniconda3/envs/promptehr/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", - " warnings.warn(\n", - "***** Running training *****\n", - " Num examples = 100\n", - " Num Epochs = 1\n", - " Instantaneous batch size per device = 16\n", - " Total train batch size (w. parallel, distributed & accumulation) = 128\n", - " Gradient Accumulation steps = 1\n", - " Total optimization steps = 1\n", - "Token indices sequence length is longer than the specified maximum sequence length for this model (552 > 512). Running this sequence through the model will result in indexing errors\n", - "/home/zifengw2/miniconda3/envs/promptehr/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", - " warnings.warn('Was asked to gather along dimension 0, but all '\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [1/1 00:06, Epoch 1/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining LossValidation LossPpl DiagPpl ProdPpl Med
16.695700No log897.926819353.614288110.910278

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "evaluation for code diag.\n", - "***** Running Evaluation *****\n", - " Num examples = 100\n", - " Batch size = 512\n", - "evaluation for code prod.\n", - "***** Running Evaluation *****\n", - " Num examples = 100\n", - " Batch size = 512\n", - "evaluation for code med.\n", - "***** Running Evaluation *****\n", - " Num examples = 100\n", - " Batch size = 512\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Saving model checkpoint to ./promptEHR_logs/checkpoint-1\n", - "Configuration saved in ./promptEHR_logs/checkpoint-1/config.json\n", - "Model weights saved in ./promptEHR_logs/checkpoint-1/pytorch_model.bin\n", - "\n", - "\n", - "Training completed. Do not forget to share your model on huggingface.co/models =)\n", - "\n", - "\n", - "Loading best model from ./promptEHR_logs/checkpoint-1 (score: 897.9268188476562).\n" - ] - } - ], - "source": [ - "# fit the model\n", - "model = PromptEHR(\n", - " code_type=demo['order'],\n", - " n_num_feature=demo['n_num_feature'],\n", - " cat_cardinalities=demo['cat_cardinalities'],\n", - " num_worker=0,\n", - " eval_step=1,\n", - " epoch=1,\n", - " device=[1,2],\n", - ")\n", - "model.fit(\n", - " train_data=seqdata,\n", - " val_data=seqdata,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "41709250", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Configuration saved in ./simulation/promptEHR/config.json\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Save the trained model to: ./simulation/promptEHR\n" - ] - } - ], - "source": [ - "# save the model\n", - "model.save_model('./simulation/promptEHR')" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "2e5cff5c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "550it [00:47, 11.54it/s] \n" - ] - } - ], - "source": [ - "# generate fake records\n", - "res = model.predict(seqdata, n_per_sample=10, n=100, verbose=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "de881ccb", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'visit': [[[[1, 3, 4, 6, 7, 202, 235, 684, 2], [601, 226, 9, 7]], [[0, 2, 71], [153, 3, 175]], [[97, 2, 100, 4, 5, 6, 9, 74, 11, 12, 15, 18, 19, 51, 87, 93], [0, 1, 2, 3, 4, 6, 40, 8, 11, 79, 15, 16, 17, 23, 56, 30]]], [[[2, 4, 5, 6, 202, 235, 621, 530, 2], [8, 9, 202, 7]], [[0, 1, 2], [153, 71, 175]], [[64, 97, 0, 3, 4, 5, 6, 7, 8, 10, 11, 14, 19, 30], [64, 1, 2, 3, 4, 6, 7, 9, 11, 15, 80, 16, 82, 51, 17, 23]]], [[[3, 4, 6, 7, 235, 684, 530, 632, 2], [8, 601, 10, 9]], [[0, 2, 175], [3, 4, 71]], [[2, 3, 4, 8, 10, 11, 12, 13, 14, 47, 15, 82, 51, 19, 23], [64, 1, 3, 4, 6, 74, 11, 10, 13, 14, 16, 82, 51, 18, 94, 30]]], [[[2, 3, 7, 202, 684, 530, 632, 637, 2], [8, 9, 202, 226]], [[2, 19, 71], [153, 71, 175]], [[64, 1, 2, 97, 100, 5, 7, 8, 42, 10, 12, 14, 15, 18], [33, 2, 97, 4, 38, 7, 40, 8, 74, 9, 10, 14, 79, 82, 87, 56, 30]]], [[[0, 1, 3, 6, 7, 621, 530, 637, 2], [601, 202, 10, 7]], [[19, 71, 175], [153, 4, 71]], [[0, 1, 97, 2, 9, 10, 11, 12, 14, 47, 15, 18, 19, 51, 23], [0, 97, 1, 2, 4, 5, 6, 40, 8, 9, 10, 11, 13, 16, 17, 18, 23]]], [[[0, 2, 3, 235, 684, 530, 632, 542, 2], [9, 226, 542, 7]], [[0, 2, 71], [153, 4, 175]], [[64, 97, 1, 3, 2, 5, 7, 74, 11, 14, 47, 15, 82, 51], [64, 97, 2, 1, 6, 40, 8, 74, 9, 11, 13, 15, 80, 18, 82, 23, 56]]], [[[0, 1, 2, 3, 4, 6, 7, 621, 2], [10, 202, 419, 7]], [[113, 71, 175], [153, 3, 4]], [[0, 97, 1, 2, 4, 5, 9, 74, 11, 10, 12, 14, 13, 18, 82], [64, 97, 0, 3, 74, 42, 77, 13, 15, 80, 16, 51, 87, 23, 58, 27]]], [[[2, 4, 6, 235, 684, 621, 632, 542, 2], [601, 226, 9, 7]], [[0, 19, 71], [153, 3, 4]], [[64, 2, 3, 100, 4, 5, 6, 8, 42, 10, 13, 14, 47, 15, 30], [33, 2, 1, 3, 5, 38, 4, 7, 8, 74, 11, 13, 14, 15, 80, 18]]], [[[2, 7, 235, 684, 621, 530, 632, 542, 2], [9, 226, 542, 7]], [[1, 2, 71], [153, 4, 71]], [[97, 2, 3, 4, 5, 7, 74, 42, 11, 12, 14, 47, 15, 82, 19], [96, 0, 1, 2, 5, 38, 39, 7, 10, 11, 15, 16, 17, 82, 23, 56]]], [[[0, 1, 3, 4, 6, 235, 684, 542, 2], [8, 9, 202, 7]], [[153, 1, 175], [153, 3, 4]], [[0, 2, 4, 5, 7, 74, 11, 42, 77, 10, 47, 15, 14, 51, 30], [64, 0, 2, 1, 4, 38, 7, 8, 41, 9, 11, 14, 80, 17, 23, 29]]], [[[608, 2, 35, 197, 359, 103, 650, 621, 622, 14, 16, 17, 18, 19, 21, 23, 411, 542, 17], [1, 141, 14, 18, 20, 21, 25, 411, 28, 30, 32, 33, 549, 37, 38, 39, 40, 41, 444, 622]], [[99, 133, 197, 71, 6, 9, 12, 206, 175, 14, 244, 150], [17, 18, 175]], [[2, 34, 4, 5, 6, 26, 3, 42, 78, 80, 16, 18, 19, 20, 21, 58, 29, 94], [0, 1, 2, 32, 4, 33, 34, 7, 8, 41, 9, 82, 19, 23, 56, 27, 31]]], [[[197, 103, 650, 524, 622, 494, 14, 15, 690, 16, 18, 373, 19, 20, 21, 23, 26, 411, 17], [1, 18, 280, 28, 284, 29, 30, 31, 33, 35, 549, 37, 38, 40, 427, 566, 77, 485, 103, 235]], [[230, 7, 9, 170, 10, 11, 12, 13, 14, 15, 16, 153], [18, 12, 175]], [[3, 5, 9, 12, 18, 20, 21, 22, 23, 24, 26, 28, 29, 39, 47, 56, 71, 74, 78, 94], [0, 1, 2, 3, 4, 6, 7, 8, 14, 15, 22, 23, 31, 34, 35, 47, 58, 64, 78, 82]]], [[[197, 359, 27, 11, 524, 621, 622, 13, 15, 16, 690, 274, 373, 21, 23, 24, 411, 542, 17], [643, 663, 408, 25, 27, 28, 283, 30, 33, 34, 36, 677, 39, 40, 681, 41, 444, 337, 483, 621]], [[100, 133, 6, 7, 8, 9, 10, 13, 206, 14, 15, 16], [153, 18, 71]], [[2, 5, 6, 9, 14, 16, 19, 20, 23, 24, 26, 34, 56, 58, 64, 69, 71, 78, 80], [0, 32, 2, 3, 4, 33, 34, 71, 8, 9, 41, 44, 78, 80, 17, 22, 23, 31]]], [[[608, 197, 359, 650, 524, 13, 494, 622, 16, 274, 18, 19, 20, 21, 22, 25, 26, 27, 17], [1, 524, 14, 21, 663, 25, 27, 28, 411, 539, 29, 32, 33, 34, 31, 38, 39, 690, 444, 616]], [[228, 5, 6, 71, 7, 8, 10, 43, 13, 175, 150, 153], [17, 18, 12]], [[2, 3, 4, 5, 15, 16, 18, 19, 23, 25, 27, 28, 30, 47, 51, 69, 74, 78, 82], [1, 2, 3, 4, 9, 14, 15, 19, 23, 27, 33, 34, 42, 51, 56, 58, 64, 74, 80]]], [[[2, 35, 4, 359, 103, 650, 621, 622, 494, 15, 16, 274, 17, 19, 277, 24, 411, 542, 17], [524, 141, 14, 20, 21, 534, 27, 156, 411, 539, 542, 549, 38, 39, 40, 41, 695, 55, 444, 621]], [[99, 67, 133, 102, 71, 100, 9, 10, 175, 16, 150, 153], [17, 12, 175]], [[2, 3, 4, 12, 13, 16, 18, 20, 21, 22, 24, 25, 27, 30, 42, 58, 69, 78, 80, 94], [2, 3, 6, 7, 8, 9, 17, 19, 21, 26, 27, 31, 34, 35, 47, 69, 71, 74, 78, 96]]], [[[2, 35, 4, 27, 524, 621, 622, 15, 16, 17, 18, 19, 373, 277, 22, 26, 411, 542, 17], [1, 524, 14, 20, 663, 25, 411, 539, 28, 29, 32, 33, 34, 35, 38, 39, 41, 690, 616, 621]], [[132, 197, 102, 7, 8, 9, 10, 12, 13, 206, 14, 16], [18, 71, 175]], [[96, 1, 34, 3, 26, 5, 39, 74, 78, 14, 80, 19, 20, 22, 23, 24, 25, 58], [0, 1, 2, 7, 17, 18, 19, 23, 26, 27, 31, 33, 34, 35, 42, 51, 56, 64, 78]]], [[[608, 103, 11, 621, 622, 13, 14, 16, 690, 17, 18, 19, 20, 22, 23, 25, 26, 411, 17], [643, 131, 18, 663, 25, 283, 28, 31, 33, 293, 677, 549, 38, 39, 41, 52, 469, 483, 637]], [[100, 102, 71, 134, 6, 7, 10, 11, 206, 15, 16, 244], [153, 18, 175]], [[64, 0, 1, 3, 5, 6, 78, 15, 18, 19, 82, 51, 21, 23, 25, 26, 27, 30], [0, 1, 2, 3, 32, 33, 35, 7, 8, 41, 74, 78, 47, 17, 18, 19, 22]]], [[[609, 4, 103, 27, 650, 11, 621, 494, 13, 14, 15, 19, 373, 22, 23, 24, 25, 411, 17], [264, 14, 18, 21, 534, 279, 29, 32, 549, 165, 37, 681, 170, 427, 41, 433, 327, 458, 485, 235]], [[132, 133, 6, 8, 9, 11, 12, 13, 14, 175, 244, 22], [17, 18, 175]], [[2, 3, 4, 5, 6, 9, 12, 14, 19, 20, 21, 23, 24, 26, 28, 29, 47, 78, 94], [0, 1, 2, 3, 33, 34, 6, 39, 7, 35, 42, 14, 47, 17, 82, 19, 23, 56]]], [[[4, 197, 359, 650, 524, 12, 622, 14, 15, 16, 690, 17, 18, 20, 24, 25, 411, 542, 17], [1, 141, 14, 18, 20, 149, 21, 663, 408, 25, 283, 28, 29, 34, 36, 38, 40, 188, 621, 637]], [[225, 100, 5, 6, 71, 8, 9, 10, 11, 13, 15, 244], [17, 132, 175]], [[64, 1, 2, 3, 5, 39, 74, 78, 14, 80, 18, 19, 21, 22, 25, 26, 28, 30], [0, 1, 6, 7, 9, 14, 17, 18, 19, 23, 31, 32, 33, 35, 42, 51, 64, 78, 82]]], [[[608, 2, 35, 4, 359, 11, 524, 621, 12, 13, 14, 17, 19, 277, 21, 22, 25, 26, 17], [18, 20, 21, 663, 25, 29, 30, 542, 35, 36, 549, 37, 39, 40, 41, 183, 188, 483, 621, 637]], [[100, 197, 71, 8, 9, 10, 11, 12, 13, 175, 15, 16], [132, 12, 71]], [[1, 4, 5, 12, 14, 15, 16, 19, 20, 22, 25, 26, 29, 47, 51, 58, 72, 74, 78], [0, 33, 1, 3, 2, 32, 34, 41, 74, 9, 78, 14, 80, 15, 17, 19, 26]]], [[[42, 587, 43, 44, 45, 46, 47, 337, 48, 542, 46], [524, 46, 51, 20, 52, 120, 283, 542]], [[175], [67, 4, 227, 71, 15, 21, 23]], [[64, 1, 2, 26, 4, 37, 38, 74, 12, 14, 47, 19, 22, 23, 58, 29], [0, 32, 3, 5, 41, 11, 13, 14, 15, 80, 16, 51, 19, 23, 26]]], [[[608, 609, 359, 42, 587, 46, 48, 49, 695, 603, 46], [524, 621, 45, 50, 51, 663, 283, 542]], [[175], [1, 100, 15, 20, 21, 150, 153]], [[33, 34, 2, 1, 37, 6, 8, 78, 47, 14, 82, 22, 23, 29, 30, 31], [0, 33, 2, 3, 1, 69, 6, 9, 11, 12, 14, 15, 80, 16, 58]]], [[[35, 359, 7, 42, 587, 43, 46, 337, 695, 603, 46], [677, 7, 45, 46, 51, 52, 53, 283]], [[175], [1, 4, 71, 175, 15, 20, 150]], [[0, 1, 2, 3, 4, 26, 8, 74, 12, 14, 16, 22, 56, 58, 30, 31], [64, 33, 2, 3, 1, 32, 6, 7, 9, 74, 12, 13, 14, 16, 56, 26, 27]]], [[[608, 35, 359, 587, 44, 48, 337, 695, 603, 542, 46], [7, 621, 46, 51, 628, 53, 663, 283]], [[175], [227, 4, 15, 20, 150, 22, 153]], [[0, 1, 34, 3, 4, 37, 6, 38, 40, 78, 14, 48, 18, 23, 26, 60, 30], [0, 2, 3, 4, 5, 9, 11, 12, 78, 15, 14, 16, 82, 19, 58, 27]]], [[[7, 42, 43, 44, 45, 46, 47, 48, 49, 24, 46], [7, 649, 524, 45, 50, 52, 280, 542]], [[19], [132, 206, 15, 175, 23, 150, 22]], [[0, 97, 34, 2, 4, 1, 37, 41, 74, 12, 14, 80, 16, 22, 23, 26, 30], [64, 1, 2, 3, 4, 32, 6, 9, 74, 14, 47, 80, 15, 51, 19, 23, 26]]], [[[35, 7, 44, 45, 47, 48, 337, 49, 411, 542, 46], [45, 46, 20, 53, 534, 663, 283, 542]], [[19], [206, 15, 175, 20, 22, 23]], [[64, 0, 2, 34, 4, 1, 6, 38, 12, 14, 47, 80, 18, 29, 30, 31], [64, 0, 2, 1, 4, 5, 41, 11, 13, 78, 47, 14, 15, 16, 19, 23, 56, 27]]], [[[359, 7, 42, 44, 45, 46, 47, 48, 24, 542, 46], [645, 46, 50, 20, 373, 53, 283, 542]], [[19], [1, 132, 71, 175, 15, 22, 55]], [[0, 1, 3, 36, 26, 6, 74, 78, 47, 16, 14, 18, 82, 22, 87, 58, 31], [0, 33, 1, 2, 32, 5, 7, 9, 42, 11, 47, 15, 16, 82, 19, 56, 58, 27]]], [[[608, 609, 7, 587, 43, 46, 47, 48, 411, 542, 46], [7, 616, 46, 50, 51, 20, 663, 542]], [[175], [4, 100, 23, 21, 150, 22, 153]], [[1, 2, 3, 36, 26, 74, 14, 16, 18, 19, 82, 22, 23, 58, 31], [64, 32, 4, 5, 6, 7, 72, 9, 12, 13, 78, 14, 80, 15, 18, 23, 56, 26]]], [[[608, 609, 35, 359, 7, 46, 337, 411, 24, 603, 46], [7, 45, 690, 50, 20, 51, 53, 542]], [[19], [1, 132, 15, 20, 150, 22]], [[64, 0, 4, 38, 8, 74, 12, 14, 47, 80, 16, 18, 22, 23, 26, 29, 30], [64, 33, 1, 3, 4, 5, 6, 41, 42, 78, 47, 80, 15, 82, 51, 14, 16, 26]]], [[[359, 7, 42, 587, 45, 47, 48, 49, 695, 24, 46], [50, 51, 52, 53, 534, 663, 283, 542]], [[175], [1, 100, 15, 20, 22, 23, 153]], [[33, 2, 3, 4, 36, 38, 8, 74, 12, 14, 80, 18, 82, 29, 31], [64, 96, 32, 4, 5, 38, 6, 7, 9, 14, 47, 48, 15, 18, 19, 93, 30]]], [[[11, 15, 274, 20, 54, 55, 56, 57, 57], [549, 40, 61, 20, 62, 411, 637, 542]], [[24, 71, 175], [26, 27, 175]], [[64, 0, 2, 4, 13, 78, 16, 82, 22, 23, 26, 29, 31], [64, 1, 41, 74, 42, 13, 47, 15, 82, 23, 27, 28, 29, 31]]], [[[11, 587, 337, 690, 54, 55, 56, 57, 57], [616, 15, 411, 20, 58, 59, 60, 62]], [[24, 153, 175], [153, 27, 71]], [[64, 4, 38, 40, 74, 12, 78, 82, 22, 23, 56, 25, 26], [64, 1, 2, 74, 42, 13, 14, 15, 17, 82, 22, 23, 56, 31]]], [[[337, 690, 274, 18, 20, 55, 411, 542, 57], [621, 15, 411, 58, 59, 60, 61, 62]], [[57, 13, 175], [25, 26, 71]], [[2, 3, 39, 40, 74, 12, 78, 80, 82, 22, 29, 94, 31], [64, 0, 39, 74, 13, 78, 47, 15, 17, 51, 19, 27, 28, 29]]], [[[449, 609, 18, 20, 54, 55, 56, 542, 57], [449, 59, 15, 280, 58, 283, 60, 542]], [[14, 71, 175], [25, 27, 71]], [[1, 74, 13, 78, 47, 80, 14, 18, 82, 16, 22, 26, 31], [64, 32, 2, 41, 13, 14, 80, 17, 51, 56, 27, 28, 29, 31]]], [[[11, 15, 18, 20, 54, 55, 56, 57, 57], [524, 622, 15, 283, 20, 534, 59, 62]], [[13, 14, 71], [26, 27, 175]], [[0, 1, 39, 40, 12, 13, 14, 15, 80, 51, 22, 26, 31], [32, 1, 2, 41, 74, 13, 14, 15, 80, 23, 56, 26, 27, 31]]], [[[609, 11, 15, 54, 55, 56, 57, 411, 57], [524, 534, 663, 58, 411, 60, 61, 62]], [[24, 71, 175], [26, 27, 71]], [[2, 4, 74, 78, 14, 80, 15, 82, 51, 23, 56, 25, 26], [0, 1, 2, 3, 41, 74, 42, 22, 23, 26, 27, 28, 29, 94]]], [[[609, 587, 11, 337, 690, 20, 54, 411, 57], [35, 524, 534, 663, 58, 60, 61, 62]], [[24, 14, 175], [25, 26, 27]], [[2, 4, 38, 39, 40, 74, 12, 14, 47, 80, 16, 29, 31], [64, 32, 2, 41, 74, 42, 78, 47, 80, 14, 15, 19, 26, 28]]], [[[609, 11, 690, 274, 54, 55, 56, 411, 57], [59, 15, 690, 534, 663, 58, 411, 60]], [[57, 71, 175], [27, 71, 175]], [[2, 3, 39, 40, 74, 78, 16, 82, 51, 25, 23, 94, 30], [0, 1, 74, 13, 78, 47, 80, 14, 18, 19, 17, 26, 28, 30]]], [[[609, 587, 11, 15, 274, 690, 411, 542, 57], [616, 621, 690, 20, 58, 411, 60, 542]], [[24, 13, 175], [25, 26, 71]], [[0, 2, 3, 40, 42, 78, 18, 51, 23, 94, 26, 93, 30], [64, 0, 2, 38, 74, 78, 80, 18, 82, 22, 56, 28, 29, 30]]], [[[449, 587, 11, 337, 690, 18, 20, 542, 57], [387, 40, 50, 20, 534, 58, 59, 61]], [[13, 14, 71], [153, 71, 175]], [[64, 0, 2, 4, 40, 74, 14, 15, 16, 51, 22, 56, 31], [1, 2, 74, 13, 78, 47, 14, 82, 19, 51, 56, 26, 29, 31]]], [[[608, 64, 65, 35, 66, 67, 587, 690, 274, 411, 444, 542, 20], [293, 71, 73, 75, 76, 621, 46, 77, 52, 283, 637, 542], [68, 70, 46, 79, 52, 663, 283, 637, 542]], [[32, 33, 226, 228, 133, 175, 15, 50], [36, 37, 41, 108, 175, 147, 28], [1, 100, 133, 4, 43, 21, 150, 153]], [[2, 5, 6, 72, 43, 12, 45, 14, 47, 80, 49, 18, 82, 48, 22, 56, 11], [0, 2, 3, 6, 7, 10, 14, 15, 17, 18, 22, 29, 32, 38, 39, 42, 47, 64, 80], [0, 4, 5, 9, 11, 13, 14, 15, 16, 17, 18, 19, 26, 29, 33, 41, 42, 47, 58, 82]]], [[[609, 65, 35, 66, 69, 587, 462, 337, 690, 20, 444, 63, 20], [483, 643, 293, 70, 72, 681, 75, 77, 663, 283, 637, 542], [69, 70, 235, 46, 78, 20, 52, 663, 411]], [[225, 226, 2, 100, 133, 175, 30, 31], [36, 134, 38, 39, 108, 175, 63], [133, 71, 41, 14, 175, 21, 150, 153]], [[1, 2, 4, 5, 14, 15, 18, 19, 23, 30, 40, 41, 42, 45, 46, 47, 64, 71, 72], [1, 2, 3, 6, 7, 9, 15, 23, 29, 38, 41, 42, 44, 45, 47, 52, 53, 56, 80, 82], [64, 2, 3, 41, 74, 42, 13, 78, 47, 15, 48, 82, 51, 18, 53, 17, 58]]], [[[65, 66, 69, 70, 587, 462, 47, 46, 20, 52, 542, 63, 20], [483, 35, 293, 68, 681, 73, 621, 52, 663, 283, 188, 542], [483, 68, 70, 46, 78, 79, 50, 663, 542]], [[2, 228, 5, 133, 175, 15, 50, 30], [34, 67, 227, 35, 38, 150, 28], [1, 100, 40, 41, 43, 14, 175, 153]], [[2, 3, 6, 14, 15, 18, 19, 22, 23, 42, 45, 47, 48, 51, 56, 69, 71, 78, 80], [2, 6, 10, 12, 14, 15, 18, 22, 23, 28, 42, 47, 51, 53, 56, 64, 72, 78, 80], [64, 33, 0, 3, 2, 6, 7, 41, 42, 11, 10, 47, 15, 48, 16, 19, 22, 29]]], [[[608, 609, 64, 67, 68, 587, 337, 210, 690, 274, 20, 411, 20], [35, 68, 293, 70, 72, 73, 74, 76, 77, 46, 52, 663], [35, 68, 293, 46, 78, 79, 20, 283, 542]], [[32, 33, 228, 5, 133, 71, 50, 30], [36, 38, 41, 175, 147, 28, 63], [67, 227, 133, 71, 40, 175, 150, 153]], [[2, 3, 4, 7, 8, 11, 14, 15, 17, 18, 30, 40, 41, 47, 49, 56, 64, 78, 93], [0, 2, 6, 10, 14, 18, 19, 26, 32, 42, 47, 51, 52, 53, 56, 72, 74, 78, 80], [64, 33, 1, 4, 8, 41, 74, 9, 12, 78, 47, 14, 15, 51, 22, 23, 26]]], [[[608, 609, 64, 65, 67, 70, 337, 210, 690, 274, 411, 444, 20], [65, 68, 71, 74, 75, 76, 77, 494, 116, 628, 601, 542], [581, 235, 621, 46, 79, 52, 149, 534, 283]], [[226, 2, 228, 5, 133, 50, 29, 31], [34, 35, 36, 38, 39, 175, 84], [227, 133, 40, 41, 43, 175, 21, 150]], [[0, 1, 2, 9, 11, 12, 15, 17, 19, 23, 30, 47, 51, 52, 56, 64, 68, 74, 80, 82], [3, 5, 7, 10, 12, 13, 14, 17, 19, 22, 23, 32, 44, 47, 51, 72, 74, 78, 82], [64, 0, 3, 26, 69, 37, 6, 4, 42, 10, 14, 15, 17, 82, 22, 23, 56, 58]]], [[[608, 64, 66, 67, 68, 69, 70, 47, 337, 210, 542, 63, 20], [65, 35, 68, 71, 74, 75, 76, 77, 52, 283, 637, 542], [35, 483, 677, 68, 69, 621, 79, 663, 542]], [[32, 33, 5, 15, 22, 28, 30, 31], [259, 132, 36, 37, 71, 38, 22], [228, 4, 40, 41, 42, 43, 14, 21]], [[0, 1, 3, 6, 8, 15, 18, 19, 26, 41, 42, 44, 47, 50, 51, 56, 58, 74, 82], [0, 1, 3, 6, 12, 13, 17, 22, 26, 29, 32, 38, 39, 42, 45, 47, 52, 56, 74, 80], [0, 33, 1, 4, 5, 7, 8, 74, 10, 14, 15, 48, 47, 18, 51, 16, 17, 54]]], [[[608, 609, 64, 66, 67, 68, 46, 274, 20, 52, 542, 63, 20], [35, 643, 677, 70, 73, 74, 75, 77, 52, 663, 283, 188], [68, 677, 69, 70, 78, 79, 50, 52, 283]], [[32, 33, 133, 175, 15, 50, 28, 29], [35, 259, 132, 228, 37, 38, 28], [1, 40, 41, 206, 175, 14, 21, 150]], [[5, 14, 19, 20, 22, 37, 41, 42, 43, 44, 46, 48, 49, 50, 74, 76, 80, 82, 94], [64, 1, 0, 32, 6, 42, 12, 45, 15, 80, 17, 82, 19, 51, 53, 23, 28, 29], [0, 1, 2, 3, 4, 5, 7, 9, 10, 15, 16, 18, 23, 26, 29, 40, 41, 51, 53, 58]]], [[[64, 609, 65, 35, 67, 68, 69, 337, 210, 690, 411, 444, 20], [483, 70, 72, 74, 75, 76, 621, 46, 77, 52, 663, 542], [483, 68, 293, 677, 20, 52, 283, 637, 542]], [[225, 33, 2, 71, 15, 212, 29, 30], [35, 36, 37, 38, 39, 41, 28], [132, 228, 71, 40, 41, 175, 22, 55]], [[0, 2, 3, 6, 7, 8, 11, 12, 14, 18, 19, 20, 22, 28, 45, 48, 51, 56, 78], [0, 1, 2, 5, 17, 18, 19, 22, 26, 28, 29, 32, 38, 41, 42, 44, 52, 53, 56, 74], [0, 3, 6, 7, 10, 14, 15, 19, 22, 26, 39, 40, 41, 42, 47, 48, 51, 64, 74, 82]]], [[[608, 609, 67, 68, 70, 46, 337, 690, 274, 411, 444, 542, 20], [65, 420, 616, 72, 74, 75, 76, 77, 622, 46, 55, 223], [68, 198, 70, 616, 649, 46, 50, 436, 542]], [[33, 228, 133, 71, 175, 15, 50, 31], [36, 37, 134, 41, 147, 84, 63], [1, 67, 227, 133, 4, 43, 175, 153]], [[1, 4, 11, 13, 14, 15, 19, 23, 47, 48, 49, 50, 51, 56, 71, 74, 78, 93, 94], [1, 2, 3, 7, 14, 17, 18, 19, 26, 42, 47, 51, 53, 64, 72, 74, 78, 80, 82], [1, 2, 3, 6, 9, 16, 17, 22, 23, 29, 33, 40, 47, 48, 51, 64, 71, 80, 82]]], [[[65, 66, 67, 70, 587, 462, 47, 46, 210, 20, 411, 63, 20], [643, 68, 293, 70, 71, 72, 73, 74, 621, 77, 52, 542], [483, 293, 677, 69, 621, 50, 663, 637, 542]], [[32, 2, 228, 5, 50, 28, 30, 31], [36, 134, 38, 41, 108, 147, 84], [1, 67, 227, 133, 4, 43, 21, 153]], [[1, 2, 5, 6, 11, 13, 14, 15, 28, 45, 47, 50, 56, 64, 72, 74, 80, 82, 93], [0, 1, 2, 38, 39, 72, 41, 74, 11, 45, 47, 15, 17, 82, 19, 51, 56], [0, 1, 3, 6, 10, 12, 14, 15, 33, 37, 41, 47, 48, 53, 54, 56, 78, 80, 82]]], [[[521, 202, 12, 81, 83, 85, 86, 632, 411, 84], [89, 96, 98, 3, 324, 100, 99, 427, 80, 18, 149, 88, 313, 542, 95]], [[135, 44, 45, 175, 47, 153], [2, 133, 71, 175, 48, 49, 50, 15, 153]], [[64, 0, 2, 35, 4, 6, 71, 42, 76, 51, 19, 55, 56, 94, 58, 59, 28, 30], [1, 2, 6, 9, 11, 14, 15, 18, 19, 32, 42, 43, 56, 63, 64, 71, 74, 82, 94]]], [[[521, 202, 684, 83, 500, 84, 85, 283, 542, 84], [89, 97, 226, 98, 100, 645, 3, 327, 99, 18, 149, 313, 188, 93, 542, 95]], [[67, 71, 44, 175, 22, 57], [2, 259, 170, 43, 45, 13, 48, 18, 51]], [[64, 1, 2, 35, 4, 8, 11, 14, 47, 80, 46, 17, 19, 20, 55, 23, 28, 30], [1, 2, 4, 6, 7, 17, 18, 19, 20, 22, 38, 41, 56, 62, 64, 65, 68, 71, 76]]], [[[521, 684, 12, 80, 81, 83, 84, 85, 86, 84], [96, 161, 226, 353, 98, 3, 99, 429, 80, 373, 534, 90, 539, 380, 93, 95]], [[133, 44, 46, 175, 15, 50], [2, 133, 71, 45, 175, 51, 22, 88, 153]], [[2, 4, 5, 6, 9, 11, 15, 19, 20, 30, 42, 44, 46, 47, 48, 55, 56, 58, 71, 74], [2, 5, 7, 8, 9, 14, 15, 18, 19, 22, 32, 42, 43, 47, 56, 63, 65, 66, 74, 94]]], [[[202, 684, 12, 80, 500, 84, 632, 283, 542, 84], [226, 98, 100, 645, 99, 313, 327, 681, 112, 18, 149, 373, 88, 601, 285, 95]], [[71, 44, 46, 15, 22, 57], [2, 133, 45, 15, 49, 50, 150, 22, 153]], [[97, 2, 1, 6, 71, 42, 44, 14, 15, 46, 48, 82, 19, 23, 56, 28, 93, 30], [64, 97, 98, 0, 6, 7, 8, 77, 14, 47, 80, 15, 18, 19, 20, 22, 63]]], [[[521, 684, 80, 83, 84, 86, 87, 632, 411, 84], [226, 98, 100, 327, 681, 202, 80, 18, 373, 470, 534, 89, 188, 93, 542]], [[45, 46, 15, 175, 47, 22], [2, 13, 47, 48, 49, 18, 51, 15, 153]], [[0, 97, 2, 5, 71, 8, 41, 42, 11, 46, 47, 48, 82, 19, 23, 57, 59, 28], [0, 1, 2, 3, 4, 5, 7, 11, 15, 17, 18, 22, 23, 35, 42, 56, 62, 74, 93]]], [[[521, 202, 684, 12, 80, 83, 85, 87, 542, 84], [96, 100, 202, 112, 80, 93, 373, 149, 534, 94, 89, 92, 445, 542, 95]], [[226, 44, 15, 47, 22, 57], [259, 228, 170, 175, 15, 49, 50, 22, 153]], [[1, 2, 4, 5, 6, 7, 9, 42, 14, 19, 20, 59, 28, 93, 94, 95], [1, 2, 3, 8, 15, 18, 19, 20, 22, 30, 38, 64, 65, 66, 71, 74, 76, 81, 96]]], [[[521, 621, 81, 82, 83, 411, 84, 86, 283, 84], [89, 96, 97, 99, 100, 649, 91, 50, 18, 628, 149, 534, 313, 59, 188, 94]], [[226, 71, 45, 46, 47, 20], [2, 100, 102, 71, 15, 51, 22, 151, 153]], [[1, 2, 97, 4, 8, 9, 74, 14, 82, 19, 20, 55, 56, 94, 57, 93, 30, 95], [1, 6, 7, 15, 18, 23, 30, 32, 35, 39, 56, 57, 64, 65, 66, 71, 74, 76, 82, 96]]], [[[202, 621, 283, 81, 82, 500, 87, 411, 542, 84], [96, 97, 98, 3, 645, 112, 149, 534, 408, 89, 90, 88, 92, 93, 542, 95]], [[67, 44, 46, 175, 20, 57], [2, 43, 45, 13, 15, 18, 50, 51, 22]], [[1, 2, 6, 71, 8, 42, 11, 46, 15, 14, 51, 20, 19, 23, 56, 94, 30], [64, 65, 3, 4, 7, 71, 74, 11, 42, 14, 47, 80, 82, 51, 61, 94, 93, 30]]], [[[202, 684, 80, 82, 85, 87, 632, 283, 542, 84], [98, 419, 3, 327, 681, 202, 233, 50, 18, 534, 88, 89, 90, 59, 188, 95]], [[67, 175, 47, 20, 22, 57], [2, 259, 228, 47, 48, 49, 51, 22, 153]], [[3, 4, 6, 11, 14, 15, 17, 20, 28, 30, 44, 46, 48, 51, 58, 64, 74, 80, 82, 97], [0, 1, 3, 5, 7, 8, 14, 15, 17, 20, 22, 39, 43, 56, 64, 65, 66, 74, 81, 96]]], [[[521, 202, 621, 81, 84, 85, 632, 411, 542, 84], [97, 226, 3, 99, 425, 202, 91, 18, 373, 149, 470, 534, 89, 88, 92, 95]], [[67, 71, 175, 47, 151, 57], [2, 259, 45, 15, 48, 49, 50, 22, 153]], [[0, 1, 2, 3, 74, 42, 14, 15, 80, 17, 82, 19, 51, 20, 56, 59, 28], [1, 4, 5, 7, 14, 15, 18, 19, 23, 30, 42, 47, 51, 64, 71, 74, 80, 82, 93]]], [[[101, 37, 103, 73, 235, 15, 82, 54, 632, 92, 542, 88], [226, 645, 549, 101, 37, 202, 13, 15, 150, 55, 88, 542]], [[46, 175], [153, 52, 15, 175]], [[2, 8, 11, 14, 15, 22, 23, 30, 42, 51, 52, 56, 58, 64, 69, 70, 74, 77, 82], [1, 4, 6, 11, 13, 18, 22, 23, 27, 30, 48, 49, 56, 57, 65, 69, 70, 71, 97]]], [[[101, 521, 202, 13, 15, 530, 82, 54, 283, 92, 542, 88], [226, 419, 649, 202, 524, 50, 628, 84, 534, 59, 156, 542]], [[46, 71], [153, 170, 71, 175]], [[3, 4, 10, 14, 18, 22, 28, 37, 42, 44, 48, 60, 63, 64, 65, 66, 67, 69, 70], [97, 3, 68, 76, 77, 14, 13, 16, 17, 48, 51, 49, 23, 57, 60, 61, 30]]], [[[37, 103, 202, 235, 684, 13, 530, 82, 88, 283, 92, 88], [226, 37, 649, 202, 105, 106, 334, 15, 50, 628, 59, 156]], [[71, 175], [153, 2, 71, 175]], [[65, 3, 67, 37, 70, 8, 10, 11, 28, 14, 15, 48, 17, 18, 51, 19, 22, 60], [5, 7, 8, 12, 14, 16, 19, 27, 46, 48, 57, 60, 61, 68, 69, 70, 76, 81, 93, 97]]], [[[37, 102, 103, 73, 684, 621, 15, 54, 92, 637, 542, 88], [104, 649, 73, 105, 524, 13, 334, 15, 628, 534, 59, 542]], [[2, 175], [2, 52, 53, 15]], [[2, 3, 8, 13, 14, 17, 22, 23, 30, 42, 44, 51, 60, 67, 69, 70, 71, 77, 97], [64, 97, 2, 3, 34, 4, 70, 71, 1, 11, 13, 14, 77, 81, 17, 51, 27, 30]]], [[[101, 102, 37, 202, 530, 54, 55, 632, 88, 637, 542, 88], [101, 37, 649, 13, 15, 628, 84, 534, 54, 88, 380, 542]], [[2, 46], [153, 2, 53, 175]], [[2, 4, 13, 14, 15, 28, 30, 37, 38, 56, 64, 65, 68, 69, 71, 74, 77, 78, 82, 97], [1, 13, 14, 15, 16, 23, 27, 49, 61, 64, 69, 70, 71, 73, 74, 77, 80, 82, 87]]], [[[103, 521, 202, 73, 684, 13, 15, 530, 632, 283, 637, 88], [645, 101, 103, 104, 13, 15, 628, 150, 55, 88, 121, 542]], [[71, 175], [2, 15, 53, 71]], [[1, 2, 3, 4, 8, 10, 14, 17, 22, 37, 44, 48, 52, 67, 68, 71, 74, 77, 82], [33, 1, 2, 3, 69, 71, 42, 77, 14, 15, 46, 19, 87, 23, 58, 27, 61]]], [[[101, 102, 103, 73, 235, 684, 621, 82, 54, 632, 637, 88], [545, 645, 104, 425, 73, 105, 15, 208, 84, 373, 54, 542]], [[71, 175], [170, 2, 53, 71]], [[1, 4, 9, 11, 13, 14, 15, 18, 19, 30, 44, 47, 51, 52, 60, 64, 65, 67, 68, 77], [1, 2, 3, 14, 15, 27, 33, 42, 46, 48, 51, 56, 58, 65, 68, 69, 70, 80, 87, 97]]], [[[102, 103, 73, 13, 15, 530, 82, 54, 55, 632, 92, 88], [101, 37, 103, 104, 649, 106, 524, 13, 628, 84, 280, 88]], [[2, 46], [52, 15, 71, 175]], [[2, 3, 4, 5, 7, 8, 13, 18, 22, 23, 28, 38, 44, 52, 56, 60, 65, 69, 70, 71], [1, 2, 3, 4, 6, 8, 15, 22, 27, 51, 54, 56, 57, 64, 70, 73, 74, 77, 82, 97]]], [[[101, 202, 235, 621, 13, 15, 530, 55, 283, 637, 542, 88], [101, 37, 103, 73, 202, 13, 373, 150, 54, 88, 542, 223]], [[71, 175], [153, 170, 2, 71]], [[1, 2, 3, 5, 8, 14, 15, 23, 48, 56, 63, 64, 65, 67, 69, 71, 74, 88, 93, 95], [1, 2, 3, 4, 6, 8, 12, 14, 18, 23, 34, 48, 54, 57, 58, 63, 64, 71, 80]]], [[[101, 102, 202, 235, 684, 15, 530, 55, 88, 283, 92, 88], [101, 37, 104, 73, 649, 105, 106, 13, 84, 54, 88, 380]], [[2, 46], [2, 15, 71, 175]], [[2, 3, 68, 69, 70, 8, 11, 77, 14, 47, 17, 18, 82, 52, 22, 23, 28, 30], [1, 2, 5, 14, 15, 22, 27, 42, 48, 51, 54, 57, 60, 63, 65, 74, 80, 82, 87, 97]]], [[[609, 35, 197, 7, 73, 107, 108, 76, 109, 110, 111, 337, 18, 78, 47, 115, 93, 94, 110], [35, 677, 621, 77, 80, 116, 52, 117, 20, 118, 188]], [[54, 55], [132, 133, 4, 71, 40, 21, 150, 22, 56, 58, 31]], [[0, 33, 2, 34, 3, 68, 32, 72, 73, 42, 41, 75, 43, 48, 82, 26, 30, 31], [0, 1, 2, 3, 4, 7, 14, 15, 23, 25, 26, 36, 40, 41, 51, 58, 72, 74, 78, 80]]], [[[608, 35, 292, 197, 359, 7, 73, 587, 107, 621, 365, 76, 110, 78, 690, 114, 115, 542, 110], [485, 69, 663, 649, 235, 52, 20, 534, 55, 118, 283]], [[55, 175], [4, 40, 41, 147, 84, 85, 21, 22, 58, 31, 63]], [[0, 1, 2, 3, 8, 12, 15, 18, 26, 32, 42, 43, 44, 47, 48, 52, 73, 78, 80], [64, 1, 0, 36, 37, 8, 9, 74, 12, 78, 14, 16, 48, 15, 17, 23, 58, 28]]], [[[608, 33, 609, 35, 292, 197, 40, 73, 76, 621, 622, 109, 112, 337, 18, 411, 542, 575, 110], [69, 616, 621, 77, 690, 116, 52, 118, 55, 539, 542]], [[54, 175], [134, 41, 108, 47, 147, 84, 21, 56, 57, 58, 31]], [[1, 15, 18, 19, 22, 31, 32, 44, 47, 51, 68, 71, 73, 74, 75, 80, 82, 93, 94, 95], [0, 1, 2, 6, 8, 9, 10, 12, 15, 18, 23, 25, 28, 42, 51, 56, 72, 74, 80]]], [[[608, 33, 292, 197, 107, 108, 365, 78, 111, 47, 337, 690, 274, 18, 114, 411, 93, 94, 110], [483, 293, 677, 69, 621, 45, 52, 118, 283, 188, 637]], [[54, 175], [228, 133, 134, 197, 4, 14, 147, 23, 21, 215, 57]], [[2, 3, 6, 14, 15, 22, 23, 31, 32, 41, 44, 48, 51, 52, 56, 64, 68, 72, 74, 78], [0, 2, 3, 4, 7, 9, 15, 16, 17, 19, 25, 26, 27, 34, 40, 42, 51, 78, 82]]], [[[33, 35, 7, 40, 73, 108, 621, 622, 76, 109, 110, 690, 111, 112, 114, 115, 93, 94, 110], [483, 643, 35, 621, 45, 77, 112, 116, 118, 283, 637]], [[55, 175], [134, 40, 46, 21, 150, 23, 56, 57, 58, 22, 31]], [[0, 2, 14, 18, 19, 22, 23, 32, 33, 41, 42, 44, 47, 51, 52, 56, 64, 68, 75, 78], [0, 1, 3, 6, 7, 9, 11, 13, 15, 17, 26, 37, 40, 41, 48, 51, 56, 64, 78]]], [[[608, 33, 197, 359, 40, 621, 365, 109, 110, 337, 274, 18, 111, 597, 112, 114, 115, 93, 110], [621, 622, 45, 80, 77, 116, 149, 117, 55, 20, 283]], [[55, 175], [197, 133, 40, 108, 175, 15, 244, 21, 22, 57, 31]], [[1, 2, 6, 14, 18, 23, 31, 32, 34, 43, 48, 52, 72, 73, 74, 78, 80, 82, 95], [1, 2, 3, 5, 6, 8, 11, 18, 19, 23, 28, 30, 37, 40, 45, 51, 71, 72, 74, 78]]], [[[33, 609, 35, 197, 359, 7, 587, 107, 76, 622, 109, 337, 274, 113, 597, 94, 411, 542, 110], [64, 69, 45, 622, 80, 112, 20, 149, 117, 118, 283]], [[54, 71], [193, 228, 4, 108, 47, 21, 23, 56, 57, 58, 63]], [[1, 2, 8, 14, 18, 26, 30, 31, 32, 43, 47, 51, 56, 68, 69, 71, 72, 74, 75], [96, 2, 37, 38, 8, 74, 11, 78, 47, 80, 15, 18, 82, 16, 23, 25, 28, 93]]], [[[608, 609, 292, 73, 587, 107, 365, 108, 76, 111, 337, 690, 112, 113, 47, 114, 94, 575, 110], [45, 621, 77, 80, 116, 117, 20, 118, 188, 637, 542]], [[54, 175], [4, 41, 43, 108, 14, 175, 15, 147, 84, 21, 56]], [[0, 1, 2, 3, 6, 12, 14, 15, 18, 26, 31, 41, 42, 43, 44, 48, 52, 56, 72], [1, 2, 37, 6, 72, 74, 42, 12, 11, 47, 15, 17, 19, 51, 23, 56, 26, 28]]], [[[608, 33, 609, 35, 197, 40, 73, 587, 107, 365, 47, 337, 274, 114, 115, 597, 542, 575, 110], [64, 581, 235, 45, 622, 80, 20, 118, 663, 283, 542]], [[54, 55], [133, 134, 102, 46, 206, 14, 22, 23, 57, 58, 31]], [[0, 2, 6, 8, 10, 12, 18, 19, 26, 41, 42, 48, 49, 51, 56, 68, 73, 74, 78, 80], [64, 2, 4, 6, 74, 11, 78, 14, 15, 48, 82, 19, 51, 17, 23, 58]]], [[[608, 33, 609, 35, 197, 7, 40, 587, 107, 621, 76, 110, 111, 112, 78, 411, 94, 575, 110], [483, 293, 677, 69, 621, 45, 117, 118, 283, 188, 542]], [[71, 175], [134, 41, 108, 14, 175, 15, 147, 85, 22, 57, 31]], [[0, 1, 2, 6, 8, 18, 22, 23, 30, 31, 41, 42, 43, 47, 48, 56, 68, 69, 73, 80], [0, 2, 3, 6, 11, 12, 15, 17, 26, 34, 36, 37, 40, 44, 51, 56, 74, 80, 94, 96]]], [[[609, 73, 235, 620, 111, 52, 184, 542, 18], [677, 123, 46, 663, 122, 283, 637, 542]], [[130, 67, 228, 2, 71, 39, 175, 59], [153, 2, 38, 1]], [[1, 3, 4, 5, 10, 11, 15, 19, 22, 23, 32, 36, 41, 42, 48, 54, 72, 77, 80], [0, 2, 3, 4, 6, 7, 14, 15, 18, 19, 20, 22, 28, 47, 58, 69, 71, 77, 78, 80]]], [[[73, 235, 46, 18, 534, 119, 184, 58, 18], [387, 46, 280, 121, 58, 283, 124, 542]], [[64, 2, 39, 42, 116, 57, 61, 62], [2, 132, 71, 175]], [[0, 33, 2, 3, 32, 4, 7, 10, 44, 13, 14, 77, 48, 17, 18, 12, 23, 58], [0, 1, 2, 9, 14, 20, 22, 29, 41, 48, 51, 53, 54, 72, 74, 75, 77, 80, 82]]], [[[73, 46, 18, 119, 184, 58, 120, 542, 18], [387, 649, 105, 50, 280, 283, 124, 542]], [[64, 133, 39, 71, 174, 175, 61, 63], [2, 132, 22, 175]], [[0, 6, 8, 13, 16, 17, 19, 41, 42, 44, 47, 48, 51, 56, 64, 71, 74, 77, 78, 82], [0, 1, 2, 4, 8, 14, 20, 22, 28, 33, 36, 48, 53, 54, 56, 69, 71, 75, 78]]], [[[35, 37, 73, 52, 534, 119, 120, 58, 18], [35, 621, 46, 663, 121, 58, 123, 542]], [[67, 132, 100, 39, 42, 174, 175, 63], [38, 71, 150, 175]], [[0, 3, 5, 11, 12, 14, 18, 19, 22, 23, 28, 29, 32, 41, 47, 48, 56, 72, 76], [0, 1, 2, 3, 4, 6, 7, 14, 15, 18, 20, 28, 29, 41, 51, 53, 54, 58, 64, 75]]], [[[609, 35, 73, 620, 46, 52, 534, 184, 18], [387, 105, 123, 46, 50, 121, 122, 283]], [[130, 67, 228, 71, 57, 59, 61, 62], [41, 1, 38, 39]], [[0, 3, 4, 5, 7, 14, 17, 18, 23, 28, 32, 41, 48, 51, 56, 69, 72, 76, 78], [0, 1, 2, 4, 9, 14, 15, 20, 22, 28, 29, 36, 47, 56, 64, 71, 77, 80, 82]]], [[[35, 37, 73, 52, 119, 184, 58, 120, 18], [58, 621, 46, 663, 121, 122, 124, 637]], [[64, 100, 132, 39, 42, 174, 61, 63], [132, 71, 22, 175]], [[6, 8, 10, 14, 16, 19, 27, 29, 41, 42, 48, 51, 54, 58, 64, 72, 74, 80, 82], [1, 2, 3, 4, 6, 7, 8, 9, 15, 20, 22, 23, 28, 48, 56, 64, 80, 82, 95]]], [[[37, 73, 46, 111, 18, 534, 184, 58, 18], [105, 123, 621, 46, 411, 121, 122, 283]], [[64, 100, 132, 133, 39, 42, 175, 59], [2, 39, 38, 71]], [[0, 1, 12, 13, 15, 17, 18, 19, 23, 27, 28, 29, 36, 41, 47, 48, 56, 77, 78, 80], [0, 2, 7, 8, 14, 20, 22, 28, 29, 36, 51, 53, 54, 64, 72, 77, 78, 80, 82]]], [[[35, 37, 73, 111, 18, 52, 119, 58, 18], [35, 483, 3, 105, 621, 122, 637, 542]], [[2, 67, 132, 133, 39, 175, 61, 62], [153, 2, 38, 175]], [[0, 1, 2, 4, 5, 10, 14, 15, 18, 22, 23, 32, 42, 47, 54, 56, 58, 72, 74], [2, 3, 6, 9, 14, 18, 20, 22, 23, 28, 29, 33, 36, 47, 53, 64, 71, 93, 96]]], [[[609, 37, 620, 46, 52, 119, 120, 58, 18], [616, 649, 50, 280, 122, 283, 124, 542]], [[67, 100, 132, 133, 59, 61, 62, 63], [1, 2, 150, 71]], [[0, 2, 3, 4, 5, 6, 7, 10, 14, 16, 19, 22, 36, 47, 72, 74, 77, 80, 82], [2, 3, 4, 6, 8, 14, 15, 22, 23, 29, 33, 42, 48, 56, 64, 75, 77, 80, 82]]], [[[73, 235, 620, 46, 534, 119, 58, 542, 18], [387, 3, 649, 105, 50, 121, 122, 283]], [[64, 2, 67, 39, 57, 60, 62, 63], [132, 71, 22, 175]], [[2, 3, 5, 8, 11, 13, 16, 17, 19, 21, 27, 28, 34, 39, 41, 42, 58, 80, 97], [0, 1, 2, 3, 4, 15, 19, 20, 22, 42, 48, 51, 53, 56, 64, 74, 77, 78, 80]]], [[[128, 129, 4, 359, 49, 125, 126, 127, 46]], [[175]], [[64, 0, 2, 4, 5, 74, 14, 47, 80, 16, 19, 20, 56, 57]]], [[[128, 542, 359, 73, 684, 283, 126, 127, 46]], [[65]], [[64, 97, 2, 1, 38, 39, 42, 12, 14, 47, 82, 19, 20]]], [[[128, 73, 684, 621, 46, 283, 126, 127, 46]], [[175]], [[97, 1, 4, 5, 38, 6, 14, 47, 16, 49, 20, 22, 23, 57]]], [[[129, 4, 202, 684, 556, 621, 283, 125, 46]], [[175]], [[96, 0, 2, 4, 6, 38, 42, 12, 47, 16, 82, 56, 57, 93]]], [[[128, 359, 684, 556, 49, 283, 125, 542, 46]], [[175]], [[64, 0, 2, 1, 38, 74, 11, 12, 15, 47, 20, 22, 23, 56]]], [[[4, 73, 556, 621, 46, 49, 125, 127, 46]], [[65]], [[0, 97, 4, 5, 6, 38, 14, 47, 16, 18, 82, 22, 23, 56]]], [[[129, 684, 46, 49, 283, 125, 542, 127, 46]], [[175]], [[97, 5, 38, 39, 6, 42, 12, 14, 47, 80, 82, 19, 23, 57]]], [[[4, 359, 202, 556, 621, 684, 283, 542, 46]], [[65]], [[0, 1, 34, 3, 4, 74, 11, 12, 16, 51, 19, 20, 57, 94]]], [[[128, 129, 359, 556, 46, 49, 283, 125, 46]], [[65]], [[64, 97, 2, 0, 4, 1, 6, 38, 42, 11, 47, 15, 19]]], [[[129, 4, 73, 684, 46, 283, 125, 127, 46]], [[65]], [[64, 97, 2, 0, 4, 39, 74, 12, 80, 16, 82, 19, 22, 57]]]], 'feature': array([[-1.02022052, 0. , 0. ],\n", - " [-1.02022052, 0. , 0. ],\n", - " [-1.02022052, 0. , 0. ],\n", - " [-1.02022052, 0. , 0. ],\n", - " [-1.02022052, 0. , 0. ],\n", - " [-1.02022052, 0. , 0. ],\n", - " [-1.02022052, 0. , 0. ],\n", - " [-1.02022052, 0. , 0. ],\n", - " [-1.02022052, 0. , 0. ],\n", - " [-1.02022052, 0. , 0. ],\n", - " [ 1.33936059, 1. , 0. ],\n", - " [ 1.33936059, 1. , 0. ],\n", - " [ 1.33936059, 1. , 0. ],\n", - " [ 1.33936059, 1. , 0. ],\n", - " [ 1.33936059, 1. , 0. ],\n", - " [ 1.33936059, 1. , 0. ],\n", - " [ 1.33936059, 1. , 0. ],\n", - " [ 1.33936059, 1. , 0. ],\n", - " [ 1.33936059, 1. , 0. ],\n", - " [ 1.33936059, 1. , 0. ],\n", - " [ 0.61376214, 1. , 0. ],\n", - " [ 0.61376214, 1. , 0. ],\n", - " [ 0.61376214, 1. , 0. ],\n", - " [ 0.61376214, 1. , 0. ],\n", - " [ 0.61376214, 1. , 0. ],\n", - " [ 0.61376214, 1. , 0. ],\n", - " [ 0.61376214, 1. , 0. ],\n", - " [ 0.61376214, 1. , 0. ],\n", - " [ 0.61376214, 1. , 0. ],\n", - " [ 0.61376214, 1. , 0. ],\n", - " [ 1.50029671, 1. , 0. ],\n", - " [ 1.50029671, 1. , 0. ],\n", - " [ 1.50029671, 1. , 0. ],\n", - " [ 1.50029671, 1. , 0. ],\n", - " [ 1.50029671, 1. , 0. ],\n", - " [ 1.50029671, 1. , 0. ],\n", - " [ 1.50029671, 1. , 0. ],\n", - " [ 1.50029671, 1. , 0. ],\n", - " [ 1.50029671, 1. , 0. ],\n", - " [ 1.50029671, 1. , 0. ],\n", - " [ 0.28487846, 1. , 0. ],\n", - " [ 0.28487846, 1. , 0. ],\n", - " [ 0.28487846, 1. , 0. ],\n", - " [ 0.28487846, 1. , 0. ],\n", - " [ 0.28487846, 1. , 0. ],\n", - " [ 0.28487846, 1. , 0. ],\n", - " [ 0.28487846, 1. , 0. ],\n", - " [ 0.28487846, 1. , 0. ],\n", - " [ 0.28487846, 1. , 0. ],\n", - " [ 0.28487846, 1. , 0. ],\n", - " [-0.61307007, 1. , 0. ],\n", - " [-0.61307007, 1. , 0. ],\n", - " [-0.61307007, 1. , 0. ],\n", - " [-0.61307007, 1. , 0. ],\n", - " [-0.61307007, 1. , 0. ],\n", - " [-0.61307007, 1. , 0. ],\n", - " [-0.61307007, 1. , 0. ],\n", - " [-0.61307007, 1. , 0. ],\n", - " [-0.61307007, 1. , 0. ],\n", - " [-0.61307007, 1. , 0. ],\n", - " [-1.37568235, 0. , 1. ],\n", - " [-1.37568235, 0. , 1. ],\n", - " [-1.37568235, 0. , 1. ],\n", - " [-1.37568235, 0. , 1. ],\n", - " [-1.37568235, 0. , 1. ],\n", - " [-1.37568235, 0. , 1. ],\n", - " [-1.37568235, 0. , 1. ],\n", - " [-1.37568235, 0. , 1. ],\n", - " [-1.37568235, 0. , 1. ],\n", - " [-1.37568235, 0. , 1. ],\n", - " [ 0.7086606 , 1. , 0. ],\n", - " [ 0.7086606 , 1. , 0. ],\n", - " [ 0.7086606 , 1. , 0. ],\n", - " [ 0.7086606 , 1. , 0. ],\n", - " [ 0.7086606 , 1. , 0. ],\n", - " [ 0.7086606 , 1. , 0. ],\n", - " [ 0.7086606 , 1. , 0. ],\n", - " [ 0.7086606 , 1. , 0. ],\n", - " [ 0.7086606 , 1. , 0. ],\n", - " [ 0.7086606 , 1. , 0. ],\n", - " [ 0.59582597, 1. , 2. ],\n", - " [ 0.59582597, 1. , 2. ],\n", - " [ 0.59582597, 1. , 2. ],\n", - " [ 0.59582597, 1. , 2. ],\n", - " [ 0.59582597, 1. , 2. ],\n", - " [ 0.59582597, 1. , 2. ],\n", - " [ 0.59582597, 1. , 2. ],\n", - " [ 0.59582597, 1. , 2. ],\n", - " [ 0.59582597, 1. , 2. ],\n", - " [ 0.59582597, 1. , 2. ],\n", - " [-0.27521837, 0. , 3. ],\n", - " [-0.27521837, 0. , 3. ],\n", - " [-0.27521837, 0. , 3. ],\n", - " [-0.27521837, 0. , 3. ],\n", - " [-0.27521837, 0. , 3. ],\n", - " [-0.27521837, 0. , 3. ],\n", - " [-0.27521837, 0. , 3. ],\n", - " [-0.27521837, 0. , 3. ],\n", - " [-0.27521837, 0. , 3. ],\n", - " [-0.27521837, 0. , 3. ]]), 'order': ['diag', 'prod', 'med'], 'n_num_feature': 1, 'cat_cardinalties': [2, 10]}\n" - ] - } - ], - "source": [ - "print(res)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "c53197ed", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "os.chdir('../')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "2e4d1fd0", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", - "The tokenizer class you load from this checkpoint is 'BartTokenizer'. \n", - "The class this function is called from is 'DataTokenizer'.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Load pretrained PromptEHR model from ./simulation/pretrained_promptEHR\n", - "Load the pre-trained model from: ./simulation/pretrained_promptEHR\n" - ] - } - ], - "source": [ - "# if you want pretrained model downloaded\n", - "from promptehr import PromptEHR\n", - "model = PromptEHR()\n", - "model.from_pretrained()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "fca5038d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "we are done! :)\n" - ] - } - ], - "source": [ - "print('we are done! :)')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "67c73772", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}