CDS1011_A2/main.ipynb

573 lines
111 KiB
Plaintext
Raw Normal View History

2024-12-02 19:42:00 +01:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Data Processing\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"# Data Visualization\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Training / Evaluation\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.metrics import f1_score"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Loading Data\n",
"dataset_path = 'datasets/UCI HAR Dataset/'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Load feature names\n",
"features = pd.read_csv(f'{dataset_path}features.txt', sep='\\s+', names=['feature_id', 'feature_name'])\n",
"feature_names = features['feature_name']\n",
"\n",
"# Creating a unique name for \"duplicate\" feature names (example: fBodyAcc-bandsEnergy()-1,8 )\n",
"name_count = {}\n",
"unique_feature_names = []\n",
"\n",
"for name in feature_names:\n",
" if name in name_count:\n",
" name_count[name] += 1\n",
" unique_feature_names.append(f\"{name}_{name_count[name]}\")\n",
" else:\n",
" name_count[name] = 0\n",
" unique_feature_names.append(name)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Loading training data\n",
"X_train = pd.read_csv(f'{dataset_path}train/X_train.txt', sep='\\s+', names=unique_feature_names)\n",
"y_train = pd.read_csv(f'{dataset_path}train/y_train.txt', sep='\\s+', names=['Activity'])\n",
"\n",
"# Loading testing data\n",
"X_test = pd.read_csv(f'{dataset_path}test/X_test.txt', sep='\\s+', names=unique_feature_names)\n",
"y_test = pd.read_csv(f'{dataset_path}test/y_test.txt', sep='\\s+', names=['activity'])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Activity</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Activity\n",
"0 5\n",
"1 5\n",
"2 5\n",
"3 5\n",
"4 5"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train.head()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>tBodyAcc-mean()-X</th>\n",
" <th>tBodyAcc-mean()-Y</th>\n",
" <th>tBodyAcc-mean()-Z</th>\n",
" <th>tBodyAcc-std()-X</th>\n",
" <th>tBodyAcc-std()-Y</th>\n",
" <th>tBodyAcc-std()-Z</th>\n",
" <th>tBodyAcc-mad()-X</th>\n",
" <th>tBodyAcc-mad()-Y</th>\n",
" <th>tBodyAcc-mad()-Z</th>\n",
" <th>tBodyAcc-max()-X</th>\n",
" <th>...</th>\n",
" <th>fBodyBodyGyroJerkMag-meanFreq()</th>\n",
" <th>fBodyBodyGyroJerkMag-skewness()</th>\n",
" <th>fBodyBodyGyroJerkMag-kurtosis()</th>\n",
" <th>angle(tBodyAccMean,gravity)</th>\n",
" <th>angle(tBodyAccJerkMean),gravityMean)</th>\n",
" <th>angle(tBodyGyroMean,gravityMean)</th>\n",
" <th>angle(tBodyGyroJerkMean,gravityMean)</th>\n",
" <th>angle(X,gravityMean)</th>\n",
" <th>angle(Y,gravityMean)</th>\n",
" <th>angle(Z,gravityMean)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.288585</td>\n",
" <td>-0.020294</td>\n",
" <td>-0.132905</td>\n",
" <td>-0.995279</td>\n",
" <td>-0.983111</td>\n",
" <td>-0.913526</td>\n",
" <td>-0.995112</td>\n",
" <td>-0.983185</td>\n",
" <td>-0.923527</td>\n",
" <td>-0.934724</td>\n",
" <td>...</td>\n",
" <td>-0.074323</td>\n",
" <td>-0.298676</td>\n",
" <td>-0.710304</td>\n",
" <td>-0.112754</td>\n",
" <td>0.030400</td>\n",
" <td>-0.464761</td>\n",
" <td>-0.018446</td>\n",
" <td>-0.841247</td>\n",
" <td>0.179941</td>\n",
" <td>-0.058627</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.278419</td>\n",
" <td>-0.016411</td>\n",
" <td>-0.123520</td>\n",
" <td>-0.998245</td>\n",
" <td>-0.975300</td>\n",
" <td>-0.960322</td>\n",
" <td>-0.998807</td>\n",
" <td>-0.974914</td>\n",
" <td>-0.957686</td>\n",
" <td>-0.943068</td>\n",
" <td>...</td>\n",
" <td>0.158075</td>\n",
" <td>-0.595051</td>\n",
" <td>-0.861499</td>\n",
" <td>0.053477</td>\n",
" <td>-0.007435</td>\n",
" <td>-0.732626</td>\n",
" <td>0.703511</td>\n",
" <td>-0.844788</td>\n",
" <td>0.180289</td>\n",
" <td>-0.054317</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.279653</td>\n",
" <td>-0.019467</td>\n",
" <td>-0.113462</td>\n",
" <td>-0.995380</td>\n",
" <td>-0.967187</td>\n",
" <td>-0.978944</td>\n",
" <td>-0.996520</td>\n",
" <td>-0.963668</td>\n",
" <td>-0.977469</td>\n",
" <td>-0.938692</td>\n",
" <td>...</td>\n",
" <td>0.414503</td>\n",
" <td>-0.390748</td>\n",
" <td>-0.760104</td>\n",
" <td>-0.118559</td>\n",
" <td>0.177899</td>\n",
" <td>0.100699</td>\n",
" <td>0.808529</td>\n",
" <td>-0.848933</td>\n",
" <td>0.180637</td>\n",
" <td>-0.049118</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.279174</td>\n",
" <td>-0.026201</td>\n",
" <td>-0.123283</td>\n",
" <td>-0.996091</td>\n",
" <td>-0.983403</td>\n",
" <td>-0.990675</td>\n",
" <td>-0.997099</td>\n",
" <td>-0.982750</td>\n",
" <td>-0.989302</td>\n",
" <td>-0.938692</td>\n",
" <td>...</td>\n",
" <td>0.404573</td>\n",
" <td>-0.117290</td>\n",
" <td>-0.482845</td>\n",
" <td>-0.036788</td>\n",
" <td>-0.012892</td>\n",
" <td>0.640011</td>\n",
" <td>-0.485366</td>\n",
" <td>-0.848649</td>\n",
" <td>0.181935</td>\n",
" <td>-0.047663</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.276629</td>\n",
" <td>-0.016570</td>\n",
" <td>-0.115362</td>\n",
" <td>-0.998139</td>\n",
" <td>-0.980817</td>\n",
" <td>-0.990482</td>\n",
" <td>-0.998321</td>\n",
" <td>-0.979672</td>\n",
" <td>-0.990441</td>\n",
" <td>-0.942469</td>\n",
" <td>...</td>\n",
" <td>0.087753</td>\n",
" <td>-0.351471</td>\n",
" <td>-0.699205</td>\n",
" <td>0.123320</td>\n",
" <td>0.122542</td>\n",
" <td>0.693578</td>\n",
" <td>-0.615971</td>\n",
" <td>-0.847865</td>\n",
" <td>0.185151</td>\n",
" <td>-0.043892</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 561 columns</p>\n",
"</div>"
],
"text/plain": [
" tBodyAcc-mean()-X tBodyAcc-mean()-Y tBodyAcc-mean()-Z tBodyAcc-std()-X \\\n",
"0 0.288585 -0.020294 -0.132905 -0.995279 \n",
"1 0.278419 -0.016411 -0.123520 -0.998245 \n",
"2 0.279653 -0.019467 -0.113462 -0.995380 \n",
"3 0.279174 -0.026201 -0.123283 -0.996091 \n",
"4 0.276629 -0.016570 -0.115362 -0.998139 \n",
"\n",
" tBodyAcc-std()-Y tBodyAcc-std()-Z tBodyAcc-mad()-X tBodyAcc-mad()-Y \\\n",
"0 -0.983111 -0.913526 -0.995112 -0.983185 \n",
"1 -0.975300 -0.960322 -0.998807 -0.974914 \n",
"2 -0.967187 -0.978944 -0.996520 -0.963668 \n",
"3 -0.983403 -0.990675 -0.997099 -0.982750 \n",
"4 -0.980817 -0.990482 -0.998321 -0.979672 \n",
"\n",
" tBodyAcc-mad()-Z tBodyAcc-max()-X ... fBodyBodyGyroJerkMag-meanFreq() \\\n",
"0 -0.923527 -0.934724 ... -0.074323 \n",
"1 -0.957686 -0.943068 ... 0.158075 \n",
"2 -0.977469 -0.938692 ... 0.414503 \n",
"3 -0.989302 -0.938692 ... 0.404573 \n",
"4 -0.990441 -0.942469 ... 0.087753 \n",
"\n",
" fBodyBodyGyroJerkMag-skewness() fBodyBodyGyroJerkMag-kurtosis() \\\n",
"0 -0.298676 -0.710304 \n",
"1 -0.595051 -0.861499 \n",
"2 -0.390748 -0.760104 \n",
"3 -0.117290 -0.482845 \n",
"4 -0.351471 -0.699205 \n",
"\n",
" angle(tBodyAccMean,gravity) angle(tBodyAccJerkMean),gravityMean) \\\n",
"0 -0.112754 0.030400 \n",
"1 0.053477 -0.007435 \n",
"2 -0.118559 0.177899 \n",
"3 -0.036788 -0.012892 \n",
"4 0.123320 0.122542 \n",
"\n",
" angle(tBodyGyroMean,gravityMean) angle(tBodyGyroJerkMean,gravityMean) \\\n",
"0 -0.464761 -0.018446 \n",
"1 -0.732626 0.703511 \n",
"2 0.100699 0.808529 \n",
"3 0.640011 -0.485366 \n",
"4 0.693578 -0.615971 \n",
"\n",
" angle(X,gravityMean) angle(Y,gravityMean) angle(Z,gravityMean) \n",
"0 -0.841247 0.179941 -0.058627 \n",
"1 -0.844788 0.180289 -0.054317 \n",
"2 -0.848933 0.180637 -0.049118 \n",
"3 -0.848649 0.181935 -0.047663 \n",
"4 -0.847865 0.185151 -0.043892 \n",
"\n",
"[5 rows x 561 columns]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.head()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA6gAAAIMCAYAAAD4u4FkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/GU6VOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAC/x0lEQVR4nOzdeXxU1f3/8de9M9khCwQCgUjCDrJKEHFlE9yXWsUVRUVFccPWn1oFUautVapWrN+qCFarVuuOxQUBtSA4EFbZN4GQkABJIPvMvb8/JnOTIQkkMTEJvJ+PxzwgZ86c+XzuuffOnDl3MWzbthERERERERFpZGZjByAiIiIiIiICGqCKiIiIiIhIE6EBqoiIiIiIiDQJGqCKiIiIiIhIk6ABqoiIiIiIiDQJGqCKiIiIiIhIk6ABqoiIiIiIiDQJGqCKiIiIiIhIk6ABqoiIiIiIiDQJGqCKSLNkGAbDhg1r9u9RU9u3b8cwDG644YbGDkWk2Zs1axaGYTBr1qzGDuWo6iPW5pSviIgGqCLHkWXLlnHTTTfRrVs3oqKiiIiIoEuXLlx33XV89dVXjR3er27YsGEYhtHYYTQZWh5+jz76KIZhVPu45JJLfpU4brjhBgzDYPv27b/K+0nNfPvtt8668N577/3i9hrzx6em9COciEiAu7EDEJGGZ1kWv/vd7/jrX/+K2+1mxIgRXHTRRYSEhLB161bmzJnDm2++yWOPPcYjjzzS2OE2GevWrSMyMrKxw5BGctlll9GnT59K5T179myEaKSpeO211wD/4G7mzJlcfvnlDfp+l156Kaeccgrt27dv1DZERH4tGqCKHAcefvhh/vrXvzJgwADef/99unTpEvR8YWEhL774Ivv27WukCJsmDUSOb7/97W+58sorGzsMaULy8vJ4//336devHwkJCXz55Zfs3LmTpKSkBnvPmJgYYmJiGr0NEZFfiw7xFTnGbd68maeffprWrVszd+7cSoNTgIiICH7/+98zbdq0oPLs7GzuueceUlJSCAsLo23btlxxxRWsWbOmUhuBwxG3bt3Ks88+S+/evQkLC3MOW0tOTiY5OZmcnBwmTZpEUlISbrc76JyoVatWceWVV9K+fXtCQ0Pp1KkTd955Z40Hzhs3buT+++/npJNOonXr1oSHh9O9e3ceeOABDh06FFTXMAwWLlzo/D/wqHiYXXWHv9VluWzbto0XXniBnj17EhYWRqdOnZg2bRqWZdUot4rWrl3L+eefT2xsLC1atGD06NEsW7asyroHDx5k6tSpnHjiiURERBAbG8uYMWP4/vvva7w8LMuidevWlWYT9+/fj2maGIbB119/XWXeO3bsCCqvbR/XtH7FwyQ3b97MpZdeSlxcHFFRUYwaNYqVK1cefcHWgW3bzJw5k9NOO43o6GgiIyNJTU1l5syZleqmp6czdepUTjnlFNq2bUtYWBjJycncfvvt7N27N6hucnIys2fPBiAlJcXpj8D6eLTDQqtadwOHcBcVFfHwww/TpUsXQkJCePTRR50627Zt4+abb+aEE04gLCyM9u3bc8MNN1TqR4Dly5fz29/+1qnbpk0bBg8ezB//+McaLbvabK8V4y8tLeXRRx8lOTmZsLAwunfvzksvvVTle+zfv5/bbruNhIQEIiMjGTx4MB9++GGN4qvK22+/TUFBAePGjWPcuHFYlnXE8zr37t3LfffdR48ePYiIiKBVq1YMGTKEZ555BvCfG5qSkgLA7Nmzg7a9BQsWOHUqnj9aUFBAy5Ytq9yXB/Tr14+IiAjy8vKqbGPBggXO4fwLFy4Met9Zs2bx6quvYhgGTz/9dJXtf/PNNxiGwa233lrTRSciUmOaQRU5xs2aNQufz8ett95KQkLCEeuGhYU5/8/KymLo0KFs2bKFYcOGceWVV7Jt2zbef/995syZwxdffMHpp59eqY0777yTH374gfPPP58LL7yQtm3bOs8VFxczYsQIDh06xEUXXYTb7XZi+uSTT7jiiiswTZOLL76YpKQkfvrpJ1588UW++OILlixZQlxc3BHj/+CDD3jttdcYPnw4w4YNw7IsfvjhB/785z+zcOFCvv32W0JCQgCYOnUqs2bNYseOHUydOtVpY8CAAUd8j7oul9///vcsXLiQCy64gDFjxvDRRx/x6KOPUlJSUuMv9ABbt27ltNNO46STTmLixIns2LGD9957jzPPPJNvvvmGIUOGOHX379/PmWeeydq1aznttNO47bbbyMvL4+OPP2b48OG89957zvmUR1oepmly1lln8eGHH7J3716nTxcuXIht2wDMnz+fUaNGOa+bP38+KSkpdOrUySmrbR/XZZ3Yvn07p5xyCieeeCI33ngjW7ZscfJdt27dUbeB2rBtm2uuuYa3336bbt26cfXVVxMaGspXX33FTTfdxE8//eQMRMB/7uKzzz7LyJEjGTJkCCEhIaSlpfH3v/+dL774guXLlzuzXPfccw+zZs1i5cqV3H333cTGxgL+gesvddlll7Fy5UrOOeccYmNjnQHSkiVLGDNmDPn5+VxwwQV069aN7du389Zbb/Hf//6XxYsX07lzZwBWrFjBqaeeisvl4uKLL6ZTp07k5OTw008/8Y9//IM//OEPR42jNttrRVdddRVLly7l3HPPxeVy8e9//5s77riDkJAQJkyY4NQrKChg2LBhrF69mqFDh3LWWWexc+dOxo4dy+jRo+u07F577TVcLhfXXHMN0dHRTJw4kddff52HH3640vnbGzZsYPjw4ezZs4fTTz+dSy65hPz8fNauXcuTTz7J7373OwYMGMDdd9/N888/T//+/YPOb66uryMjI7nsssuYPXs2ixYt4tRTTw16fuXKlaxevZqxY8cSHR1dZRvJyclMnTqVadOm0alTp6AfOgYMGEC3bt247777eO2117j//vsrvf6VV14BCFreIiL1xhaRY9qwYcNswP76669r9brx48fbgP3ggw8Glc+ZM8cG7K5du9o+n88pv/76623A7tixo71jx45K7XXq1MkG7DFjxtgFBQVBz2VnZ9vR0dF2hw4d7O3btwc99/bbb9uAPWnSpKBywD7rrLOCynbt2mUXFxdXeu9p06bZgP3mm28GlZ911ln2kXaDVb1HXZdLSkqKnZ6e7pRnZWXZsbGxdsuWLauM+XDbtm2zARuwH3jggaDn5s6dawN23759g8qvvvpqG7BfeeWVoPLMzEw7KSnJbtOmjV1YWOiUH2l5vPDCCzZgv/vuu07ZnXfeaUdFRdmnnHKKPXToUKd8y5YtNmDfeOONTllt+7i29Ssunz/96U9B9R9++GEbsJ966qkqczvc1KlTbcC+7LLL7KlTp1Z6BJbZP/7xDxuwx48fb5eUlDivLy4uti+88EIbsD0ej1OemZlpHzx4sNL7zZ492wbsJ554Iqg8sO5s27at0msC+V5//fVV5lDVuhvo3wEDBtj79u0Leq6kpMROTk62W7ZsaS9fvjzoue+++852uVz2BRdc4JRNnjzZBuyPPvqo0ntnZ2dXGdPh6rq9DhkyxM7NzXXK169fb7vdbrtHjx5B9QP9OGHChKDywPYC2K+//nqNYrVt2161apWzDwsYN25ctfvX1NRUG7D/8Y9/VHpu586dzv+P1pevv/56pVi//vprG7AnTpxYqf59991nA/Znn312xDZsu+r1JGDixIk2YC9YsCCofN++fXZYWJg9YMCAKl8nIvJLaYAqcozr2bOnDdjr16+v8WuKi4vt8PBwu3Xr1nZ+fn6l588++2wbsL/99lunLPBl+vnnn6+yzcAAdeXKlZWemz59ug3Yb7zxRpWvPemkk+z4+PigsiN9sTrcvn37bMC+4YYbgsprO0D9Jctl5syZleoHnlu1atVRcwh8iY2Nja1ykDNy5MigAVFWVpbtcrnsESNGVNleYMD56aefOmVHWh6rV6+2AfvWW291yvr06WOPGTPGnjJliu12u524Xn311Ur9Wds+rm39wPJJSUkJ+oGg4nO/+c1vqmzrcIGBTXWPAwcO2LZt2/369bOjoqIq/eBi2+WDmfvuu++o72dZlh0
"text/plain": [
"<Figure size 1200x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Feature Selection\n",
"feature_correlations = X_train.corrwith(y_train['Activity'])\n",
"\n",
"plt.figure(figsize=(12, 6))\n",
"bars = plt.bar(range(len(feature_correlations)), feature_correlations,\n",
" color=plt.cm.coolwarm((feature_correlations - feature_correlations.min()) /\n",
" (feature_correlations.max() - feature_correlations.min())))\n",
"\n",
"plt.title('Correlation between Features and Activity', fontsize=14)\n",
"plt.xlabel('Features')\n",
"plt.ylabel('Correlation')\n",
"\n",
"plt.xlim([0, len(feature_correlations)])\n",
"plt.ylim([-1, 1])\n",
"plt.xticks([])\n",
"\n",
"plt.colorbar(\n",
"ax=plt.gca(),\n",
"mappable=plt.cm.ScalarMappable(cmap='coolwarm'),\n",
"orientation='vertical',\n",
")\n",
"\n",
"plt.grid(axis='y', linestyle='--', alpha=0.7)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Selecting features with correlation > 0.7 or < -0.75\n",
"positive_correlations = feature_correlations[feature_correlations > 0.7]\n",
"negative_correlations = feature_correlations[feature_correlations < -0.75]\n",
"\n",
"# We only work with the top 20 features\n",
"selected_features = pd.concat([positive_correlations, negative_correlations]).sort_values(ascending=False)\n",
"selected_features = selected_features.head(20)\n",
"\n",
"feature_indices = selected_features.index"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Updating data\n",
"X_train = X_train[feature_indices]\n",
"X_test = X_test[feature_indices]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Creating the models\n",
"models = {\n",
" 'Random Forest': RandomForestClassifier(n_estimators=100, random_state=0),\n",
" 'K-Nearest Neighbors': KNeighborsClassifier(n_neighbors=5),\n",
" 'Decision Tree': DecisionTreeClassifier(random_state=0)\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# Training the models\n",
"for model_name, model in models.items():\n",
" model.fit(X_train, y_train.values.ravel())"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# Evaluating the models\n",
"model_scores = {}\n",
"for model_name, model in models.items():\n",
" y_pred = model.predict(X_test)\n",
" model_scores[model_name] = f1_score(y_test, y_pred, average='weighted')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA+kAAAIRCAYAAAA2i/y/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/GU6VOAAAACXBIWXMAAA9hAAAPYQGoP6dpAABUOklEQVR4nO3deVxU9f7H8feZGQFFEVfcSFxyK8UFNTNzycIsyzbNLM1M66alaWWruNwyu2m267XUureb3spc0lwitVKvO5ql5q5pqGRKogIz8/394Y+BkUFBQY7xej4ePGo+8z1nvp+Z4XjenDNnLGOMEQAAAAAAKHSOwp4AAAAAAAA4g5AOAAAAAIBNENIBAAAAALAJQjoAAAAAADZBSAcAAAAAwCYI6QAAAAAA2AQhHQAAAAAAmyCkAwAAAABgE4R0AAAAAABsgpAOAEAhmzZtmizL0rRp0y5qPZZlqV27dvkyJwAAUDgI6QCAImfPnj2yLEuWZalSpUpyu90Bx23ZssU3Lioq6tJOsoAtXbrU11ugn/DwcL/xc+fO1eOPP67WrVsrNDRUlmVpxIgRF/TY+/bt02OPPaYrr7xSISEhKlmypGrUqKFbbrlFY8eOVUpKysU3CADAZcpV2BMAAKCwuFwuHTp0SPPnz9dtt92W7f4PP/xQDsdf++/ZzZo106233pqtHhIS4nd73LhxWrZsmcLCwlSlShXt2LHjgh5v48aNateunY4dO6bWrVvr5ptvVsmSJbVv3z59//33mj9/vu666y7Vrl37gtYPAMDljpAOACiyrr32Wm3cuFFTpkzJFtLdbrf+/e9/q2PHjlq2bFkhzbDgxcTE5OqI+OjRo1WpUiXVrl1bM2bMUI8ePS7o8YYMGaJjx47p448/1gMPPJDt/pUrV6p8+fIXtG4AAP4K/tqHBwAAOIfixYvr3nvv1bx583T48GG/+7766isdOnRIDz30UI7Lp6SkKC4uTvXq1VNISIjKli2rW265RcuXLw84/ujRo3r00UcVERGhEiVKqHnz5vryyy/POcdNmzbp3nvvVeXKlRUUFKTq1avr8ccf1++//573hi9CmzZtdOWVV8qyrItaz8qVKxUeHh4woEtSq1atsp1qL505At+zZ09Vq1ZNwcHBqly5sjp16qS5c+f6jXO73Ro/fryio6NVvHhxlS5dWu3bt882TvK/FsDcuXPVunVrlSpVyu+jDWlpaRo/fryaNm2q0NBQlSpVSm3atNGcOXOyre/48eMaPny4GjRooJIlSyosLEy1a9dW7969tXfv3rw9UQCAIouQDgAo0h566CG53W7961//8qtPmTJFZcuWVdeuXQMud/r0aXXo0EGjRo1SaGioBg8erNtvv11LlixR27Zt9dlnn/mNP3nypNq1a6dJkyapVq1aGjRokOrWravu3bvr888/D/gYc+bMUYsWLTRnzhy1a9dOgwcPVsOGDfXOO++oVatW+uOPP/LlObiUypUrpxMnTujgwYO5XuaLL75QixYt9Nlnn6lly5YaOnSobrnlFh04cEAffvihb5wxRnfffbeGDh2q06dPa8CAAbrvvvu0ceNG3XbbbXrjjTcCrv+zzz7TnXfeqYoVK+qxxx7TzTffLElKTU1VbGyshg4dKmOM+vbtq/vvv1979+7V7bffrnfeecfvsWNjYzV69GiVLVtW/fv3V//+/dWkSRPNmTNH27dvv8BnDABQ5BgAAIqY3bt3G0kmNjbWGGPM1Vdfba666irf/b/99ptxuVzm8ccfN8YYExwcbKpXr+63jpEjRxpJpmfPnsbr9frq69evN0FBQSY8PNwkJyf76nFxcUaS6devn996FixYYCQZSWbq1Km+elJSkgkLCzNVq1Y1e/bs8Vvm008/NZLMwIED/eqSTNu2bXP1HCxZssRIMs2aNTNxcXHZfrZs2ZLjshmPHxcXl6vHymrIkCFGkqlRo4YZO3asWbFihUlJSclxfGJiogkNDTWhoaFm/fr12e7fv3+/7/8/+ugj33OQmprqq+/du9eUL1/euFwus3PnTl996tSpRpJxOBxm8eLF2db9/PPPG0nmpZde8nuNk5OTTUxMjAkKCjIHDhwwxhizadMmI8l07do123pOnz5t/vzzz/M8MwAAnEFIBwAUOWeH9PHjxxtJ5n//+58xxphXX33VSDIbNmwwxgQO6TVr1jTFihXzC4kZ+vXrZySZjz/+2FerUaOGCQoKMr/99lu28TfccEO2kJ4xp6zryKpp06amfPnyfrULCek5/Xz55Zc5LnsxIf3UqVPmwQcfNA6Hw/dYTqfTNG3a1IwePdr88ccffuPHjh1rJJnhw4efd90dOnQwksyqVauy3ffyyy8bSWbUqFG+WkZIv+OOO7KN93g8pkyZMqZWrVp+AT3DnDlzjCTz9ttvG2MyQ3qPHj3OO08AAM6FC8cBAIq8+++/X8OGDdOUKVPUsmVLTZ06VU2aNFHjxo0Djk9OTtauXbtUv359VatWLdv97du31+TJk5WQkKAHHnhAycnJ2r17txo0aKBKlSplG9+mTRvFx8f71f73v/9JklatWqWdO3dmW+b06dNKSkpSUlLSRV1o7ZFHHtHEiRMvePm8CgkJ0dSpUzV69GjNnz9fq1ev1urVq7V+/XqtX79ekyZN0rJly1SzZk1J0urVqyVJN91003nXvWHDBpUoUUItWrTIdl/79u0lSQkJCdnuCzR+27Zt+uOPP1SlShWNHDky2/1HjhyRJG3dulWSVL9+fTVq1Eiffvqpfv31V3Xt2lXt2rVT48aN//LfEAAAyF+EdABAkVehQgV16dJF06dP1z333KNt27bp7bffznF8cnKyJCkiIiLg/ZUrV/Ybl/HfihUrBhwfaD1Hjx6VJL377rvnnHtKSspleTX0atWq+T63LUk7d+7UQw89pO+++05PPvmkZs+eLenMxdgkqWrVquddZ3JysiIjIwPed/ZrktW5nv+ffvpJP/30U46PmfGd7i6XS99++61GjBihL774QkOHDpV05r01cOBAvfDCC3I6neftAQAA/rQLAICkvn37Kjk5WQ8++KBCQkLUs2fPHMeGhYVJkg4dOhTw/sTERL9xGf89+wryGQKtJ2OZH3/8UebMx9MC/lSvXj2XHdpbrVq1NG3aNEnSt99+66tnXOn9wIED511HWFhYjs/x2a9JVoGuWJ8x7q677jrn8z916lTfMuXKldPbb7+tAwcO6Oeff9Y777yjsmXLKi4uTq+99tp55w8AgERIBwBAkhQbG6uqVavqwIED6tq1q8qUKZPj2LCwMNWsWVM7duwIGB6XLl0qSb7T5cPCwlSjRg3t2LHDFxaz+v7777PVWrZsKenMV5YVFSVLlsxWyzgVfdGiReddvkmTJjp58qTvFPmszn5Nzqd+/foKCwvT2rVrlZ6enqtlMliWpfr162vAgAFavHixJAX8yjYAAAIhpAMAIMnpdGrWrFn68ssvNWbMmPOO7927t9LT0/Xcc8/JGOOrb9q0SdOmTVPp0qX9vr7tgQceUFpamoYPH+63nkWLFmX7PLok9enTR6VKldILL7wQ8HTrkydP+j63fjkZNWqU9u/fn61ujNGrr74qSbruuut89d69e6tkyZIaN25cwM+TZ/0jSe/evSVJzz33nF+w3r9/v8aPHy+Xy3XOMySycrlc+tvf/qa9e/fqqaeeChjUN2/e7Dtyv2fPHu3ZsyfbmIyzJEJCQnL1uAAA8Jl0AAD+X0xMjGJiYnI19plnntG8efP0r3/9S1u2bNENN9ygw4cPa8aMGXK73Zo8ebJKlSrlN37mzJmaPHmyfvrpJ11//fXav3+//vvf/+qWW27RvHnz/NZfoUIFffrpp7rnnnsUHR2tTp06qV69ekpNTdWePXu0bNkyXXvttVqwYEG+Pgc5mTVrlmbNmiVJ2r17t6+WEUzr1aunZ5999rzrGT9+vEaMGKGYmBg1a9ZMZcuW1e+//64lS5bol19+Ubly5TRu3Djf+IoVK+rjjz/WvffeqxYtWui2225T3bp1lZSUpFWrVikqKso3rwceeEAzZ87U7Nmz1ahRI91
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Visualization of results\n",
"model_scores = pd.Series(model_scores\n",
" ).sort_values(ascending=False)\n",
"\n",
"plt.figure(figsize=(12, 6))\n",
"bars = plt.bar(model_scores.index, model_scores,\n",
" color=plt.cm.coolwarm((model_scores - model_scores.min()) /\n",
" (model_scores.max() - model_scores.min())))\n",
"plt.title('Model F1 Scores', fontsize=14)\n",
"plt.ylabel('F1 Score')\n",
"plt.ylim([0, 1])\n",
"plt.grid(axis='y', linestyle='--', alpha=0.7)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The best model is \"Random Forest\" with an F1 score of 0.6806382183305906\n"
]
}
],
"source": [
"print(f'The best model is \"{model_scores.idxmax()}\" with an F1 score of {model_scores.max()}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# For Debugging: Export data to csv\n",
"data = pd.concat([y_train, X_train], axis=1)\n",
"data.to_csv('datasets/data.csv', index=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}