diff --git a/README.md b/README.md
index 95aea12..adf6d72 100644
--- a/README.md
+++ b/README.md
@@ -39,7 +39,7 @@
| Группа | Дата |
| :--- | :---: |
-| А-01-19 | 06.03.2023 |
+| А-01-19 | 13.03.2023 |
| А-03-19 | 06.03.2023 |
* [Задание](labs/OATD_LR3.md)
diff --git a/labs/OATD_LR3.md b/labs/OATD_LR3.md
index f1fded4..efec3be 100644
--- a/labs/OATD_LR3.md
+++ b/labs/OATD_LR3.md
@@ -63,7 +63,7 @@
Обратить внимание, что разные виды регуляризации работают с разными методами нахождения экстремума.
**Метод опорных векторов (SVM):**
-* функция потерь (параметр loss: ‘hinge’, ‘squared_hinge’),
+* функция потерь (параметр kernel: ‘linear’, ‘rbf’),
* регуляризация (параметр penalty: ‘L1’, ‘L2’)
Обратить внимание, что разные виды регуляризации работают с разными функциями потерь
diff --git a/lections/notebooks/lec5_text2vec_classifier.ipynb b/lections/notebooks/lec5_text2vec_classifier.ipynb
index 8d2999f..d202ef1 100644
--- a/lections/notebooks/lec5_text2vec_classifier.ipynb
+++ b/lections/notebooks/lec5_text2vec_classifier.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 1,
"id": "3dda6a69",
"metadata": {},
"outputs": [],
@@ -18,7 +18,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 2,
"id": "7fd6636b",
"metadata": {},
"outputs": [
@@ -45,7 +45,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 3,
"id": "be870586",
"metadata": {},
"outputs": [],
@@ -55,7 +55,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 4,
"id": "599d6406",
"metadata": {},
"outputs": [
@@ -84,7 +84,7 @@
" ('lady', 0.8845519423484802)]"
]
},
- "execution_count": 7,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -96,7 +96,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 5,
"id": "2db71cfb",
"metadata": {},
"outputs": [
@@ -106,7 +106,7 @@
"0.60927683"
]
},
- "execution_count": 8,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -117,7 +117,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 6,
"id": "7788acf5",
"metadata": {},
"outputs": [],
@@ -139,7 +139,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 7,
"id": "0565dd1a",
"metadata": {},
"outputs": [],
@@ -149,7 +149,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 8,
"id": "a681a1d6",
"metadata": {},
"outputs": [
@@ -355,7 +355,7 @@
"[5 rows x 23297 columns]"
]
},
- "execution_count": 11,
+ "execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -369,7 +369,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 9,
"id": "b20aef46",
"metadata": {},
"outputs": [
@@ -410,10 +410,129 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 10,
"id": "0d6af65a",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 00 | \n",
+ " 000 | \n",
+ " 0000 | \n",
+ " 00000 | \n",
+ " 000000 | \n",
+ " 000005102000 | \n",
+ " 000062david42 | \n",
+ " 000100255pixel | \n",
+ " 00041032 | \n",
+ " 0004136 | \n",
+ " ... | \n",
+ " zurbrin | \n",
+ " zurich | \n",
+ " zus | \n",
+ " zvi | \n",
+ " zwaartepunten | \n",
+ " zwak | \n",
+ " zwakke | \n",
+ " zware | \n",
+ " zwarte | \n",
+ " zyxel | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " ... | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " ... | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
2 rows × 23297 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 00 000 0000 00000 000000 000005102000 000062david42 000100255pixel \\\n",
+ "0 0 0 0 0 0 0 0 0 \n",
+ "1 1 1 0 0 0 0 0 0 \n",
+ "\n",
+ " 00041032 0004136 ... zurbrin zurich zus zvi zwaartepunten zwak \\\n",
+ "0 0 0 ... 0 0 0 0 0 0 \n",
+ "1 0 0 ... 0 0 0 0 0 0 \n",
+ "\n",
+ " zwakke zware zwarte zyxel \n",
+ "0 0 0 0 0 \n",
+ "1 0 0 0 1 \n",
+ "\n",
+ "[2 rows x 23297 columns]"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"text_data = ['Hello world I love python', 'This is a great computer game! 00 000 zyxel']\n",
"# Векторизуем с помощью обученного CountVectorizer\n",
@@ -424,10 +543,82 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"id": "11dda58a",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "hello : [-0.77069 0.12827 0.33137 0.0050893 -0.47605 -0.50116\n",
+ " 1.858 1.0624 -0.56511 0.13328 -0.41918 -0.14195\n",
+ " -2.8555 -0.57131 -0.13418 -0.44922 0.48591 -0.6479\n",
+ " -0.84238 0.61669 -0.19824 -0.57967 -0.65885 0.43928\n",
+ " -0.50473 ]\n",
+ "love : [-0.62645 -0.082389 0.070538 0.5782 -0.87199 -0.14816 2.2315\n",
+ " 0.98573 -1.3154 -0.34921 -0.8847 0.14585 -4.97 -0.73369\n",
+ " -0.94359 0.035859 -0.026733 -0.77538 -0.30014 0.48853 -0.16678\n",
+ " -0.016651 -0.53164 0.64236 -0.10922 ]\n",
+ "python : [-0.25645 -0.22323 0.025901 0.22901 0.49028 -0.060829 0.24563\n",
+ " -0.84854 1.5882 -0.7274 0.60603 0.25205 -1.8064 -0.95526\n",
+ " 0.44867 0.013614 0.60856 0.65423 0.82506 0.99459 -0.29403\n",
+ " -0.27013 -0.348 -0.7293 0.2201 ]\n",
+ "world : [ 0.10301 0.095666 -0.14789 -0.22383 -0.14775 -0.11599 1.8513\n",
+ " 0.24886 -0.41877 -0.20384 -0.08509 0.33246 -4.6946 0.84096\n",
+ " -0.46666 -0.031128 -0.19539 -0.037349 0.58949 0.13941 -0.57667\n",
+ " -0.44426 -0.43085 -0.52875 0.25855 ]\n",
+ "Hello world I love python : [ -1.55058002 -0.081683 0.27991899 0.58846928 -1.00551002\n",
+ " -0.82613902 6.18642995 1.44844997 -0.71108004 -1.14717001\n",
+ " -0.78294002 0.58841 -14.32649982 -1.41929996 -1.09575997\n",
+ " -0.430875 0.87234702 -0.806399 0.27203003 2.23921998\n",
+ " -1.23571999 -1.31071102 -1.96934 -0.17641005 -0.1353 ]\n",
+ "computer : [ 0.64005 -0.019514 0.70148 -0.66123 1.1723 -0.58859 0.25917\n",
+ " -0.81541 1.1708 1.1413 -0.15405 -0.11369 -3.8414 -0.87233\n",
+ " 0.47489 1.1541 0.97678 1.1107 -0.14572 -0.52013 -0.52234\n",
+ " -0.92349 0.34651 0.061939 -0.57375 ]\n",
+ "game : [ 1.146 0.3291 0.26878 -1.3945 -0.30044 0.77901 1.3537\n",
+ " 0.37393 0.50478 -0.44266 -0.048706 0.51396 -4.3136 0.39805\n",
+ " 1.197 0.10287 -0.17618 -1.2881 -0.59801 0.26131 -1.2619\n",
+ " 0.39202 0.59309 -0.55232 0.005087]\n",
+ "great : [-8.4229e-01 3.6512e-01 -3.8841e-01 -4.6118e-01 2.4301e-01 3.2412e-01\n",
+ " 1.9009e+00 -2.2630e-01 -3.1335e-01 -1.0970e+00 -4.1494e-03 6.2074e-01\n",
+ " -5.0964e+00 6.7418e-01 5.0080e-01 -6.2119e-01 5.1765e-01 -4.4122e-01\n",
+ " -1.4364e-01 1.9130e-01 -7.4608e-01 -2.5903e-01 -7.8010e-01 1.1030e-01\n",
+ " -2.7928e-01]\n",
+ "zyxel : [ 0.79234 0.067376 -0.22639 -2.2272 0.30057 -0.85676 -1.7268\n",
+ " -0.78626 1.2042 -0.92348 -0.83987 -0.74233 0.29689 -1.208\n",
+ " 0.98706 -1.1624 0.61415 -0.27825 0.27813 1.5838 -0.63593\n",
+ " -0.10225 1.7102 -0.95599 -1.3867 ]\n",
+ "This is a great computer game! 00 000 zyxel : [ 1.73610002 0.74208201 0.35545996 -4.74411008 1.41543998\n",
+ " -0.34222007 1.78697008 -1.45404002 2.56643 -1.32184002\n",
+ " -1.04677537 0.27867999 -12.95450976 -1.00809997 3.15975004\n",
+ " -0.52662008 1.93239999 -0.89686999 -0.60924001 1.51628\n",
+ " -3.16624993 -0.89275002 1.86969995 -1.33607102 -2.23464306]\n",
+ "glove_data: 0 1 2 3 4 5 6 7 \\\n",
+ "0 -1.55058 -0.081683 0.279919 0.588469 -1.00551 -0.826139 6.18643 1.44845 \n",
+ "0 1.73610 0.742082 0.355460 -4.744110 1.41544 -0.342220 1.78697 -1.45404 \n",
+ "\n",
+ " 8 9 ... 15 16 17 18 19 \\\n",
+ "0 -0.71108 -1.14717 ... -0.430875 0.872347 -0.806399 0.27203 2.23922 \n",
+ "0 2.56643 -1.32184 ... -0.526620 1.932400 -0.896870 -0.60924 1.51628 \n",
+ "\n",
+ " 20 21 22 23 24 \n",
+ "0 -1.23572 -1.310711 -1.96934 -0.176410 -0.135300 \n",
+ "0 -3.16625 -0.892750 1.86970 -1.336071 -2.234643 \n",
+ "\n",
+ "[2 rows x 25 columns]\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "C:\\Users\\Андрей\\AppData\\Local\\Temp\\ipykernel_29476\\129113310.py:17: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
+ " glove_data=glove_data.append(pd.DataFrame([one_doc]))\n"
+ ]
+ }
+ ],
"source": [
"# Создадим датафрейм, в который будем сохранять вектор документа\n",
"glove_data=pd.DataFrame()\n",
@@ -451,7 +642,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 12,
"id": "ff68d8dc",
"metadata": {},
"outputs": [],
@@ -485,10 +676,129 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 13,
"id": "b778776c",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 7 | \n",
+ " 8 | \n",
+ " 9 | \n",
+ " ... | \n",
+ " 15 | \n",
+ " 16 | \n",
+ " 17 | \n",
+ " 18 | \n",
+ " 19 | \n",
+ " 20 | \n",
+ " 21 | \n",
+ " 22 | \n",
+ " 23 | \n",
+ " 24 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " -1.55058 | \n",
+ " -0.081683 | \n",
+ " 0.279919 | \n",
+ " 0.588469 | \n",
+ " -1.00551 | \n",
+ " -0.826139 | \n",
+ " 6.18643 | \n",
+ " 1.44845 | \n",
+ " -0.71108 | \n",
+ " -1.14717 | \n",
+ " ... | \n",
+ " -0.430875 | \n",
+ " 0.872347 | \n",
+ " -0.806399 | \n",
+ " 0.27203 | \n",
+ " 2.23922 | \n",
+ " -1.23572 | \n",
+ " -1.310711 | \n",
+ " -1.96934 | \n",
+ " -0.176410 | \n",
+ " -0.135300 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " 1.73610 | \n",
+ " 0.742082 | \n",
+ " 0.355460 | \n",
+ " -4.744110 | \n",
+ " 1.41544 | \n",
+ " -0.342220 | \n",
+ " 1.78697 | \n",
+ " -1.45404 | \n",
+ " 2.56643 | \n",
+ " -1.32184 | \n",
+ " ... | \n",
+ " -0.526620 | \n",
+ " 1.932400 | \n",
+ " -0.896870 | \n",
+ " -0.60924 | \n",
+ " 1.51628 | \n",
+ " -3.16625 | \n",
+ " -0.892750 | \n",
+ " 1.86970 | \n",
+ " -1.336071 | \n",
+ " -2.234643 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
2 rows × 25 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2 3 4 5 6 7 \\\n",
+ "0 -1.55058 -0.081683 0.279919 0.588469 -1.00551 -0.826139 6.18643 1.44845 \n",
+ "0 1.73610 0.742082 0.355460 -4.744110 1.41544 -0.342220 1.78697 -1.45404 \n",
+ "\n",
+ " 8 9 ... 15 16 17 18 19 \\\n",
+ "0 -0.71108 -1.14717 ... -0.430875 0.872347 -0.806399 0.27203 2.23922 \n",
+ "0 2.56643 -1.32184 ... -0.526620 1.932400 -0.896870 -0.60924 1.51628 \n",
+ "\n",
+ " 20 21 22 23 24 \n",
+ "0 -1.23572 -1.310711 -1.96934 -0.176410 -0.135300 \n",
+ "0 -3.16625 -0.892750 1.86970 -1.336071 -2.234643 \n",
+ "\n",
+ "[2 rows x 25 columns]"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"\n",
"glove_data\n"
@@ -496,20 +806,412 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 14,
"id": "cb6edbdf",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([ 1.73610002, 0.74208201, 0.35545996, -4.74411008,\n",
+ " 1.41543998, -0.34222007, 1.78697008, -1.45404002,\n",
+ " 2.56643 , -1.32184002, -1.04677537, 0.27867999,\n",
+ " -12.95450976, -1.00809997, 3.15975004, -0.52662008,\n",
+ " 1.93239999, -0.89686999, -0.60924001, 1.51628 ,\n",
+ " -3.16624993, -0.89275002, 1.86969995, -1.33607102,\n",
+ " -2.23464306])"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"one_doc"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 15,
"id": "1bdb459e",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 7 | \n",
+ " 8 | \n",
+ " 9 | \n",
+ " ... | \n",
+ " 15 | \n",
+ " 16 | \n",
+ " 17 | \n",
+ " 18 | \n",
+ " 19 | \n",
+ " 20 | \n",
+ " 21 | \n",
+ " 22 | \n",
+ " 23 | \n",
+ " 24 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " -8.521142 | \n",
+ " 2.020376 | \n",
+ " -10.802921 | \n",
+ " 3.167636 | \n",
+ " 0.252469 | \n",
+ " 15.544048 | \n",
+ " 17.631184 | \n",
+ " -32.581192 | \n",
+ " 9.696540 | \n",
+ " -11.103087 | \n",
+ " ... | \n",
+ " 2.810453 | \n",
+ " 7.900215 | \n",
+ " 0.962129 | \n",
+ " 17.691130 | \n",
+ " -1.252574 | \n",
+ " -10.098049 | \n",
+ " 0.500113 | \n",
+ " 1.348694 | \n",
+ " 2.186150 | \n",
+ " -16.556824 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " 6.576228 | \n",
+ " 20.336350 | \n",
+ " -32.675150 | \n",
+ " -9.073872 | \n",
+ " 17.515655 | \n",
+ " -6.488794 | \n",
+ " 59.458419 | \n",
+ " -75.384298 | \n",
+ " 13.323775 | \n",
+ " -14.443218 | \n",
+ " ... | \n",
+ " 21.407738 | \n",
+ " 23.525118 | \n",
+ " 0.325680 | \n",
+ " 19.871444 | \n",
+ " -27.585188 | \n",
+ " -4.559155 | \n",
+ " -7.417482 | \n",
+ " -16.694553 | \n",
+ " -0.197711 | \n",
+ " -58.948193 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " 1.329914 | \n",
+ " 3.060870 | \n",
+ " -1.868484 | \n",
+ " 1.392735 | \n",
+ " -1.335277 | \n",
+ " -5.014955 | \n",
+ " 12.859476 | \n",
+ " -9.978156 | \n",
+ " -0.869613 | \n",
+ " -2.031490 | \n",
+ " ... | \n",
+ " 2.925134 | \n",
+ " 2.872930 | \n",
+ " 2.184486 | \n",
+ " 3.831770 | \n",
+ " -0.877866 | \n",
+ " -0.927770 | \n",
+ " 0.700101 | \n",
+ " -9.855365 | \n",
+ " -5.419429 | \n",
+ " -2.279330 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " -4.866150 | \n",
+ " -0.273176 | \n",
+ " 3.515124 | \n",
+ " -5.008165 | \n",
+ " -1.236789 | \n",
+ " -7.951168 | \n",
+ " -11.015882 | \n",
+ " -3.496241 | \n",
+ " 16.024286 | \n",
+ " -9.388742 | \n",
+ " ... | \n",
+ " -0.471141 | \n",
+ " 3.575378 | \n",
+ " 6.193222 | \n",
+ " 0.349430 | \n",
+ " 15.040248 | \n",
+ " -10.369132 | \n",
+ " -0.848717 | \n",
+ " -0.564796 | \n",
+ " -1.114126 | \n",
+ " -7.844431 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " -3.115007 | \n",
+ " -1.805252 | \n",
+ " -5.419340 | \n",
+ " -0.393406 | \n",
+ " -0.406461 | \n",
+ " -2.724340 | \n",
+ " 7.898330 | \n",
+ " -15.619113 | \n",
+ " 0.231822 | \n",
+ " -3.628156 | \n",
+ " ... | \n",
+ " 5.944151 | \n",
+ " 8.309932 | \n",
+ " -0.656084 | \n",
+ " 12.178709 | \n",
+ " -6.118551 | \n",
+ " -3.286376 | \n",
+ " 3.450946 | \n",
+ " 2.055343 | \n",
+ " 0.463787 | \n",
+ " -12.644626 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " -0.930954 | \n",
+ " 4.974043 | \n",
+ " -8.147008 | \n",
+ " -5.147130 | \n",
+ " 3.960455 | \n",
+ " -1.344022 | \n",
+ " 7.818063 | \n",
+ " -25.427420 | \n",
+ " 4.624732 | \n",
+ " -7.218097 | \n",
+ " ... | \n",
+ " 3.623038 | \n",
+ " 4.453189 | \n",
+ " 2.405320 | \n",
+ " 8.032963 | \n",
+ " -8.029539 | \n",
+ " 0.838867 | \n",
+ " -4.757457 | \n",
+ " -5.755052 | \n",
+ " -9.496197 | \n",
+ " -21.542710 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " -0.770690 | \n",
+ " 0.128270 | \n",
+ " 0.331370 | \n",
+ " 0.005089 | \n",
+ " -0.476050 | \n",
+ " -0.501160 | \n",
+ " 1.858000 | \n",
+ " 1.062400 | \n",
+ " -0.565110 | \n",
+ " 0.133280 | \n",
+ " ... | \n",
+ " -0.449220 | \n",
+ " 0.485910 | \n",
+ " -0.647900 | \n",
+ " -0.842380 | \n",
+ " 0.616690 | \n",
+ " -0.198240 | \n",
+ " -0.579670 | \n",
+ " -0.658850 | \n",
+ " 0.439280 | \n",
+ " -0.504730 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " 1.491177 | \n",
+ " 6.992638 | \n",
+ " -7.921970 | \n",
+ " -7.157521 | \n",
+ " 6.641657 | \n",
+ " -2.958020 | \n",
+ " 12.820770 | \n",
+ " -18.502946 | \n",
+ " 6.838083 | \n",
+ " -2.717310 | \n",
+ " ... | \n",
+ " -1.344873 | \n",
+ " 4.170405 | \n",
+ " -0.178030 | \n",
+ " 5.699992 | \n",
+ " -7.295038 | \n",
+ " -3.683306 | \n",
+ " -2.718006 | \n",
+ " -0.117608 | \n",
+ " -7.205832 | \n",
+ " -13.863438 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " 2.523770 | \n",
+ " 5.817394 | \n",
+ " 2.184340 | \n",
+ " -2.996497 | \n",
+ " -0.267181 | \n",
+ " -10.059634 | \n",
+ " 6.344402 | \n",
+ " -2.047127 | \n",
+ " 2.679123 | \n",
+ " -7.642505 | \n",
+ " ... | \n",
+ " -1.230296 | \n",
+ " 1.409746 | \n",
+ " -3.322040 | \n",
+ " -5.068259 | \n",
+ " -0.648718 | \n",
+ " 0.753010 | \n",
+ " -6.220990 | \n",
+ " -5.012004 | \n",
+ " -1.518542 | \n",
+ " -10.156440 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " -0.118691 | \n",
+ " 11.860546 | \n",
+ " -2.567264 | \n",
+ " -10.955913 | \n",
+ " -4.239322 | \n",
+ " -9.340552 | \n",
+ " 21.189778 | \n",
+ " -10.895375 | \n",
+ " 2.659030 | \n",
+ " -3.848115 | \n",
+ " ... | \n",
+ " 0.726191 | \n",
+ " 11.634998 | \n",
+ " -5.447248 | \n",
+ " 1.293007 | \n",
+ " -7.882002 | \n",
+ " -2.527453 | \n",
+ " 0.298939 | \n",
+ " -6.107062 | \n",
+ " 3.365051 | \n",
+ " -15.641826 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
1657 rows × 25 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2 3 4 5 \\\n",
+ "0 -8.521142 2.020376 -10.802921 3.167636 0.252469 15.544048 \n",
+ "0 6.576228 20.336350 -32.675150 -9.073872 17.515655 -6.488794 \n",
+ "0 1.329914 3.060870 -1.868484 1.392735 -1.335277 -5.014955 \n",
+ "0 -4.866150 -0.273176 3.515124 -5.008165 -1.236789 -7.951168 \n",
+ "0 -3.115007 -1.805252 -5.419340 -0.393406 -0.406461 -2.724340 \n",
+ ".. ... ... ... ... ... ... \n",
+ "0 -0.930954 4.974043 -8.147008 -5.147130 3.960455 -1.344022 \n",
+ "0 -0.770690 0.128270 0.331370 0.005089 -0.476050 -0.501160 \n",
+ "0 1.491177 6.992638 -7.921970 -7.157521 6.641657 -2.958020 \n",
+ "0 2.523770 5.817394 2.184340 -2.996497 -0.267181 -10.059634 \n",
+ "0 -0.118691 11.860546 -2.567264 -10.955913 -4.239322 -9.340552 \n",
+ "\n",
+ " 6 7 8 9 ... 15 16 \\\n",
+ "0 17.631184 -32.581192 9.696540 -11.103087 ... 2.810453 7.900215 \n",
+ "0 59.458419 -75.384298 13.323775 -14.443218 ... 21.407738 23.525118 \n",
+ "0 12.859476 -9.978156 -0.869613 -2.031490 ... 2.925134 2.872930 \n",
+ "0 -11.015882 -3.496241 16.024286 -9.388742 ... -0.471141 3.575378 \n",
+ "0 7.898330 -15.619113 0.231822 -3.628156 ... 5.944151 8.309932 \n",
+ ".. ... ... ... ... ... ... ... \n",
+ "0 7.818063 -25.427420 4.624732 -7.218097 ... 3.623038 4.453189 \n",
+ "0 1.858000 1.062400 -0.565110 0.133280 ... -0.449220 0.485910 \n",
+ "0 12.820770 -18.502946 6.838083 -2.717310 ... -1.344873 4.170405 \n",
+ "0 6.344402 -2.047127 2.679123 -7.642505 ... -1.230296 1.409746 \n",
+ "0 21.189778 -10.895375 2.659030 -3.848115 ... 0.726191 11.634998 \n",
+ "\n",
+ " 17 18 19 20 21 22 23 \\\n",
+ "0 0.962129 17.691130 -1.252574 -10.098049 0.500113 1.348694 2.186150 \n",
+ "0 0.325680 19.871444 -27.585188 -4.559155 -7.417482 -16.694553 -0.197711 \n",
+ "0 2.184486 3.831770 -0.877866 -0.927770 0.700101 -9.855365 -5.419429 \n",
+ "0 6.193222 0.349430 15.040248 -10.369132 -0.848717 -0.564796 -1.114126 \n",
+ "0 -0.656084 12.178709 -6.118551 -3.286376 3.450946 2.055343 0.463787 \n",
+ ".. ... ... ... ... ... ... ... \n",
+ "0 2.405320 8.032963 -8.029539 0.838867 -4.757457 -5.755052 -9.496197 \n",
+ "0 -0.647900 -0.842380 0.616690 -0.198240 -0.579670 -0.658850 0.439280 \n",
+ "0 -0.178030 5.699992 -7.295038 -3.683306 -2.718006 -0.117608 -7.205832 \n",
+ "0 -3.322040 -5.068259 -0.648718 0.753010 -6.220990 -5.012004 -1.518542 \n",
+ "0 -5.447248 1.293007 -7.882002 -2.527453 0.298939 -6.107062 3.365051 \n",
+ "\n",
+ " 24 \n",
+ "0 -16.556824 \n",
+ "0 -58.948193 \n",
+ "0 -2.279330 \n",
+ "0 -7.844431 \n",
+ "0 -12.644626 \n",
+ ".. ... \n",
+ "0 -21.542710 \n",
+ "0 -0.504730 \n",
+ "0 -13.863438 \n",
+ "0 -10.156440 \n",
+ "0 -15.641826 \n",
+ "\n",
+ "[1657 rows x 25 columns]"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"train_data_glove = text2vec(twenty_train['data']);\n",
"train_data_glove"
@@ -517,17 +1219,29 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 16,
"id": "3a7ea7c6",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "<1657x23297 sparse matrix of type ''\n",
+ "\twith 106580 stored elements in Compressed Sparse Row format>"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"train_data\n"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 17,
"id": "5ac20e79",
"metadata": {},
"outputs": [],
@@ -537,17 +1251,31 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 18,
"id": "08164a25",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "KNeighborsClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
+ ],
+ "text/plain": [
+ "KNeighborsClassifier()"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"clf.fit(train_data_glove, twenty_train['target'])"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 19,
"id": "e459faaf",
"metadata": {},
"outputs": [],
@@ -557,17 +1285,392 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 20,
"id": "d8144e75",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 7 | \n",
+ " 8 | \n",
+ " 9 | \n",
+ " ... | \n",
+ " 15 | \n",
+ " 16 | \n",
+ " 17 | \n",
+ " 18 | \n",
+ " 19 | \n",
+ " 20 | \n",
+ " 21 | \n",
+ " 22 | \n",
+ " 23 | \n",
+ " 24 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " -6.760635 | \n",
+ " 5.063863 | \n",
+ " -2.779060 | \n",
+ " 3.699120 | \n",
+ " -2.858086 | \n",
+ " 0.135230 | \n",
+ " 20.811229 | \n",
+ " -19.425567 | \n",
+ " 7.302950 | \n",
+ " -5.826012 | \n",
+ " ... | \n",
+ " 3.833378 | \n",
+ " 6.794452 | \n",
+ " -0.921720 | \n",
+ " 12.187404 | \n",
+ " -5.547615 | \n",
+ " -4.133999 | \n",
+ " 3.588260 | \n",
+ " -0.497106 | \n",
+ " -2.542142 | \n",
+ " -11.362855 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " 1.632616 | \n",
+ " 2.512300 | \n",
+ " -0.745513 | \n",
+ " -3.081154 | \n",
+ " 2.182067 | \n",
+ " -1.988816 | \n",
+ " 7.533100 | \n",
+ " -1.015740 | \n",
+ " -0.829598 | \n",
+ " -2.764237 | \n",
+ " ... | \n",
+ " 0.791851 | \n",
+ " 2.114150 | \n",
+ " -2.249193 | \n",
+ " -0.163590 | \n",
+ " -1.177710 | \n",
+ " -2.496928 | \n",
+ " -5.074085 | \n",
+ " -2.666947 | \n",
+ " 0.662050 | \n",
+ " -3.590550 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " 2.115766 | \n",
+ " 2.142060 | \n",
+ " -0.445607 | \n",
+ " -3.229030 | \n",
+ " 1.154580 | \n",
+ " -2.877278 | \n",
+ " 6.399954 | \n",
+ " -10.445769 | \n",
+ " 2.230760 | \n",
+ " -3.299899 | \n",
+ " ... | \n",
+ " 4.388870 | \n",
+ " 8.515056 | \n",
+ " -0.766260 | \n",
+ " 3.549431 | \n",
+ " -1.643443 | \n",
+ " -0.825730 | \n",
+ " -2.968016 | \n",
+ " -0.808924 | \n",
+ " -0.000160 | \n",
+ " -7.468189 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " -0.802784 | \n",
+ " 5.199443 | \n",
+ " 4.294071 | \n",
+ " -7.390966 | \n",
+ " 2.747166 | \n",
+ " -1.359952 | \n",
+ " 15.032628 | \n",
+ " -1.601590 | \n",
+ " 1.474406 | \n",
+ " 2.570105 | \n",
+ " ... | \n",
+ " 3.043432 | \n",
+ " 6.176236 | \n",
+ " -6.193988 | \n",
+ " -3.990476 | \n",
+ " -2.345854 | \n",
+ " -5.534376 | \n",
+ " -8.925422 | \n",
+ " 1.553300 | \n",
+ " 0.905790 | \n",
+ " -12.824533 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " 29.926489 | \n",
+ " 65.324993 | \n",
+ " -25.059592 | \n",
+ " -64.080130 | \n",
+ " 77.565282 | \n",
+ " -34.614604 | \n",
+ " 75.643770 | \n",
+ " -115.600859 | \n",
+ " 90.847175 | \n",
+ " -42.971146 | \n",
+ " ... | \n",
+ " 40.956031 | \n",
+ " 50.322156 | \n",
+ " -19.537098 | \n",
+ " 28.903925 | \n",
+ " -34.643949 | \n",
+ " -69.894146 | \n",
+ " -94.992145 | \n",
+ " -48.601895 | \n",
+ " -29.098555 | \n",
+ " -91.934770 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " 1.829235 | \n",
+ " 4.513807 | \n",
+ " 2.916520 | \n",
+ " 2.237308 | \n",
+ " -1.704831 | \n",
+ " -1.811192 | \n",
+ " 22.196895 | \n",
+ " -12.858912 | \n",
+ " -4.054810 | \n",
+ " -3.130457 | \n",
+ " ... | \n",
+ " 6.019246 | \n",
+ " 8.949456 | \n",
+ " -4.682214 | \n",
+ " -5.648911 | \n",
+ " -1.026898 | \n",
+ " 3.719006 | \n",
+ " 2.449941 | \n",
+ " -6.487197 | \n",
+ " 1.340930 | \n",
+ " -7.325196 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " -0.963815 | \n",
+ " 5.491164 | \n",
+ " 3.567377 | \n",
+ " -6.048021 | \n",
+ " -5.059298 | \n",
+ " -0.977958 | \n",
+ " 15.131499 | \n",
+ " -0.904470 | \n",
+ " 2.185990 | \n",
+ " -1.459807 | \n",
+ " ... | \n",
+ " 0.968499 | \n",
+ " 4.725793 | \n",
+ " -0.726944 | \n",
+ " 1.328612 | \n",
+ " -3.144209 | \n",
+ " 1.643127 | \n",
+ " -1.259245 | \n",
+ " -0.880740 | \n",
+ " -6.713165 | \n",
+ " -3.115454 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " 6.801324 | \n",
+ " 15.348126 | \n",
+ " -17.051718 | \n",
+ " 5.030998 | \n",
+ " 9.332448 | \n",
+ " -5.716691 | \n",
+ " 56.409175 | \n",
+ " -56.250411 | \n",
+ " -4.028209 | \n",
+ " -11.687558 | \n",
+ " ... | \n",
+ " 22.884424 | \n",
+ " 12.940570 | \n",
+ " 1.058664 | \n",
+ " 21.879058 | \n",
+ " -20.897253 | \n",
+ " 2.537755 | \n",
+ " 3.774890 | \n",
+ " -11.495336 | \n",
+ " -2.609774 | \n",
+ " -36.597559 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " 1.054090 | \n",
+ " 0.764524 | \n",
+ " 1.958340 | \n",
+ " -1.085245 | \n",
+ " -0.441392 | \n",
+ " -0.421970 | \n",
+ " 6.139770 | \n",
+ " -0.612219 | \n",
+ " -2.251460 | \n",
+ " -0.465165 | \n",
+ " ... | \n",
+ " 0.377958 | \n",
+ " 1.957450 | \n",
+ " -1.705220 | \n",
+ " -0.509700 | \n",
+ " 0.016110 | \n",
+ " 1.461620 | \n",
+ " 1.589069 | \n",
+ " 2.267340 | \n",
+ " 0.447919 | \n",
+ " -0.469250 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " -18.387286 | \n",
+ " 13.274879 | \n",
+ " -7.895913 | \n",
+ " -1.831442 | \n",
+ " -10.424961 | \n",
+ " -12.248442 | \n",
+ " 32.153890 | \n",
+ " -40.169293 | \n",
+ " 13.089525 | \n",
+ " -21.306493 | \n",
+ " ... | \n",
+ " 6.497279 | \n",
+ " 8.340729 | \n",
+ " 4.996109 | \n",
+ " 23.442078 | \n",
+ " -3.701088 | \n",
+ " -11.671505 | \n",
+ " 9.209790 | \n",
+ " -10.002501 | \n",
+ " -0.815266 | \n",
+ " -17.024052 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
1102 rows × 25 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2 3 4 5 \\\n",
+ "0 -6.760635 5.063863 -2.779060 3.699120 -2.858086 0.135230 \n",
+ "0 1.632616 2.512300 -0.745513 -3.081154 2.182067 -1.988816 \n",
+ "0 2.115766 2.142060 -0.445607 -3.229030 1.154580 -2.877278 \n",
+ "0 -0.802784 5.199443 4.294071 -7.390966 2.747166 -1.359952 \n",
+ "0 29.926489 65.324993 -25.059592 -64.080130 77.565282 -34.614604 \n",
+ ".. ... ... ... ... ... ... \n",
+ "0 1.829235 4.513807 2.916520 2.237308 -1.704831 -1.811192 \n",
+ "0 -0.963815 5.491164 3.567377 -6.048021 -5.059298 -0.977958 \n",
+ "0 6.801324 15.348126 -17.051718 5.030998 9.332448 -5.716691 \n",
+ "0 1.054090 0.764524 1.958340 -1.085245 -0.441392 -0.421970 \n",
+ "0 -18.387286 13.274879 -7.895913 -1.831442 -10.424961 -12.248442 \n",
+ "\n",
+ " 6 7 8 9 ... 15 16 \\\n",
+ "0 20.811229 -19.425567 7.302950 -5.826012 ... 3.833378 6.794452 \n",
+ "0 7.533100 -1.015740 -0.829598 -2.764237 ... 0.791851 2.114150 \n",
+ "0 6.399954 -10.445769 2.230760 -3.299899 ... 4.388870 8.515056 \n",
+ "0 15.032628 -1.601590 1.474406 2.570105 ... 3.043432 6.176236 \n",
+ "0 75.643770 -115.600859 90.847175 -42.971146 ... 40.956031 50.322156 \n",
+ ".. ... ... ... ... ... ... ... \n",
+ "0 22.196895 -12.858912 -4.054810 -3.130457 ... 6.019246 8.949456 \n",
+ "0 15.131499 -0.904470 2.185990 -1.459807 ... 0.968499 4.725793 \n",
+ "0 56.409175 -56.250411 -4.028209 -11.687558 ... 22.884424 12.940570 \n",
+ "0 6.139770 -0.612219 -2.251460 -0.465165 ... 0.377958 1.957450 \n",
+ "0 32.153890 -40.169293 13.089525 -21.306493 ... 6.497279 8.340729 \n",
+ "\n",
+ " 17 18 19 20 21 22 \\\n",
+ "0 -0.921720 12.187404 -5.547615 -4.133999 3.588260 -0.497106 \n",
+ "0 -2.249193 -0.163590 -1.177710 -2.496928 -5.074085 -2.666947 \n",
+ "0 -0.766260 3.549431 -1.643443 -0.825730 -2.968016 -0.808924 \n",
+ "0 -6.193988 -3.990476 -2.345854 -5.534376 -8.925422 1.553300 \n",
+ "0 -19.537098 28.903925 -34.643949 -69.894146 -94.992145 -48.601895 \n",
+ ".. ... ... ... ... ... ... \n",
+ "0 -4.682214 -5.648911 -1.026898 3.719006 2.449941 -6.487197 \n",
+ "0 -0.726944 1.328612 -3.144209 1.643127 -1.259245 -0.880740 \n",
+ "0 1.058664 21.879058 -20.897253 2.537755 3.774890 -11.495336 \n",
+ "0 -1.705220 -0.509700 0.016110 1.461620 1.589069 2.267340 \n",
+ "0 4.996109 23.442078 -3.701088 -11.671505 9.209790 -10.002501 \n",
+ "\n",
+ " 23 24 \n",
+ "0 -2.542142 -11.362855 \n",
+ "0 0.662050 -3.590550 \n",
+ "0 -0.000160 -7.468189 \n",
+ "0 0.905790 -12.824533 \n",
+ "0 -29.098555 -91.934770 \n",
+ ".. ... ... \n",
+ "0 1.340930 -7.325196 \n",
+ "0 -6.713165 -3.115454 \n",
+ "0 -2.609774 -36.597559 \n",
+ "0 0.447919 -0.469250 \n",
+ "0 -0.815266 -17.024052 \n",
+ "\n",
+ "[1102 rows x 25 columns]"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"test_data_glove"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 21,
"id": "a69830f0",
"metadata": {},
"outputs": [],
@@ -577,10 +1680,30 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 22,
"id": "9ac5cf20",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[225 35 59]\n",
+ " [ 26 313 50]\n",
+ " [ 56 98 240]]\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.73 0.71 0.72 319\n",
+ " 1 0.70 0.80 0.75 389\n",
+ " 2 0.69 0.61 0.65 394\n",
+ "\n",
+ " accuracy 0.71 1102\n",
+ " macro avg 0.71 0.71 0.70 1102\n",
+ "weighted avg 0.71 0.71 0.70 1102\n",
+ "\n"
+ ]
+ }
+ ],
"source": [
"print (confusion_matrix(twenty_test['target'], predict))\n",
"print(classification_report(twenty_test['target'], predict))"