{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Categorical Encoding.ipynb", "provenance": [], "collapsed_sections": [], "authorship_tag": "ABX9TyPxp4DnWs7o6MmCborBmiSx", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "source": [ "!pip install tensorflow_addons\n", "!pip install category_encoders" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dV7sBRyA0laj", "outputId": "e75f1c7f-ddc6-4c36-a2d0-41ac2cd8c0e7" }, "execution_count": 4, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: tensorflow_addons in /usr/local/lib/python3.7/dist-packages (0.16.1)\n", "Requirement already satisfied: typeguard>=2.7 in /usr/local/lib/python3.7/dist-packages (from tensorflow_addons) (2.7.1)\n", "Requirement already satisfied: category_encoders in /usr/local/lib/python3.7/dist-packages (2.4.0)\n", "Requirement already satisfied: patsy>=0.5.1 in /usr/local/lib/python3.7/dist-packages (from category_encoders) (0.5.2)\n", "Requirement already satisfied: statsmodels>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from category_encoders) (0.10.2)\n", "Requirement already satisfied: scipy>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from category_encoders) (1.4.1)\n", "Requirement already satisfied: scikit-learn>=0.20.0 in /usr/local/lib/python3.7/dist-packages (from category_encoders) (1.0.2)\n", "Requirement already satisfied: pandas>=0.21.1 in /usr/local/lib/python3.7/dist-packages (from category_encoders) (1.3.5)\n", "Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.7/dist-packages (from category_encoders) (1.21.6)\n", "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.21.1->category_encoders) (2.8.2)\n", "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.21.1->category_encoders) (2022.1)\n", "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from patsy>=0.5.1->category_encoders) (1.15.0)\n", "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.20.0->category_encoders) (1.1.0)\n", "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.20.0->category_encoders) (3.1.0)\n" ] } ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "2INjkFXMDzMg" }, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "source": [ "import pandas as pd\n", "import requests\n", "from zipfile import ZipFile\n", "\n", "dataset_url = 'https://storage.googleapis.com/kaggle-data-sets/225/498/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20220430%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20220430T104007Z&X-Goog-Expires=259199&X-Goog-SignedHeaders=host&X-Goog-Signature=1ced48204b7889c650880becabe6f5a825ff8d6d346832d811cb4a928586f097ff2edd3f976109f5c05b3bcb7c93c644196bbf71417ed9038f23296c2310fda6539c5349471f856435c1ee13f0345cefd37aca6de9f39f454486106353681c949830c5629f62ed7551beb1e16dda1f011b4c54f9c1943e2607629e5b6849373b923fd595fdcb63a6e7a61a0d98c3753ffdafaeb4506efafe45948cb2dc577c2df8d0cf6d195c88077f050e024ffb50f3f66b0f1fa4b0c1fe3ac7c5185aa0af2907d179847e8eec2d4996428fa2b97b93c9d19247827213c65fb142e4d5f3ce20f3ba0b7fa3a45a55a17f8b975f2204e77fcc9edaff701c1c7d29e459ce6f4e25'\n", "download_filename = 'download.zip'\n", "content_filename = 'adult.csv'\n", "\n", "req = requests.get(dataset_url)\n", "with open(download_filename, 'wb') as output_file:\n", " output_file.write(req.content)\n", "print('Download completed!\\n')\n", "\n", "zf = ZipFile(download_filename)\n", "data = pd.read_csv(zf.open(content_filename)).dropna()\n", "data['occupation'] = data['occupation'].replace({'?': 'Unknown'})\n", "data['workclass'] = data['workclass'].replace({'?': 'Unknown'})\n", "data" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 458 }, "id": "GttbPkK5EDQf", "outputId": "3c819a95-cd2c-410e-cd11-5affbd49822c" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Download completed!\n", "\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ " age workclass fnlwgt education education.num marital.status \\\n", "0 90 Unknown 77053 HS-grad 9 Widowed \n", "1 82 Private 132870 HS-grad 9 Widowed \n", "2 66 Unknown 186061 Some-college 10 Widowed \n", "3 54 Private 140359 7th-8th 4 Divorced \n", "4 41 Private 264663 Some-college 10 Separated \n", "... ... ... ... ... ... ... \n", "32556 22 Private 310152 Some-college 10 Never-married \n", "32557 27 Private 257302 Assoc-acdm 12 Married-civ-spouse \n", "32558 40 Private 154374 HS-grad 9 Married-civ-spouse \n", "32559 58 Private 151910 HS-grad 9 Widowed \n", "32560 22 Private 201490 HS-grad 9 Never-married \n", "\n", " occupation relationship race sex capital.gain \\\n", "0 Unknown Not-in-family White Female 0 \n", "1 Exec-managerial Not-in-family White Female 0 \n", "2 Unknown Unmarried Black Female 0 \n", "3 Machine-op-inspct Unmarried White Female 0 \n", "4 Prof-specialty Own-child White Female 0 \n", "... ... ... ... ... ... \n", "32556 Protective-serv Not-in-family White Male 0 \n", "32557 Tech-support Wife White Female 0 \n", "32558 Machine-op-inspct Husband White Male 0 \n", "32559 Adm-clerical Unmarried White Female 0 \n", "32560 Adm-clerical Own-child White Male 0 \n", "\n", " capital.loss hours.per.week native.country income \n", "0 4356 40 United-States <=50K \n", "1 4356 18 United-States <=50K \n", "2 4356 40 United-States <=50K \n", "3 3900 40 United-States <=50K \n", "4 3900 40 United-States <=50K \n", "... ... ... ... ... \n", "32556 0 40 United-States <=50K \n", "32557 0 38 United-States <=50K \n", "32558 0 40 United-States >50K \n", "32559 0 40 United-States <=50K \n", "32560 0 20 United-States <=50K \n", "\n", "[32561 rows x 15 columns]" ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation.nummarital.statusoccupationrelationshipracesexcapital.gaincapital.losshours.per.weeknative.countryincome
090Unknown77053HS-grad9WidowedUnknownNot-in-familyWhiteFemale0435640United-States<=50K
182Private132870HS-grad9WidowedExec-managerialNot-in-familyWhiteFemale0435618United-States<=50K
266Unknown186061Some-college10WidowedUnknownUnmarriedBlackFemale0435640United-States<=50K
354Private1403597th-8th4DivorcedMachine-op-inspctUnmarriedWhiteFemale0390040United-States<=50K
441Private264663Some-college10SeparatedProf-specialtyOwn-childWhiteFemale0390040United-States<=50K
................................................
3255622Private310152Some-college10Never-marriedProtective-servNot-in-familyWhiteMale0040United-States<=50K
3255727Private257302Assoc-acdm12Married-civ-spouseTech-supportWifeWhiteFemale0038United-States<=50K
3255840Private154374HS-grad9Married-civ-spouseMachine-op-inspctHusbandWhiteMale0040United-States>50K
3255958Private151910HS-grad9WidowedAdm-clericalUnmarriedWhiteFemale0040United-States<=50K
3256022Private201490HS-grad9Never-marriedAdm-clericalOwn-childWhiteMale0020United-States<=50K
\n", "

32561 rows × 15 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 6 } ] }, { "cell_type": "markdown", "source": [ "To better understand the categorical encoding techniques, let's first print the unique values for each column." ], "metadata": { "id": "w8R9g5_MHPsP" } }, { "cell_type": "code", "source": [ "categorical_columns = ['workclass', 'education', 'marital.status', 'occupation', 'relationship', 'race', 'sex', 'native.country']\n", "\n", "for column in categorical_columns:\n", " print('column: {},\\tunique values: {}'.format(column, data[column].unique().shape[0]))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "zWTbaFc3HPNx", "outputId": "ea358583-284f-4038-9f6c-12635e9013d5" }, "execution_count": 7, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "column: workclass,\tunique values: 9\n", "column: education,\tunique values: 16\n", "column: marital.status,\tunique values: 7\n", "column: occupation,\tunique values: 15\n", "column: relationship,\tunique values: 6\n", "column: race,\tunique values: 5\n", "column: sex,\tunique values: 2\n", "column: native.country,\tunique values: 42\n" ] } ] }, { "cell_type": "markdown", "source": [ "Categorical Column | Unique Values\n", "-------------------|------------------\n", "workclass | 9\n", "education | 16\n", "marital.status |7 \n", "occupation | 15\n", "relationship | 6 \n", "race | 5\n", "sex | 2\n", "native.country | 42\n", "income | 2\n" ], "metadata": { "id": "nvKLZDS3IVl-" } }, { "cell_type": "markdown", "source": [ "# Binary Encoding of Targets\n", "\n", "Since there are 2 categories for the **income**: *{<=50K, >50K}*, we can use binary encoding." ], "metadata": { "id": "0wd0vZhmGG79" } }, { "cell_type": "code", "source": [ "targets = data['income'].replace({'<=50K': 0, '>50K': 1}).to_numpy()\n", "data = data.drop(columns=['income'])\n", "targets" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "I8J6nQ5aF4Qt", "outputId": "475f7e02-c7f5-4226-f905-2520a1a9ade4" }, "execution_count": 8, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([0, 0, 0, ..., 1, 0, 0])" ] }, "metadata": {}, "execution_count": 8 } ] }, { "cell_type": "markdown", "source": [ "Splitting dataset to train-test samples." ], "metadata": { "id": "-MMD1nms1wdx" } }, { "cell_type": "code", "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "test_size = 0.1\n", "random_state=0\n", "\n", "x_train, x_test, y_train, y_test = train_test_split(\n", " data,\n", " targets,\n", " test_size=test_size,\n", " random_state=random_state\n", ")\n", "x_train.shape, x_test.shape, y_train.shape, y_test.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "leINNlh91tWH", "outputId": "d8ccd756-491a-439c-b766-4a387b17cc4a" }, "execution_count": 9, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "((29304, 14), (3257, 14), (29304,), (3257,))" ] }, "metadata": {}, "execution_count": 9 } ] }, { "cell_type": "markdown", "source": [ "#Ordinal Encoding\n", "\n", "In ordinal encoding, we replace each categorical variable with an integer value (e.g. 0, 1, 2, etc.). In this case, retaining the order is important." ], "metadata": { "id": "dsyxoZWoFHlU" } }, { "cell_type": "code", "source": [ "from sklearn.preprocessing import OrdinalEncoder\n", "\n", "ordinal_encoder = OrdinalEncoder()\n", "ordinal_encoder.fit(data)\n", "x_train = ordinal_encoder.transform(x_train)\n", "x_test = ordinal_encoder.transform(x_test)\n", "x_train" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "jYqcsrIiFsK_", "outputId": "e7790dba-7634-40ad-c259-dcea745f69ea" }, "execution_count": 11, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([[1.8000e+01, 3.0000e+00, 1.5732e+04, ..., 0.0000e+00, 3.9000e+01,\n", " 3.9000e+01],\n", " [1.1000e+01, 3.0000e+00, 3.2850e+03, ..., 0.0000e+00, 2.2000e+01,\n", " 3.9000e+01],\n", " [3.8000e+01, 3.0000e+00, 2.3630e+03, ..., 0.0000e+00, 3.9000e+01,\n", " 3.9000e+01],\n", " ...,\n", " [9.0000e+00, 3.0000e+00, 1.6120e+03, ..., 0.0000e+00, 5.0000e+01,\n", " 3.9000e+01],\n", " [2.7000e+01, 3.0000e+00, 5.1460e+03, ..., 0.0000e+00, 3.9000e+01,\n", " 3.9000e+01],\n", " [2.2000e+01, 3.0000e+00, 1.4687e+04, ..., 0.0000e+00, 3.9000e+01,\n", " 3.9000e+01]])" ] }, "metadata": {}, "execution_count": 11 } ] }, { "cell_type": "markdown", "source": [ "Building an MLP (Multi-Layer Perceptron) classifier." ], "metadata": { "id": "PwpwN11Vxqzj" } }, { "cell_type": "code", "source": [ "import tensorflow as tf\n", "import tensorflow_addons as tfa\n", "\n", "\n", "def build_model(input_size):\n", " model = tf.keras.models.Sequential()\n", " model.add(tf.keras.layers.Input(input_size))\n", " model.add(tf.keras.layers.GaussianNoise(stddev=0.1))\n", " model.add(tf.keras.layers.Dense(units=128, activation='gelu', use_bias=False))\n", " model.add(tf.keras.layers.BatchNormalization())\n", " model.add(tf.keras.layers.Dropout(rate=0.2))\n", " model.add(tf.keras.layers.Dense(units=64, kernel_regularizer='l1', use_bias=False))\n", " model.add(tf.keras.layers.BatchNormalization())\n", " model.add(tf.keras.layers.Dense(units=1, activation='sigmoid'))\n", "\n", " model.compile(\n", " optimizer=tfa.optimizers.Yogi(learning_rate=0.001),\n", " loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.1),\n", " metrics=['accuracy']\n", " )\n", " return model\n", "\n", "\n", "model = build_model(data.shape[1:])\n", "model.summary()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "udaLpISGx5IT", "outputId": "38ed75a1-ac74-4e3a-ad38-cc896f12fdd4" }, "execution_count": 10, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " gaussian_noise (GaussianNoi (None, 14) 0 \n", " se) \n", " \n", " dense (Dense) (None, 128) 1792 \n", " \n", " batch_normalization (BatchN (None, 128) 512 \n", " ormalization) \n", " \n", " dropout (Dropout) (None, 128) 0 \n", " \n", " dense_1 (Dense) (None, 64) 8192 \n", " \n", " batch_normalization_1 (Batc (None, 64) 256 \n", " hNormalization) \n", " \n", " dense_2 (Dense) (None, 1) 65 \n", " \n", "=================================================================\n", "Total params: 10,817\n", "Trainable params: 10,433\n", "Non-trainable params: 384\n", "_________________________________________________________________\n" ] } ] }, { "cell_type": "markdown", "source": [ "Training the neural network with the ordinal variables." ], "metadata": { "id": "Wviiljaj3zoa" } }, { "cell_type": "code", "source": [ "batch_size = 32\n", "epochs = 20\n", "shuffle=True\n", "\n", "model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=epochs, shuffle=shuffle, validation_data=(x_test, y_test))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4wrdfcFl3ylx", "outputId": "2b73ad1d-b4bd-44e1-e55e-d3bab9b6b859" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/20\n", "916/916 [==============================] - 22s 13ms/step - loss: 1.6441 - accuracy: 0.7641 - val_loss: 0.6491 - val_accuracy: 0.7710\n", "Epoch 2/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.6042 - accuracy: 0.7877 - val_loss: 0.6521 - val_accuracy: 0.7731\n", "Epoch 3/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5706 - accuracy: 0.7942 - val_loss: 0.9231 - val_accuracy: 0.5388\n", "Epoch 4/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5486 - accuracy: 0.8001 - val_loss: 0.6819 - val_accuracy: 0.6460\n", "Epoch 5/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.5369 - accuracy: 0.8031 - val_loss: 0.5629 - val_accuracy: 0.7928\n", "Epoch 6/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5313 - accuracy: 0.8046 - val_loss: 0.5119 - val_accuracy: 0.8158\n", "Epoch 7/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5266 - accuracy: 0.8032 - val_loss: 0.6646 - val_accuracy: 0.6736\n", "Epoch 8/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5189 - accuracy: 0.8042 - val_loss: 0.4970 - val_accuracy: 0.8262\n", "Epoch 9/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5140 - accuracy: 0.8041 - val_loss: 0.4953 - val_accuracy: 0.8155\n", "Epoch 10/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5113 - accuracy: 0.8051 - val_loss: 0.5406 - val_accuracy: 0.7839\n", "Epoch 11/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5072 - accuracy: 0.8070 - val_loss: 0.4805 - val_accuracy: 0.8268\n", "Epoch 12/20\n", "916/916 [==============================] - 8s 9ms/step - loss: 0.5071 - accuracy: 0.8057 - val_loss: 0.4838 - val_accuracy: 0.8253\n", "Epoch 13/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5051 - accuracy: 0.8053 - val_loss: 0.5125 - val_accuracy: 0.8161\n", "Epoch 14/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5037 - accuracy: 0.8069 - val_loss: 0.5076 - val_accuracy: 0.8072\n", "Epoch 15/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5039 - accuracy: 0.8063 - val_loss: 0.4861 - val_accuracy: 0.8161\n", "Epoch 16/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5009 - accuracy: 0.8072 - val_loss: 0.5607 - val_accuracy: 0.7989\n", "Epoch 17/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5018 - accuracy: 0.8067 - val_loss: 0.4826 - val_accuracy: 0.8314\n", "Epoch 18/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.4979 - accuracy: 0.8073 - val_loss: 0.4835 - val_accuracy: 0.8296\n", "Epoch 19/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.4992 - accuracy: 0.8084 - val_loss: 0.4754 - val_accuracy: 0.8330\n", "Epoch 20/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5012 - accuracy: 0.8067 - val_loss: 0.4785 - val_accuracy: 0.8207\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 9 } ] }, { "cell_type": "markdown", "source": [ "Best Validation Accuracy: **0.8330**" ], "metadata": { "id": "9CBpyAhU9MHa" } }, { "cell_type": "markdown", "source": [ "#One-Hot Encoding\n", "This cons of the ordinal encoding is that categorical variables become an **ordered** set of numbers. However, this could potentially trick the classifier into thinking that close numbers are similar (e.g. `Married and Divorced` values have the opposite meaning. However, `if Married = 0 and Divorced = 1` is assigned, then it might confuse the classifier into thinking that they are similar.\n", "\n", "One-Hot encoding takes care of the above disadvantage by separating the categorical columns into multiple binary columns (e.g. `Sex: {Male, Female}` becomes `s1, s2, where (s1, s2) = (1, 0) if Male and (0, 1) if Female`). This might work for a few unique variables. **However, if the number of unique values is large, then the dimensions of the dataset would become a very large sparsed matrix.**" ], "metadata": { "id": "W2Z7hTb3-H-S" } }, { "cell_type": "markdown", "source": [ "#Dummy Encoding\n", "This is an improvement over One-Hot Encoding. Dummy encoding uses less N-1 features to represent N categories. For example, `Marital Status: {Married, Divorced, Engaged} becomes s1, s2, where (s1, s2) = (0, 0) if Married, (0, 1) if Divorced and (1, 1) if engaged).` However, in this dataset " ], "metadata": { "id": "WXuYfWwMGbow" } }, { "cell_type": "code", "source": [ "dummy_data = pd.get_dummies(data)\n", "dummy_data" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 505 }, "id": "2dctwghYRQ35", "outputId": "bda43daa-5bd8-4e55-a469-d4f4d2d9a9d4" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " age fnlwgt education.num capital.gain capital.loss hours.per.week \\\n", "0 90 77053 9 0 4356 40 \n", "1 82 132870 9 0 4356 18 \n", "2 66 186061 10 0 4356 40 \n", "3 54 140359 4 0 3900 40 \n", "4 41 264663 10 0 3900 40 \n", "... ... ... ... ... ... ... \n", "32556 22 310152 10 0 0 40 \n", "32557 27 257302 12 0 0 38 \n", "32558 40 154374 9 0 0 40 \n", "32559 58 151910 9 0 0 40 \n", "32560 22 201490 9 0 0 20 \n", "\n", " workclass_Federal-gov workclass_Local-gov workclass_Never-worked \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "... ... ... ... \n", "32556 0 0 0 \n", "32557 0 0 0 \n", "32558 0 0 0 \n", "32559 0 0 0 \n", "32560 0 0 0 \n", "\n", " workclass_Private ... native.country_Portugal \\\n", "0 0 ... 0 \n", "1 1 ... 0 \n", "2 0 ... 0 \n", "3 1 ... 0 \n", "4 1 ... 0 \n", "... ... ... ... \n", "32556 1 ... 0 \n", "32557 1 ... 0 \n", "32558 1 ... 0 \n", "32559 1 ... 0 \n", "32560 1 ... 0 \n", "\n", " native.country_Puerto-Rico native.country_Scotland \\\n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "... ... ... \n", "32556 0 0 \n", "32557 0 0 \n", "32558 0 0 \n", "32559 0 0 \n", "32560 0 0 \n", "\n", " native.country_South native.country_Taiwan native.country_Thailand \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "... ... ... ... \n", "32556 0 0 0 \n", "32557 0 0 0 \n", "32558 0 0 0 \n", "32559 0 0 0 \n", "32560 0 0 0 \n", "\n", " native.country_Trinadad&Tobago native.country_United-States \\\n", "0 0 1 \n", "1 0 1 \n", "2 0 1 \n", "3 0 1 \n", "4 0 1 \n", "... ... ... \n", "32556 0 1 \n", "32557 0 1 \n", "32558 0 1 \n", "32559 0 1 \n", "32560 0 1 \n", "\n", " native.country_Vietnam native.country_Yugoslavia \n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "... ... ... \n", "32556 0 0 \n", "32557 0 0 \n", "32558 0 0 \n", "32559 0 0 \n", "32560 0 0 \n", "\n", "[32561 rows x 108 columns]" ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agefnlwgteducation.numcapital.gaincapital.losshours.per.weekworkclass_Federal-govworkclass_Local-govworkclass_Never-workedworkclass_Private...native.country_Portugalnative.country_Puerto-Riconative.country_Scotlandnative.country_Southnative.country_Taiwannative.country_Thailandnative.country_Trinadad&Tobagonative.country_United-Statesnative.country_Vietnamnative.country_Yugoslavia
09077053904356400000...0000000100
182132870904356180001...0000000100
2661860611004356400000...0000000100
354140359403900400001...0000000100
4412646631003900400001...0000000100
..................................................................
32556223101521000400001...0000000100
32557272573021200380001...0000000100
3255840154374900400001...0000000100
3255958151910900400001...0000000100
3256022201490900200001...0000000100
\n", "

32561 rows × 108 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 10 } ] }, { "cell_type": "code", "source": [ "x_train, x_test, y_train, y_test = train_test_split(\n", " dummy_data,\n", " targets,\n", " test_size=test_size,\n", " random_state=random_state\n", ")\n", "\n", "model = build_model(dummy_data.shape[1:])\n", "model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=epochs, shuffle=shuffle, validation_data=(x_test, y_test))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xd7UV0zGRbNe", "outputId": "9c7f68ba-cad9-4ee6-fe34-5702ec6f2490" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/20\n", "916/916 [==============================] - 9s 8ms/step - loss: 1.7158 - accuracy: 0.7689 - val_loss: 0.6750 - val_accuracy: 0.7749\n", "Epoch 2/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.6385 - accuracy: 0.7785 - val_loss: 0.6733 - val_accuracy: 0.7181\n", "Epoch 3/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.6073 - accuracy: 0.7844 - val_loss: 0.6152 - val_accuracy: 0.7706\n", "Epoch 4/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.5902 - accuracy: 0.7875 - val_loss: 0.5793 - val_accuracy: 0.7909\n", "Epoch 5/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.5867 - accuracy: 0.7870 - val_loss: 0.6048 - val_accuracy: 0.7771\n", "Epoch 6/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.5743 - accuracy: 0.7913 - val_loss: 0.5732 - val_accuracy: 0.8044\n", "Epoch 7/20\n", "916/916 [==============================] - 11s 12ms/step - loss: 0.5724 - accuracy: 0.7892 - val_loss: 0.5662 - val_accuracy: 0.7872\n", "Epoch 8/20\n", "916/916 [==============================] - 13s 14ms/step - loss: 0.5685 - accuracy: 0.7911 - val_loss: 0.6454 - val_accuracy: 0.7105\n", "Epoch 9/20\n", "916/916 [==============================] - 8s 8ms/step - loss: 0.5662 - accuracy: 0.7889 - val_loss: 0.5841 - val_accuracy: 0.7676\n", "Epoch 10/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5653 - accuracy: 0.7877 - val_loss: 0.5618 - val_accuracy: 0.7875\n", "Epoch 11/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.5623 - accuracy: 0.7875 - val_loss: 0.5626 - val_accuracy: 0.7848\n", "Epoch 12/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.5605 - accuracy: 0.7895 - val_loss: 0.5613 - val_accuracy: 0.7826\n", "Epoch 13/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.5588 - accuracy: 0.7884 - val_loss: 0.5644 - val_accuracy: 0.7872\n", "Epoch 14/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.5600 - accuracy: 0.7890 - val_loss: 0.5581 - val_accuracy: 0.7848\n", "Epoch 15/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.5568 - accuracy: 0.7883 - val_loss: 0.5595 - val_accuracy: 0.7842\n", "Epoch 16/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.5546 - accuracy: 0.7913 - val_loss: 0.5514 - val_accuracy: 0.7872\n", "Epoch 17/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.5584 - accuracy: 0.7879 - val_loss: 0.5564 - val_accuracy: 0.7866\n", "Epoch 18/20\n", "916/916 [==============================] - 7s 7ms/step - loss: 0.5579 - accuracy: 0.7887 - val_loss: 0.5578 - val_accuracy: 0.7839\n", "Epoch 19/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5542 - accuracy: 0.7896 - val_loss: 0.5561 - val_accuracy: 0.7860\n", "Epoch 20/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5551 - accuracy: 0.7884 - val_loss: 0.5586 - val_accuracy: 0.7842\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 11 } ] }, { "cell_type": "markdown", "source": [ "Best Validation Accuracy: **0.8044**" ], "metadata": { "id": "Wl-mpDFmSYrM" } }, { "cell_type": "markdown", "source": [ "#Effect/Deviation/Sum Encoding\n", "\n", "In dummy coding, we use 0 and 1 to represent the data but in effect encoding (also known as *Deviation Encoding or Sum Encoding*), we use three values i.e. 1,0, and -1." ], "metadata": { "id": "NeDGR2KySnAI" } }, { "cell_type": "code", "source": [ "from category_encoders.sum_coding import SumEncoder\n", "\n", "x_train, x_test, y_train, y_test = train_test_split(\n", " data,\n", " targets,\n", " test_size=test_size,\n", " random_state=random_state\n", ")\n", "\n", "sum_encoder = SumEncoder()\n", "sum_encoder.fit(data)\n", "x_train = sum_encoder.transform(x_train)\n", "x_test = sum_encoder.transform(x_test)\n", "x_train" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 488 }, "id": "eUcOE9qeSbph", "outputId": "1f1dd4da-6f10-4d6b-a72f-10ec4e11afe1" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " intercept age workclass_0 workclass_1 workclass_2 workclass_3 \\\n", "9281 1 35 0.0 1.0 0.0 0.0 \n", "31884 1 28 0.0 1.0 0.0 0.0 \n", "31580 1 55 0.0 1.0 0.0 0.0 \n", "18489 1 33 0.0 1.0 0.0 0.0 \n", "21111 1 39 0.0 0.0 0.0 0.0 \n", "... ... ... ... ... ... ... \n", "13123 1 90 0.0 0.0 0.0 0.0 \n", "19648 1 36 0.0 1.0 0.0 0.0 \n", "9845 1 26 0.0 1.0 0.0 0.0 \n", "10799 1 44 0.0 1.0 0.0 0.0 \n", "2732 1 39 0.0 1.0 0.0 0.0 \n", "\n", " workclass_4 workclass_5 workclass_6 workclass_7 ... \\\n", "9281 0.0 0.0 0.0 0.0 ... \n", "31884 0.0 0.0 0.0 0.0 ... \n", "31580 0.0 0.0 0.0 0.0 ... \n", "18489 0.0 0.0 0.0 0.0 ... \n", "21111 0.0 1.0 0.0 0.0 ... \n", "... ... ... ... ... ... \n", "13123 1.0 0.0 0.0 0.0 ... \n", "19648 0.0 0.0 0.0 0.0 ... \n", "9845 0.0 0.0 0.0 0.0 ... \n", "10799 0.0 0.0 0.0 0.0 ... \n", "2732 0.0 0.0 0.0 0.0 ... \n", "\n", " native.country_31 native.country_32 native.country_33 \\\n", "9281 0.0 0.0 0.0 \n", "31884 0.0 0.0 0.0 \n", "31580 0.0 0.0 0.0 \n", "18489 0.0 0.0 0.0 \n", "21111 0.0 0.0 0.0 \n", "... ... ... ... \n", "13123 0.0 0.0 0.0 \n", "19648 0.0 0.0 0.0 \n", "9845 0.0 0.0 0.0 \n", "10799 0.0 0.0 0.0 \n", "2732 0.0 0.0 0.0 \n", "\n", " native.country_34 native.country_35 native.country_36 \\\n", "9281 0.0 0.0 0.0 \n", "31884 0.0 0.0 0.0 \n", "31580 0.0 0.0 0.0 \n", "18489 0.0 0.0 0.0 \n", "21111 0.0 0.0 0.0 \n", "... ... ... ... \n", "13123 0.0 0.0 0.0 \n", "19648 0.0 0.0 0.0 \n", "9845 0.0 0.0 0.0 \n", "10799 0.0 0.0 0.0 \n", "2732 0.0 0.0 0.0 \n", "\n", " native.country_37 native.country_38 native.country_39 \\\n", "9281 0.0 0.0 0.0 \n", "31884 0.0 0.0 0.0 \n", "31580 0.0 0.0 0.0 \n", "18489 0.0 0.0 0.0 \n", "21111 0.0 0.0 0.0 \n", "... ... ... ... \n", "13123 0.0 0.0 0.0 \n", "19648 0.0 0.0 0.0 \n", "9845 0.0 0.0 0.0 \n", "10799 0.0 0.0 0.0 \n", "2732 0.0 0.0 0.0 \n", "\n", " native.country_40 \n", "9281 0.0 \n", "31884 0.0 \n", "31580 0.0 \n", "18489 0.0 \n", "21111 0.0 \n", "... ... \n", "13123 0.0 \n", "19648 0.0 \n", "9845 0.0 \n", "10799 0.0 \n", "2732 0.0 \n", "\n", "[29304 rows x 101 columns]" ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
interceptageworkclass_0workclass_1workclass_2workclass_3workclass_4workclass_5workclass_6workclass_7...native.country_31native.country_32native.country_33native.country_34native.country_35native.country_36native.country_37native.country_38native.country_39native.country_40
92811350.01.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
318841280.01.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
315801550.01.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
184891330.01.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
211111390.00.00.00.00.01.00.00.0...0.00.00.00.00.00.00.00.00.00.0
..................................................................
131231900.00.00.00.01.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
196481360.01.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
98451260.01.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
107991440.01.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
27321390.01.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
\n", "

29304 rows × 101 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 12 } ] }, { "cell_type": "code", "source": [ "model = build_model(x_train.shape[1:])\n", "model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=epochs, shuffle=shuffle, validation_data=(x_test, y_test))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "E5rdTtBAWbAh", "outputId": "35b841a1-c9fe-428e-993e-4bd1fb12f0ca" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/20\n", "916/916 [==============================] - 10s 9ms/step - loss: 1.6068 - accuracy: 0.7683 - val_loss: 0.8291 - val_accuracy: 0.6620\n", "Epoch 2/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.6221 - accuracy: 0.7818 - val_loss: 0.6185 - val_accuracy: 0.7651\n", "Epoch 3/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.6026 - accuracy: 0.7852 - val_loss: 0.5778 - val_accuracy: 0.7875\n", "Epoch 4/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5871 - accuracy: 0.7896 - val_loss: 0.5873 - val_accuracy: 0.7734\n", "Epoch 5/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5833 - accuracy: 0.7884 - val_loss: 0.5877 - val_accuracy: 0.7872\n", "Epoch 6/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5797 - accuracy: 0.7871 - val_loss: 0.5746 - val_accuracy: 0.7746\n", "Epoch 7/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5740 - accuracy: 0.7880 - val_loss: 0.5663 - val_accuracy: 0.7866\n", "Epoch 8/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5669 - accuracy: 0.7879 - val_loss: 0.5554 - val_accuracy: 0.7974\n", "Epoch 9/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5626 - accuracy: 0.7894 - val_loss: 0.5625 - val_accuracy: 0.7915\n", "Epoch 10/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5637 - accuracy: 0.7878 - val_loss: 0.5567 - val_accuracy: 0.7869\n", "Epoch 11/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5605 - accuracy: 0.7898 - val_loss: 0.5573 - val_accuracy: 0.7885\n", "Epoch 12/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5576 - accuracy: 0.7907 - val_loss: 0.5601 - val_accuracy: 0.7872\n", "Epoch 13/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5559 - accuracy: 0.7901 - val_loss: 0.5555 - val_accuracy: 0.7863\n", "Epoch 14/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5553 - accuracy: 0.7908 - val_loss: 0.5523 - val_accuracy: 0.7866\n", "Epoch 15/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5551 - accuracy: 0.7896 - val_loss: 0.5567 - val_accuracy: 0.7848\n", "Epoch 16/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5539 - accuracy: 0.7896 - val_loss: 0.5556 - val_accuracy: 0.7869\n", "Epoch 17/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5571 - accuracy: 0.7890 - val_loss: 0.5609 - val_accuracy: 0.7756\n", "Epoch 18/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5560 - accuracy: 0.7882 - val_loss: 0.5589 - val_accuracy: 0.7845\n", "Epoch 19/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5547 - accuracy: 0.7889 - val_loss: 0.5525 - val_accuracy: 0.7881\n", "Epoch 20/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5539 - accuracy: 0.7905 - val_loss: 0.5508 - val_accuracy: 0.7872\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 13 } ] }, { "cell_type": "markdown", "source": [ "Best Validation Accuracy: **0.7974**, which is similar to the accuracy of dummy encoding technique." ], "metadata": { "id": "LnS3RbtRXVpJ" } }, { "cell_type": "markdown", "source": [ "# Binary Encoding of Inputs\n", "\n", "Binary encoding is a combination of hash encoding and one hot encoding." ], "metadata": { "id": "yJ6S7V_sJfI_" } }, { "cell_type": "code", "source": [ "from category_encoders.binary import BinaryEncoder\n", "\n", "x_train, x_test, y_train, y_test = train_test_split(\n", " data,\n", " targets,\n", " test_size=test_size,\n", " random_state=random_state\n", ")\n", "\n", "binary_encoder = BinaryEncoder()\n", "binary_encoder.fit(data)\n", "x_train = binary_encoder.transform(x_train)\n", "x_test = binary_encoder.transform(x_test)\n", "x_train" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 488 }, "id": "iniX8AEWJdaT", "outputId": "72701aff-073a-43bb-f8bc-686c43dd00f7" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " age workclass_0 workclass_1 workclass_2 workclass_3 fnlwgt \\\n", "9281 35 0 0 1 0 241126 \n", "31884 28 0 0 1 0 90547 \n", "31580 55 0 0 1 0 70088 \n", "18489 33 0 0 1 0 182423 \n", "21111 39 0 1 1 0 163057 \n", "... ... ... ... ... ... ... \n", "13123 90 0 1 0 1 282095 \n", "19648 36 0 0 1 0 279721 \n", "9845 26 0 0 1 0 51961 \n", "10799 44 0 0 1 0 115323 \n", "2732 39 0 0 1 0 224531 \n", "\n", " education_0 education_1 education_2 education_3 ... sex_1 \\\n", "9281 0 1 0 0 ... 0 \n", "31884 0 0 0 0 ... 1 \n", "31580 0 1 0 0 ... 0 \n", "18489 0 0 0 0 ... 0 \n", "21111 0 0 0 1 ... 0 \n", "... ... ... ... ... ... ... \n", "13123 0 0 0 1 ... 0 \n", "19648 0 0 0 0 ... 0 \n", "9845 0 1 1 1 ... 0 \n", "10799 0 1 0 0 ... 0 \n", "2732 0 0 0 0 ... 0 \n", "\n", " capital.gain capital.loss hours.per.week native.country_0 \\\n", "9281 0 0 40 0 \n", "31884 0 0 23 0 \n", "31580 0 0 40 0 \n", "18489 0 0 40 0 \n", "21111 0 0 99 0 \n", "... ... ... ... ... \n", "13123 0 0 40 0 \n", "19648 0 0 40 0 \n", "9845 0 0 51 0 \n", "10799 0 0 40 0 \n", "2732 7298 0 40 0 \n", "\n", " native.country_1 native.country_2 native.country_3 native.country_4 \\\n", "9281 0 0 0 0 \n", "31884 0 0 0 0 \n", "31580 0 0 0 0 \n", "18489 0 0 0 0 \n", "21111 0 0 0 0 \n", "... ... ... ... ... \n", "13123 0 0 0 0 \n", "19648 0 0 0 0 \n", "9845 0 0 0 0 \n", "10799 0 0 0 0 \n", "2732 0 0 0 0 \n", "\n", " native.country_5 \n", "9281 1 \n", "31884 1 \n", "31580 1 \n", "18489 1 \n", "21111 1 \n", "... ... \n", "13123 1 \n", "19648 1 \n", "9845 1 \n", "10799 1 \n", "2732 1 \n", "\n", "[29304 rows x 36 columns]" ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclass_0workclass_1workclass_2workclass_3fnlwgteducation_0education_1education_2education_3...sex_1capital.gaincapital.losshours.per.weeknative.country_0native.country_1native.country_2native.country_3native.country_4native.country_5
92813500102411260100...00040000001
31884280010905470000...10023000001
31580550010700880100...00040000001
184893300101824230000...00040000001
211113901101630570001...00099000001
..................................................................
131239001012820950001...00040000001
196483600102797210000...00040000001
9845260010519610111...00051000001
107994400101153230100...00040000001
27323900102245310000...07298040000001
\n", "

29304 rows × 36 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 30 } ] }, { "cell_type": "code", "source": [ "model = build_model(x_train.shape[1:])\n", "model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=epochs, shuffle=shuffle, validation_data=(x_test, y_test))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "v_ezlxufJqDJ", "outputId": "3115c8a6-c71d-4882-b18b-b278e7eba34e" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/20\n", "916/916 [==============================] - 11s 8ms/step - loss: 1.6675 - accuracy: 0.7724 - val_loss: 0.6335 - val_accuracy: 0.7906\n", "Epoch 2/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.6336 - accuracy: 0.7812 - val_loss: 0.6054 - val_accuracy: 0.7854\n", "Epoch 3/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.6015 - accuracy: 0.7846 - val_loss: 0.6418 - val_accuracy: 0.7307\n", "Epoch 4/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5899 - accuracy: 0.7870 - val_loss: 0.5834 - val_accuracy: 0.7906\n", "Epoch 5/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5835 - accuracy: 0.7879 - val_loss: 0.6220 - val_accuracy: 0.7467\n", "Epoch 6/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5784 - accuracy: 0.7860 - val_loss: 0.5688 - val_accuracy: 0.7869\n", "Epoch 7/20\n", "916/916 [==============================] - 8s 8ms/step - loss: 0.5715 - accuracy: 0.7898 - val_loss: 0.5726 - val_accuracy: 0.7866\n", "Epoch 8/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5693 - accuracy: 0.7877 - val_loss: 0.5846 - val_accuracy: 0.7685\n", "Epoch 9/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5664 - accuracy: 0.7880 - val_loss: 0.5784 - val_accuracy: 0.7878\n", "Epoch 10/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5668 - accuracy: 0.7887 - val_loss: 0.5604 - val_accuracy: 0.7866\n", "Epoch 11/20\n", "916/916 [==============================] - 8s 8ms/step - loss: 0.5617 - accuracy: 0.7905 - val_loss: 0.5573 - val_accuracy: 0.7885\n", "Epoch 12/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5595 - accuracy: 0.7899 - val_loss: 0.5572 - val_accuracy: 0.7860\n", "Epoch 13/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5596 - accuracy: 0.7899 - val_loss: 0.5579 - val_accuracy: 0.7875\n", "Epoch 14/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5581 - accuracy: 0.7889 - val_loss: 0.5606 - val_accuracy: 0.7921\n", "Epoch 15/20\n", "916/916 [==============================] - 8s 8ms/step - loss: 0.5583 - accuracy: 0.7900 - val_loss: 0.5596 - val_accuracy: 0.7894\n", "Epoch 16/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5556 - accuracy: 0.7904 - val_loss: 0.5557 - val_accuracy: 0.7885\n", "Epoch 17/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5592 - accuracy: 0.7873 - val_loss: 0.5684 - val_accuracy: 0.7835\n", "Epoch 18/20\n", "916/916 [==============================] - 8s 9ms/step - loss: 0.5608 - accuracy: 0.7854 - val_loss: 0.5581 - val_accuracy: 0.7759\n", "Epoch 19/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5631 - accuracy: 0.7848 - val_loss: 0.5584 - val_accuracy: 0.7740\n", "Epoch 20/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5573 - accuracy: 0.7876 - val_loss: 0.5592 - val_accuracy: 0.7839\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 31 } ] }, { "cell_type": "markdown", "source": [ "Best Validation Accuracy: **0.7906**" ], "metadata": { "id": "aAptX4hTKhoP" } }, { "cell_type": "markdown", "source": [ "#Hash Encoder\n", "\n", "Hashing involves transforming any given categorical variable into a fixed and unique numerical value using a hash function. Popular hash functions for Hash encoders include *MD, MD2, MD5*. In this example, we will use the **MD5** method. The advantage of the hash encoder over previous approaches is that it can use any number of additional columns to describe categories. 20 Additional featutures will be used to hash the categories of this dataset." ], "metadata": { "id": "BKjgjPwez8bZ" } }, { "cell_type": "code", "source": [ "from category_encoders.hashing import HashingEncoder\n", "\n", "x_train, x_test, y_train, y_test = train_test_split(\n", " data,\n", " targets,\n", " test_size=test_size,\n", " random_state=random_state\n", ")\n", "\n", "hash_encoder = HashingEncoder(n_components=20, hash_method='md5')\n", "hash_encoder.fit(data)\n", "x_train = hash_encoder.transform(x_train)\n", "x_test = hash_encoder.transform(x_test)\n", "x_train" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "id": "TxqNkZ3cyemB", "outputId": "cc791eef-0e36-4f77-8d47-0acd83b1dc19" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " col_0 col_1 col_2 col_3 col_4 col_5 col_6 col_7 col_8 col_9 \\\n", "9281 0 0 0 0 0 0 0 0 1 0 \n", "31884 0 0 0 1 1 0 0 0 0 0 \n", "31580 0 0 0 0 0 0 0 0 1 0 \n", "18489 0 0 0 0 1 1 0 0 0 0 \n", "21111 1 0 0 0 0 1 1 0 0 0 \n", "... ... ... ... ... ... ... ... ... ... ... \n", "13123 0 0 0 0 0 0 0 0 1 0 \n", "19648 0 0 0 0 0 0 0 0 1 0 \n", "9845 0 1 1 0 0 1 0 0 0 0 \n", "10799 0 0 1 0 0 0 0 0 1 0 \n", "2732 0 0 0 0 0 0 1 0 1 0 \n", "\n", " ... col_16 col_17 col_18 col_19 age fnlwgt education.num \\\n", "9281 ... 0 2 1 1 35 241126 14 \n", "31884 ... 0 1 3 1 28 90547 9 \n", "31580 ... 0 2 1 1 55 70088 14 \n", "18489 ... 0 2 4 0 33 182423 9 \n", "21111 ... 0 1 1 1 39 163057 10 \n", "... ... ... ... ... ... ... ... ... \n", "13123 ... 0 2 1 2 90 282095 10 \n", "19648 ... 0 2 2 2 36 279721 9 \n", "9845 ... 0 2 2 0 26 51961 8 \n", "10799 ... 0 2 1 1 44 115323 14 \n", "2732 ... 0 2 2 1 39 224531 9 \n", "\n", " capital.gain capital.loss hours.per.week \n", "9281 0 0 40 \n", "31884 0 0 23 \n", "31580 0 0 40 \n", "18489 0 0 40 \n", "21111 0 0 99 \n", "... ... ... ... \n", "13123 0 0 40 \n", "19648 0 0 40 \n", "9845 0 0 51 \n", "10799 0 0 40 \n", "2732 7298 0 40 \n", "\n", "[29304 rows x 26 columns]" ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
col_0col_1col_2col_3col_4col_5col_6col_7col_8col_9...col_16col_17col_18col_19agefnlwgteducation.numcapital.gaincapital.losshours.per.week
92810000000010...021135241126140040
318840001100000...0131289054790023
315800000000010...02115570088140040
184890000110000...02403318242390040
211111000011000...011139163057100099
..................................................................
131230000000010...021290282095100040
196480000000010...02223627972190040
98450110010000...0220265196180051
107990010000010...021144115323140040
27320000001010...02213922453197298040
\n", "

29304 rows × 26 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 27 } ] }, { "cell_type": "code", "source": [ "model = build_model(x_train.shape[1:])\n", "model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=epochs, shuffle=shuffle, validation_data=(x_test, y_test))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AgCNf9zs0lc6", "outputId": "11617400-9633-4a18-f3b6-36e9edd484ac" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/20\n", "916/916 [==============================] - 9s 8ms/step - loss: 1.6417 - accuracy: 0.7719 - val_loss: 0.6505 - val_accuracy: 0.7928\n", "Epoch 2/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.6241 - accuracy: 0.7830 - val_loss: 0.6141 - val_accuracy: 0.7869\n", "Epoch 3/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5997 - accuracy: 0.7880 - val_loss: 0.6015 - val_accuracy: 0.7839\n", "Epoch 4/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5897 - accuracy: 0.7868 - val_loss: 0.5934 - val_accuracy: 0.7875\n", "Epoch 5/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5805 - accuracy: 0.7892 - val_loss: 0.5760 - val_accuracy: 0.7832\n", "Epoch 6/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5733 - accuracy: 0.7892 - val_loss: 0.5601 - val_accuracy: 0.8032\n", "Epoch 7/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5751 - accuracy: 0.7887 - val_loss: 0.5671 - val_accuracy: 0.7881\n", "Epoch 8/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5692 - accuracy: 0.7892 - val_loss: 0.5654 - val_accuracy: 0.7875\n", "Epoch 9/20\n", "916/916 [==============================] - 8s 9ms/step - loss: 0.5655 - accuracy: 0.7898 - val_loss: 0.5704 - val_accuracy: 0.7863\n", "Epoch 10/20\n", "916/916 [==============================] - 8s 9ms/step - loss: 0.5648 - accuracy: 0.7895 - val_loss: 0.5578 - val_accuracy: 0.7854\n", "Epoch 11/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5629 - accuracy: 0.7902 - val_loss: 0.5603 - val_accuracy: 0.8004\n", "Epoch 12/20\n", "916/916 [==============================] - 8s 8ms/step - loss: 0.5623 - accuracy: 0.7888 - val_loss: 0.5664 - val_accuracy: 0.7860\n", "Epoch 13/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5615 - accuracy: 0.7880 - val_loss: 0.5544 - val_accuracy: 0.7872\n", "Epoch 14/20\n", "916/916 [==============================] - 8s 8ms/step - loss: 0.5580 - accuracy: 0.7910 - val_loss: 0.5672 - val_accuracy: 0.7955\n", "Epoch 15/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5597 - accuracy: 0.7910 - val_loss: 0.5623 - val_accuracy: 0.7878\n", "Epoch 16/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5576 - accuracy: 0.7901 - val_loss: 0.5561 - val_accuracy: 0.7906\n", "Epoch 17/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5560 - accuracy: 0.7901 - val_loss: 0.5540 - val_accuracy: 0.7869\n", "Epoch 18/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5564 - accuracy: 0.7913 - val_loss: 0.5525 - val_accuracy: 0.7937\n", "Epoch 19/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5548 - accuracy: 0.7921 - val_loss: 0.5504 - val_accuracy: 0.7888\n", "Epoch 20/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5556 - accuracy: 0.7900 - val_loss: 0.5536 - val_accuracy: 0.7872\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 28 } ] }, { "cell_type": "markdown", "source": [ "Best Validation Accuracy: **0.8032**" ], "metadata": { "id": "Ikp7UE94Bs3P" } }, { "cell_type": "markdown", "source": [ "#BaseN Encoder\n", "\n", "Encoding the categorial variables using numbers from an arithmetic system of base \"4\"." ], "metadata": { "id": "xCHxxWDJL9Qp" } }, { "cell_type": "code", "source": [ "from category_encoders.basen import BaseNEncoder\n", "\n", "x_train, x_test, y_train, y_test = train_test_split(\n", " data,\n", " targets,\n", " test_size=test_size,\n", " random_state=random_state\n", ")\n", "\n", "basen_encoder = BaseNEncoder(base=4)\n", "basen_encoder.fit(data)\n", "x_train = basen_encoder.transform(x_train)\n", "x_test = basen_encoder.transform(x_test)\n", "x_train" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 488 }, "id": "5S0_KWd6LsL9", "outputId": "c7221245-b1dc-4fa7-d217-0b1fd0e775a2" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " age workclass_0 workclass_1 fnlwgt education_0 education_1 \\\n", "9281 35 0 2 241126 0 2 \n", "31884 28 0 2 90547 0 0 \n", "31580 55 0 2 70088 0 2 \n", "18489 33 0 2 182423 0 0 \n", "21111 39 1 2 163057 0 0 \n", "... ... ... ... ... ... ... \n", "13123 90 1 1 282095 0 0 \n", "19648 36 0 2 279721 0 0 \n", "9845 26 0 2 51961 0 3 \n", "10799 44 0 2 115323 0 2 \n", "2732 39 0 2 224531 0 0 \n", "\n", " education_2 education.num marital.status_0 marital.status_1 ... \\\n", "9281 0 14 1 1 ... \n", "31884 1 9 1 1 ... \n", "31580 0 14 1 1 ... \n", "18489 1 9 0 2 ... \n", "21111 2 10 0 2 ... \n", "... ... ... ... ... ... \n", "13123 2 10 1 1 ... \n", "19648 1 9 1 1 ... \n", "9845 2 8 1 0 ... \n", "10799 0 14 1 1 ... \n", "2732 1 9 1 1 ... \n", "\n", " relationship_1 race_0 race_1 sex_0 capital.gain capital.loss \\\n", "9281 1 0 1 2 0 0 \n", "31884 2 0 2 1 0 0 \n", "31580 1 0 1 2 0 0 \n", "18489 2 0 2 2 0 0 \n", "21111 1 0 1 2 0 0 \n", "... ... ... ... ... ... ... \n", "13123 1 0 1 2 0 0 \n", "19648 1 0 1 2 0 0 \n", "9845 0 0 2 2 0 0 \n", "10799 1 0 1 2 0 0 \n", "2732 1 0 1 2 7298 0 \n", "\n", " hours.per.week native.country_0 native.country_1 native.country_2 \n", "9281 40 0 0 1 \n", "31884 23 0 0 1 \n", "31580 40 0 0 1 \n", "18489 40 0 0 1 \n", "21111 99 0 0 1 \n", "... ... ... ... ... \n", "13123 40 0 0 1 \n", "19648 40 0 0 1 \n", "9845 51 0 0 1 \n", "10799 40 0 0 1 \n", "2732 40 0 0 1 \n", "\n", "[29304 rows x 23 columns]" ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclass_0workclass_1fnlwgteducation_0education_1education_2education.nummarital.status_0marital.status_1...relationship_1race_0race_1sex_0capital.gaincapital.losshours.per.weeknative.country_0native.country_1native.country_2
928135022411260201411...10120040001
31884280290547001911...20210023001
315805502700880201411...10120040001
184893302182423001902...20220040001
2111139121630570021002...10120099001
..................................................................
1312390112820950021011...10120040001
196483602279721001911...10120040001
9845260251961032810...00220051001
1079944021153230201411...10120040001
27323902224531001911...10127298040001
\n", "

29304 rows × 23 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 32 } ] }, { "cell_type": "code", "source": [ "model = build_model(x_train.shape[1:])\n", "model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=epochs, shuffle=shuffle, validation_data=(x_test, y_test))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "I8xcyjoZMaMr", "outputId": "883b9205-b7cf-4d4b-a72b-6dc914d10807" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/20\n", "916/916 [==============================] - 11s 8ms/step - loss: 1.5810 - accuracy: 0.7698 - val_loss: 0.8075 - val_accuracy: 0.6945\n", "Epoch 2/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.6178 - accuracy: 0.7839 - val_loss: 0.7165 - val_accuracy: 0.6524\n", "Epoch 3/20\n", "916/916 [==============================] - 8s 8ms/step - loss: 0.6009 - accuracy: 0.7823 - val_loss: 0.5845 - val_accuracy: 0.7710\n", "Epoch 4/20\n", "916/916 [==============================] - 9s 10ms/step - loss: 0.5899 - accuracy: 0.7833 - val_loss: 0.6068 - val_accuracy: 0.7860\n", "Epoch 5/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5821 - accuracy: 0.7823 - val_loss: 0.6083 - val_accuracy: 0.7614\n", "Epoch 6/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5706 - accuracy: 0.7872 - val_loss: 0.5674 - val_accuracy: 0.7808\n", "Epoch 7/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5676 - accuracy: 0.7863 - val_loss: 0.5960 - val_accuracy: 0.7611\n", "Epoch 8/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5678 - accuracy: 0.7867 - val_loss: 0.5724 - val_accuracy: 0.7851\n", "Epoch 9/20\n", "916/916 [==============================] - 8s 8ms/step - loss: 0.5628 - accuracy: 0.7894 - val_loss: 0.5563 - val_accuracy: 0.7921\n", "Epoch 10/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5622 - accuracy: 0.7885 - val_loss: 0.5606 - val_accuracy: 0.7866\n", "Epoch 11/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5597 - accuracy: 0.7874 - val_loss: 0.5649 - val_accuracy: 0.7786\n", "Epoch 12/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5594 - accuracy: 0.7888 - val_loss: 0.5581 - val_accuracy: 0.7869\n", "Epoch 13/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5576 - accuracy: 0.7906 - val_loss: 0.5594 - val_accuracy: 0.7869\n", "Epoch 14/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5555 - accuracy: 0.7911 - val_loss: 0.5603 - val_accuracy: 0.7869\n", "Epoch 15/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5567 - accuracy: 0.7880 - val_loss: 0.5627 - val_accuracy: 0.7860\n", "Epoch 16/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5559 - accuracy: 0.7904 - val_loss: 0.5578 - val_accuracy: 0.7869\n", "Epoch 17/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5561 - accuracy: 0.7893 - val_loss: 0.5614 - val_accuracy: 0.7829\n", "Epoch 18/20\n", "916/916 [==============================] - 8s 8ms/step - loss: 0.5565 - accuracy: 0.7902 - val_loss: 0.5584 - val_accuracy: 0.7863\n", "Epoch 19/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5547 - accuracy: 0.7897 - val_loss: 0.5550 - val_accuracy: 0.7866\n", "Epoch 20/20\n", "916/916 [==============================] - 7s 8ms/step - loss: 0.5540 - accuracy: 0.7890 - val_loss: 0.5579 - val_accuracy: 0.7897\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 33 } ] }, { "cell_type": "markdown", "source": [ "Best Validation Accuracy: **0.7921**" ], "metadata": { "id": "Us--5raUNVko" } }, { "cell_type": "markdown", "source": [ "#Other Categorical Encoders" ], "metadata": { "id": "UrN8d9Mjuliz" } }, { "cell_type": "code", "source": [ "from category_encoders.backward_difference import BackwardDifferenceEncoder\n", "from category_encoders.cat_boost import CatBoostEncoder\n", "from category_encoders.count import CountEncoder\n", "from category_encoders.glmm import GLMMEncoder\n", "from category_encoders.target_encoder import TargetEncoder\n", "from category_encoders.helmert import HelmertEncoder\n", "from category_encoders.james_stein import JamesSteinEncoder\n", "\n", "encoders = {\n", " 'Backward Difference Encoder': BackwardDifferenceEncoder(),\n", " 'Cat Boost Encoder': CatBoostEncoder(),\n", " 'Count Encoder': CountEncoder(),\n", " 'Generalized Linear Mixed Model Encoder': GLMMEncoder(binomial_target=True),\n", " 'Target Encoder': TargetEncoder(),\n", " 'Helmert Encoder': HelmertEncoder(),\n", " 'James-Stein Encoder': JamesSteinEncoder()\n", "}\n", "\n", "for encoder_name, encoder in encoders.items():\n", " print('Evaluating: {}'.format(encoder))\n", " x_train, x_test, y_train, y_test = train_test_split(\n", " data,\n", " targets,\n", " test_size=test_size,\n", " random_state=random_state\n", " )\n", "\n", " encoder.fit(data, targets)\n", " x_train = encoder.transform(x_train)\n", " x_test = encoder.transform(x_test)\n", " \n", " model = build_model(x_train.shape[1:])\n", " model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=epochs, shuffle=shuffle, validation_data=(x_test, y_test))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-PUffQk6tyco", "outputId": "04244e3a-cfc5-4f0d-bf94-7567d3eb6bd2" }, "execution_count": 17, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Evaluating: BackwardDifferenceEncoder()\n", "Epoch 1/20\n", "916/916 [==============================] - 7s 4ms/step - loss: 1.6122 - accuracy: 0.7715 - val_loss: 0.9418 - val_accuracy: 0.6742\n", "Epoch 2/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.6321 - accuracy: 0.7838 - val_loss: 0.7613 - val_accuracy: 0.6380\n", "Epoch 3/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.6051 - accuracy: 0.7862 - val_loss: 0.6074 - val_accuracy: 0.7811\n", "Epoch 4/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5926 - accuracy: 0.7853 - val_loss: 0.5854 - val_accuracy: 0.7835\n", "Epoch 5/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5846 - accuracy: 0.7857 - val_loss: 0.5915 - val_accuracy: 0.7777\n", "Epoch 6/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5760 - accuracy: 0.7887 - val_loss: 0.5742 - val_accuracy: 0.7814\n", "Epoch 7/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5713 - accuracy: 0.7882 - val_loss: 0.5753 - val_accuracy: 0.7839\n", "Epoch 8/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5685 - accuracy: 0.7882 - val_loss: 0.5769 - val_accuracy: 0.7832\n", "Epoch 9/20\n", "916/916 [==============================] - 3s 4ms/step - loss: 0.5654 - accuracy: 0.7892 - val_loss: 0.5746 - val_accuracy: 0.7842\n", "Epoch 10/20\n", "916/916 [==============================] - 3s 4ms/step - loss: 0.5631 - accuracy: 0.7883 - val_loss: 0.5668 - val_accuracy: 0.7835\n", "Epoch 11/20\n", "916/916 [==============================] - 3s 4ms/step - loss: 0.5613 - accuracy: 0.7879 - val_loss: 0.5709 - val_accuracy: 0.7756\n", "Epoch 12/20\n", "916/916 [==============================] - 3s 4ms/step - loss: 0.5612 - accuracy: 0.7891 - val_loss: 0.5608 - val_accuracy: 0.7860\n", "Epoch 13/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5639 - accuracy: 0.7883 - val_loss: 0.5717 - val_accuracy: 0.7792\n", "Epoch 14/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5651 - accuracy: 0.7851 - val_loss: 0.5617 - val_accuracy: 0.7869\n", "Epoch 15/20\n", "916/916 [==============================] - 3s 4ms/step - loss: 0.5647 - accuracy: 0.7848 - val_loss: 0.5596 - val_accuracy: 0.7866\n", "Epoch 16/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5605 - accuracy: 0.7872 - val_loss: 0.5616 - val_accuracy: 0.7881\n", "Epoch 17/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5602 - accuracy: 0.7870 - val_loss: 0.5587 - val_accuracy: 0.7860\n", "Epoch 18/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5578 - accuracy: 0.7886 - val_loss: 0.5611 - val_accuracy: 0.7915\n", "Epoch 19/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5547 - accuracy: 0.7901 - val_loss: 0.5511 - val_accuracy: 0.7875\n", "Epoch 20/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5533 - accuracy: 0.7914 - val_loss: 0.5547 - val_accuracy: 0.7866\n", "Evaluating: CatBoostEncoder()\n", "Epoch 1/20\n", "916/916 [==============================] - 5s 3ms/step - loss: 1.5688 - accuracy: 0.7703 - val_loss: 0.6104 - val_accuracy: 0.7952\n", "Epoch 2/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.6173 - accuracy: 0.7827 - val_loss: 0.7149 - val_accuracy: 0.6497\n", "Epoch 3/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5970 - accuracy: 0.7863 - val_loss: 0.5817 - val_accuracy: 0.7869\n", "Epoch 4/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5852 - accuracy: 0.7890 - val_loss: 0.5812 - val_accuracy: 0.7866\n", "Epoch 5/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5747 - accuracy: 0.7894 - val_loss: 0.5710 - val_accuracy: 0.7848\n", "Epoch 6/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5725 - accuracy: 0.7900 - val_loss: 0.5883 - val_accuracy: 0.7918\n", "Epoch 7/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5688 - accuracy: 0.7880 - val_loss: 0.5602 - val_accuracy: 0.7872\n", "Epoch 8/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5680 - accuracy: 0.7876 - val_loss: 0.5939 - val_accuracy: 0.7897\n", "Epoch 9/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5647 - accuracy: 0.7902 - val_loss: 0.5535 - val_accuracy: 0.7918\n", "Epoch 10/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5619 - accuracy: 0.7896 - val_loss: 0.5604 - val_accuracy: 0.7897\n", "Epoch 11/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5601 - accuracy: 0.7927 - val_loss: 0.5574 - val_accuracy: 0.7891\n", "Epoch 12/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5594 - accuracy: 0.7913 - val_loss: 0.5559 - val_accuracy: 0.7918\n", "Epoch 13/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5586 - accuracy: 0.7907 - val_loss: 0.5611 - val_accuracy: 0.7878\n", "Epoch 14/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5592 - accuracy: 0.7895 - val_loss: 0.5601 - val_accuracy: 0.7866\n", "Epoch 15/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5593 - accuracy: 0.7887 - val_loss: 0.5558 - val_accuracy: 0.7866\n", "Epoch 16/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5573 - accuracy: 0.7904 - val_loss: 0.5589 - val_accuracy: 0.7789\n", "Epoch 17/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5558 - accuracy: 0.7898 - val_loss: 0.5577 - val_accuracy: 0.7796\n", "Epoch 18/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5563 - accuracy: 0.7915 - val_loss: 0.5563 - val_accuracy: 0.7845\n", "Epoch 19/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5570 - accuracy: 0.7897 - val_loss: 0.5536 - val_accuracy: 0.7866\n", "Epoch 20/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5548 - accuracy: 0.7912 - val_loss: 0.5570 - val_accuracy: 0.7863\n", "Evaluating: CountEncoder(combine_min_nan_groups=True)\n", "Epoch 1/20\n", "916/916 [==============================] - 5s 3ms/step - loss: 1.6526 - accuracy: 0.7675 - val_loss: 0.6198 - val_accuracy: 0.7928\n", "Epoch 2/20\n", "916/916 [==============================] - 4s 4ms/step - loss: 0.5854 - accuracy: 0.7898 - val_loss: 0.5651 - val_accuracy: 0.8035\n", "Epoch 3/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5512 - accuracy: 0.7973 - val_loss: 0.5520 - val_accuracy: 0.7992\n", "Epoch 4/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5382 - accuracy: 0.7985 - val_loss: 0.5161 - val_accuracy: 0.8099\n", "Epoch 5/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5277 - accuracy: 0.7985 - val_loss: 0.5093 - val_accuracy: 0.8056\n", "Epoch 6/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5174 - accuracy: 0.8022 - val_loss: 0.5070 - val_accuracy: 0.8115\n", "Epoch 7/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5117 - accuracy: 0.8021 - val_loss: 0.5100 - val_accuracy: 0.8112\n", "Epoch 8/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5049 - accuracy: 0.8035 - val_loss: 0.5102 - val_accuracy: 0.8072\n", "Epoch 9/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5035 - accuracy: 0.8025 - val_loss: 0.5029 - val_accuracy: 0.8096\n", "Epoch 10/20\n", "916/916 [==============================] - 2s 3ms/step - loss: 0.5013 - accuracy: 0.8047 - val_loss: 0.5038 - val_accuracy: 0.8158\n", "Epoch 11/20\n", "916/916 [==============================] - 2s 3ms/step - loss: 0.5007 - accuracy: 0.8051 - val_loss: 0.5019 - val_accuracy: 0.8050\n", "Epoch 12/20\n", "916/916 [==============================] - 2s 3ms/step - loss: 0.4982 - accuracy: 0.8044 - val_loss: 0.4882 - val_accuracy: 0.8121\n", "Epoch 13/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.4976 - accuracy: 0.8061 - val_loss: 0.4964 - val_accuracy: 0.8136\n", "Epoch 14/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.4968 - accuracy: 0.8037 - val_loss: 0.4917 - val_accuracy: 0.8081\n", "Epoch 15/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.4961 - accuracy: 0.8033 - val_loss: 0.4887 - val_accuracy: 0.8124\n", "Epoch 16/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.4983 - accuracy: 0.8029 - val_loss: 0.4979 - val_accuracy: 0.8093\n", "Epoch 17/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.4944 - accuracy: 0.8044 - val_loss: 0.4868 - val_accuracy: 0.8087\n", "Epoch 18/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.4950 - accuracy: 0.8042 - val_loss: 0.4937 - val_accuracy: 0.8149\n", "Epoch 19/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.4947 - accuracy: 0.8046 - val_loss: 0.4925 - val_accuracy: 0.8112\n", "Epoch 20/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.4957 - accuracy: 0.8032 - val_loss: 0.4891 - val_accuracy: 0.8115\n", "Evaluating: GLMMEncoder(binomial_target=True)\n", "Epoch 1/20\n", "916/916 [==============================] - 5s 3ms/step - loss: 1.5992 - accuracy: 0.7698 - val_loss: 0.6632 - val_accuracy: 0.7654\n", "Epoch 2/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.6290 - accuracy: 0.7802 - val_loss: 0.5877 - val_accuracy: 0.7998\n", "Epoch 3/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5949 - accuracy: 0.7849 - val_loss: 0.6756 - val_accuracy: 0.6942\n", "Epoch 4/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5858 - accuracy: 0.7851 - val_loss: 0.5940 - val_accuracy: 0.7670\n", "Epoch 5/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5817 - accuracy: 0.7841 - val_loss: 0.6014 - val_accuracy: 0.7866\n", "Epoch 6/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5744 - accuracy: 0.7861 - val_loss: 0.5658 - val_accuracy: 0.7869\n", "Epoch 7/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5701 - accuracy: 0.7860 - val_loss: 0.5736 - val_accuracy: 0.7854\n", "Epoch 8/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5692 - accuracy: 0.7868 - val_loss: 0.5758 - val_accuracy: 0.7839\n", "Epoch 9/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5683 - accuracy: 0.7853 - val_loss: 0.5610 - val_accuracy: 0.7875\n", "Epoch 10/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5620 - accuracy: 0.7880 - val_loss: 0.5637 - val_accuracy: 0.7875\n", "Epoch 11/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5605 - accuracy: 0.7901 - val_loss: 0.5563 - val_accuracy: 0.7924\n", "Epoch 12/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5608 - accuracy: 0.7901 - val_loss: 0.5524 - val_accuracy: 0.7940\n", "Epoch 13/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5576 - accuracy: 0.7910 - val_loss: 0.5592 - val_accuracy: 0.7928\n", "Epoch 14/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5575 - accuracy: 0.7925 - val_loss: 0.5590 - val_accuracy: 0.7866\n", "Epoch 15/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5574 - accuracy: 0.7927 - val_loss: 0.5533 - val_accuracy: 0.7900\n", "Epoch 16/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5557 - accuracy: 0.7929 - val_loss: 0.5542 - val_accuracy: 0.7872\n", "Epoch 17/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5550 - accuracy: 0.7941 - val_loss: 0.5540 - val_accuracy: 0.7875\n", "Epoch 18/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5532 - accuracy: 0.7931 - val_loss: 0.5493 - val_accuracy: 0.8020\n", "Epoch 19/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5537 - accuracy: 0.7932 - val_loss: 0.5502 - val_accuracy: 0.7875\n", "Epoch 20/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5565 - accuracy: 0.7903 - val_loss: 0.5535 - val_accuracy: 0.7851\n", "Evaluating: TargetEncoder()\n", "Epoch 1/20\n", "916/916 [==============================] - 4s 3ms/step - loss: 1.6035 - accuracy: 0.7716 - val_loss: 0.6238 - val_accuracy: 0.8056\n", "Epoch 2/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.6275 - accuracy: 0.7811 - val_loss: 0.6185 - val_accuracy: 0.7737\n", "Epoch 3/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.6049 - accuracy: 0.7836 - val_loss: 0.6133 - val_accuracy: 0.7808\n", "Epoch 4/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5910 - accuracy: 0.7832 - val_loss: 0.6452 - val_accuracy: 0.7194\n", "Epoch 5/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5854 - accuracy: 0.7851 - val_loss: 0.6032 - val_accuracy: 0.7845\n", "Epoch 6/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5809 - accuracy: 0.7836 - val_loss: 0.5716 - val_accuracy: 0.7863\n", "Epoch 7/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5722 - accuracy: 0.7870 - val_loss: 0.5757 - val_accuracy: 0.7719\n", "Epoch 8/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5714 - accuracy: 0.7879 - val_loss: 0.5663 - val_accuracy: 0.7848\n", "Epoch 9/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5663 - accuracy: 0.7878 - val_loss: 0.6079 - val_accuracy: 0.7565\n", "Epoch 10/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5612 - accuracy: 0.7898 - val_loss: 0.5657 - val_accuracy: 0.7891\n", "Epoch 11/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5613 - accuracy: 0.7893 - val_loss: 0.5570 - val_accuracy: 0.8004\n", "Epoch 12/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5611 - accuracy: 0.7889 - val_loss: 0.5703 - val_accuracy: 0.7854\n", "Epoch 13/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5602 - accuracy: 0.7896 - val_loss: 0.5614 - val_accuracy: 0.7848\n", "Epoch 14/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5592 - accuracy: 0.7884 - val_loss: 0.5575 - val_accuracy: 0.7869\n", "Epoch 15/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5564 - accuracy: 0.7916 - val_loss: 0.5530 - val_accuracy: 0.7869\n", "Epoch 16/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5569 - accuracy: 0.7901 - val_loss: 0.5605 - val_accuracy: 0.7863\n", "Epoch 17/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5593 - accuracy: 0.7880 - val_loss: 0.5501 - val_accuracy: 0.7891\n", "Epoch 18/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5557 - accuracy: 0.7889 - val_loss: 0.5628 - val_accuracy: 0.7749\n", "Epoch 19/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5546 - accuracy: 0.7906 - val_loss: 0.5630 - val_accuracy: 0.7811\n", "Epoch 20/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5543 - accuracy: 0.7904 - val_loss: 0.5583 - val_accuracy: 0.7863\n", "Evaluating: HelmertEncoder()\n", "Epoch 1/20\n", "916/916 [==============================] - 5s 4ms/step - loss: 1.6688 - accuracy: 0.7675 - val_loss: 0.6207 - val_accuracy: 0.7891\n", "Epoch 2/20\n", "916/916 [==============================] - 3s 4ms/step - loss: 0.6234 - accuracy: 0.7826 - val_loss: 0.6019 - val_accuracy: 0.7983\n", "Epoch 3/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.6009 - accuracy: 0.7873 - val_loss: 0.6878 - val_accuracy: 0.6678\n", "Epoch 4/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5939 - accuracy: 0.7887 - val_loss: 0.5784 - val_accuracy: 0.8017\n", "Epoch 5/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5812 - accuracy: 0.7913 - val_loss: 0.6017 - val_accuracy: 0.7964\n", "Epoch 6/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5754 - accuracy: 0.7902 - val_loss: 0.5869 - val_accuracy: 0.7989\n", "Epoch 7/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5732 - accuracy: 0.7902 - val_loss: 0.5659 - val_accuracy: 0.7891\n", "Epoch 8/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5690 - accuracy: 0.7911 - val_loss: 0.5655 - val_accuracy: 0.8017\n", "Epoch 9/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5657 - accuracy: 0.7917 - val_loss: 0.5529 - val_accuracy: 0.8032\n", "Epoch 10/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5629 - accuracy: 0.7919 - val_loss: 0.5993 - val_accuracy: 0.7706\n", "Epoch 11/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5624 - accuracy: 0.7922 - val_loss: 0.5545 - val_accuracy: 0.7967\n", "Epoch 12/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5608 - accuracy: 0.7901 - val_loss: 0.5572 - val_accuracy: 0.7869\n", "Epoch 13/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5616 - accuracy: 0.7907 - val_loss: 0.5609 - val_accuracy: 0.7903\n", "Epoch 14/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5608 - accuracy: 0.7919 - val_loss: 0.5694 - val_accuracy: 0.7860\n", "Epoch 15/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5591 - accuracy: 0.7926 - val_loss: 0.5520 - val_accuracy: 0.7881\n", "Epoch 16/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5618 - accuracy: 0.7897 - val_loss: 0.5572 - val_accuracy: 0.7866\n", "Epoch 17/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5584 - accuracy: 0.7906 - val_loss: 0.5578 - val_accuracy: 0.7869\n", "Epoch 18/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5568 - accuracy: 0.7915 - val_loss: 0.5577 - val_accuracy: 0.7851\n", "Epoch 19/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5561 - accuracy: 0.7919 - val_loss: 0.5566 - val_accuracy: 0.7866\n", "Epoch 20/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5587 - accuracy: 0.7891 - val_loss: 0.5575 - val_accuracy: 0.7866\n", "Evaluating: JamesSteinEncoder()\n", "Epoch 1/20\n", "916/916 [==============================] - 4s 3ms/step - loss: 1.6417 - accuracy: 0.7693 - val_loss: 0.6371 - val_accuracy: 0.7921\n", "Epoch 2/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.6234 - accuracy: 0.7847 - val_loss: 0.5947 - val_accuracy: 0.8050\n", "Epoch 3/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.6020 - accuracy: 0.7875 - val_loss: 0.5886 - val_accuracy: 0.7946\n", "Epoch 4/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5881 - accuracy: 0.7914 - val_loss: 0.6658 - val_accuracy: 0.7049\n", "Epoch 5/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5826 - accuracy: 0.7897 - val_loss: 0.5803 - val_accuracy: 0.7783\n", "Epoch 6/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5745 - accuracy: 0.7894 - val_loss: 0.5933 - val_accuracy: 0.7903\n", "Epoch 7/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5706 - accuracy: 0.7896 - val_loss: 0.5735 - val_accuracy: 0.7857\n", "Epoch 8/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5682 - accuracy: 0.7908 - val_loss: 0.5810 - val_accuracy: 0.7792\n", "Epoch 9/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5650 - accuracy: 0.7897 - val_loss: 0.5609 - val_accuracy: 0.7875\n", "Epoch 10/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5680 - accuracy: 0.7876 - val_loss: 0.5642 - val_accuracy: 0.7869\n", "Epoch 11/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5658 - accuracy: 0.7888 - val_loss: 0.5695 - val_accuracy: 0.7872\n", "Epoch 12/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5656 - accuracy: 0.7863 - val_loss: 0.5667 - val_accuracy: 0.7909\n", "Epoch 13/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5635 - accuracy: 0.7890 - val_loss: 0.5594 - val_accuracy: 0.7872\n", "Epoch 14/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5591 - accuracy: 0.7901 - val_loss: 0.5565 - val_accuracy: 0.7875\n", "Epoch 15/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5578 - accuracy: 0.7909 - val_loss: 0.5753 - val_accuracy: 0.7756\n", "Epoch 16/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5591 - accuracy: 0.7911 - val_loss: 0.5692 - val_accuracy: 0.8053\n", "Epoch 17/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5608 - accuracy: 0.7881 - val_loss: 0.5549 - val_accuracy: 0.7875\n", "Epoch 18/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5564 - accuracy: 0.7913 - val_loss: 0.5543 - val_accuracy: 0.7835\n", "Epoch 19/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5562 - accuracy: 0.7903 - val_loss: 0.5473 - val_accuracy: 0.7897\n", "Epoch 20/20\n", "916/916 [==============================] - 3s 3ms/step - loss: 0.5549 - accuracy: 0.7911 - val_loss: 0.5530 - val_accuracy: 0.7891\n" ] } ] } ] }