diff --git a/labs/OATD_LR2_metod.ipynb b/labs/OATD_LR2_metod.ipynb
index 98e1023..b6b20b2 100644
--- a/labs/OATD_LR2_metod.ipynb
+++ b/labs/OATD_LR2_metod.ipynb
@@ -36,14 +36,16 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer \n",
- "import numpy as np"
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, roc_auc_score\n"
]
},
{
@@ -63,7 +65,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -82,7 +84,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -125,7 +127,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -150,7 +152,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -169,7 +171,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -193,14 +195,14 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "[('image', 489), ('don', 417), ('graphics', 410), ('god', 409), ('people', 384), ('does', 364), ('edu', 349), ('like', 329), ('just', 327), ('know', 319)]\n"
+ "[('image', np.int64(489)), ('don', np.int64(417)), ('graphics', np.int64(410)), ('god', np.int64(409)), ('people', np.int64(384)), ('does', np.int64(364)), ('edu', np.int64(349)), ('like', np.int64(329)), ('just', np.int64(327)), ('know', np.int64(319))]\n"
]
}
],
@@ -212,26 +214,6 @@
"print (x[:10])\n"
]
},
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Если для каждого класса по-отдельности получить подобные списки наиболее частотных слов, то можно оценить пересекаемость терминов двух классов, например, с помощью [меры сходства Жаккара](https://ru.wikipedia.org/wiki/Коэффициент_Жаккара).\n",
- "Ниже приведена функция, которая на вход принимает два списка слов и возвращает значение коэффициента Жаккара:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [],
- "source": [
- "def jaccard_similarity(list1, list2):\n",
- " intersection = len(list(set(list1).intersection(list2)))\n",
- " union = (len(set(list1)) + len(set(list2))) - intersection\n",
- " return float(intersection) / union\n"
- ]
- },
{
"cell_type": "markdown",
"metadata": {},
@@ -243,27 +225,22 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Существует целый ряд алгоритмов стемминга. В работе предлагается использовать алгоритм Портера, реализация которого приведена в библиотеке [nltk](https://www.nltk.org/)\n",
+ "Существует целый ряд алгоритмов стемминга. Один из популярных - алгоритм Портера, реализация которого приведена в библиотеке [nltk](https://www.nltk.org/)\n",
"Для проведения стемминга нужно создать объект `PorterStemmer()`. Стеммер работает таким образом: у созданного объекта `PorterStemmer` есть метод `stem`, производящий стемминга. Таким образом, необходимо каждую из частей выборки (обучающую и тестовую) разбить на отдельные документы, затем, проходя в цикле по каждому слову в документе, произвести стемминг и объединить эти слова в новый документ.\n",
- "\n"
+ "\n",
+ "В ЛР проводить стемминг не требуется.\n"
]
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " i 'll take a wild guess and say freedom is object valuabl . i base thi on the assumpt that if everyon in the world were depriv utterli of their freedom ( so that their everi act wa contrari to their volit ) , almost all would want to complain . therefor i take it that to assert or believ that `` freedom is not veri valuabl '' , when almost everyon can see that it is , is everi bit as absurd as to assert `` it is not rain '' on a raini day . i take thi to be a candid for an object valu , and it it is a necessari condit for object moral that object valu such as thi exist .\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
+ "import nltk\n",
"from nltk.stem import *\n",
"from nltk import word_tokenize\n",
+ "nltk.download('punkt_tab')\n",
"\n",
"porter_stemmer = PorterStemmer()\n",
"stem_train = []\n",
@@ -293,7 +270,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -342,7 +319,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -359,7 +336,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -377,7 +354,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -393,13 +370,4448 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prediction = text_clf.predict(twenty_test.data)"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Настройка параметров с использованием grid search"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.model_selection import GridSearchCV"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "С помощью конвейерной обработки (`pipelines`) стало гораздо проще указывать параметры обучения модели, однако перебирать все возможные варианты вручную даже в нашем простом случае выйдет затратно. В нашем случае имеется четыре настраиваемых параметра: `max_features`, `stop_words`, `use_idf` и `n_neighbors`.\n",
+ "\n",
+ "Вместо поиска лучших параметров в конвейере вручную, можно запустить поиск (методом полного перебора) лучших параметров в сетке возможных значений. Сделать это можно с помощью объекта класса `GridSearchCV`.\n",
+ " \n",
+ "Для того чтобы задать сетку параметров необходимо создать переменную-словарь, ключами которого являются конструкции вида: «НазваниеШагаКонвейера__НазваниеПараметра», а значениями – кортеж из значений параметра\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "parameters = {'vect__max_features': (100,500,1000,5000,10000),\n",
+ " 'vect__stop_words': ('english', None),\n",
+ " 'tfidf__use_idf': (True, False), \n",
+ " 'clf__n_neighbors': (1,3,5,7)} "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Далее необходимо создать объект класса `GridSearchCV`, передав в него объект `pipeline` или классификатор, список параметров сетки, а также при необходимости, задав прочие параметры, такие так количество задействованых ядер процессора `n_jobs`, количество фолдов кросс-валидации `cv`, метрику, по которой будем судить о качестве модели `scoring`, и другие"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gs_clf = GridSearchCV(text_clf, parameters, n_jobs=-1, cv=3, scoring = 'f1_weighted')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Теперь объект `gs_clf` можно обучить по всем параметрам, как обычный классификатор, методом `fit()`.\n",
+ "После того как прошло обучение, узнать лучшую совокупность параметров можно, обратившись к атрибуту `best_params_`. Для необученной модели атрибуты будут отсутствовать.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
GridSearchCV(cv=3,\n",
+ " estimator=Pipeline(steps=[('vect',\n",
+ " CountVectorizer(max_features=1000,\n",
+ " stop_words='english')),\n",
+ " ('tfidf', TfidfTransformer()),\n",
+ " ('clf',\n",
+ " KNeighborsClassifier(n_neighbors=1))]),\n",
+ " n_jobs=-1,\n",
+ " param_grid={'clf__n_neighbors': (1, 3, 5, 7),\n",
+ " 'tfidf__use_idf': (True, False),\n",
+ " 'vect__max_features': (100, 500, 1000, 5000, 10000),\n",
+ " 'vect__stop_words': ('english', None)},\n",
+ " scoring='f1_weighted') In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. \n",
+ "
\n",
+ "
\n",
+ " Parameters \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " estimator\n",
+ " estimator: estimator object This is assumed to implement the scikit-learn estimator interface. Either estimator needs to provide a ``score`` function, or ``scoring`` must be passed. \n",
+ " \n",
+ " \n",
+ " Pipeline(step...eighbors=1))]) \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " param_grid\n",
+ " param_grid: dict or list of dictionaries Dictionary with parameters names (`str`) as keys and lists of parameter settings to try as values, or a list of such dictionaries, in which case the grids spanned by each dictionary in the list are explored. This enables searching over any sequence of parameter settings. \n",
+ " \n",
+ " \n",
+ " {'clf__n_neighbors': (1, ...), 'tfidf__use_idf': (True, ...), 'vect__max_features': (100, ...), 'vect__stop_words': ('english', ...)} \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " scoring\n",
+ " scoring: str, callable, list, tuple or dict, default=None Strategy to evaluate the performance of the cross-validated model on the test set. If `scoring` represents a single score, one can use: - a single string (see :ref:`scoring_string_names`); - a callable (see :ref:`scoring_callable`) that returns a single value; - `None`, the `estimator`'s :ref:`default evaluation criterion ` is used. If `scoring` represents multiple scores, one can use: - a list or tuple of unique strings; - a callable returning a dictionary where the keys are the metric names and the values are the metric scores; - a dictionary with metric names as keys and callables as values. See :ref:`multimetric_grid_search` for an example. \n",
+ " \n",
+ " \n",
+ " 'f1_weighted' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " n_jobs\n",
+ " n_jobs: int, default=None Number of jobs to run in parallel. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details. .. versionchanged:: v0.20 `n_jobs` default changed from 1 to None \n",
+ " \n",
+ " \n",
+ " -1 \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " refit\n",
+ " refit: bool, str, or callable, default=True Refit an estimator using the best found parameters on the whole dataset. For multiple metric evaluation, this needs to be a `str` denoting the scorer that would be used to find the best parameters for refitting the estimator at the end. Where there are considerations other than maximum score in choosing a best estimator, ``refit`` can be set to a function which returns the selected ``best_index_`` given ``cv_results_``. In that case, the ``best_estimator_`` and ``best_params_`` will be set according to the returned ``best_index_`` while the ``best_score_`` attribute will not be available. The refitted estimator is made available at the ``best_estimator_`` attribute and permits using ``predict`` directly on this ``GridSearchCV`` instance. Also for multiple metric evaluation, the attributes ``best_index_``, ``best_score_`` and ``best_params_`` will only be available if ``refit`` is set and all of them will be determined w.r.t this specific scorer. See ``scoring`` parameter to know more about multiple metric evaluation. See :ref:`sphx_glr_auto_examples_model_selection_plot_grid_search_digits.py` to see how to design a custom selection strategy using a callable via `refit`. See :ref:`this example` for an example of how to use ``refit=callable`` to balance model complexity and cross-validated score. .. versionchanged:: 0.20 Support for callable added. \n",
+ " \n",
+ " \n",
+ " True \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " cv\n",
+ " cv: int, cross-validation generator or an iterable, default=None Determines the cross-validation splitting strategy. Possible inputs for cv are: - None, to use the default 5-fold cross validation, - integer, to specify the number of folds in a `(Stratified)KFold`, - :term:`CV splitter`, - An iterable yielding (train, test) splits as arrays of indices. For integer/None inputs, if the estimator is a classifier and ``y`` is either binary or multiclass, :class:`StratifiedKFold` is used. In all other cases, :class:`KFold` is used. These splitters are instantiated with `shuffle=False` so the splits will be the same across calls. Refer :ref:`User Guide ` for the various cross-validation strategies that can be used here. .. versionchanged:: 0.22 ``cv`` default value if None changed from 3-fold to 5-fold. \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " verbose\n",
+ " verbose: int Controls the verbosity: the higher, the more messages. - >1 : the computation time for each fold and parameter candidate is displayed; - >2 : the score is also displayed; - >3 : the fold and candidate parameter indexes are also displayed together with the starting time of the computation. \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " pre_dispatch\n",
+ " pre_dispatch: int, or str, default='2*n_jobs' Controls the number of jobs that get dispatched during parallel execution. Reducing this number can be useful to avoid an explosion of memory consumption when more jobs get dispatched than CPUs can process. This parameter can be: - None, in which case all the jobs are immediately created and spawned. Use this for lightweight and fast-running jobs, to avoid delays due to on-demand spawning of the jobs - An int, giving the exact number of total jobs that are spawned - A str, giving an expression as a function of n_jobs, as in '2*n_jobs' \n",
+ " \n",
+ " \n",
+ " '2*n_jobs' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " error_score\n",
+ " error_score: 'raise' or numeric, default=np.nan Value to assign to the score if an error occurs in estimator fitting. If set to 'raise', the error is raised. If a numeric value is given, FitFailedWarning is raised. This parameter does not affect the refit step, which will always raise the error. \n",
+ " \n",
+ " \n",
+ " nan \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " return_train_score\n",
+ " return_train_score: bool, default=False If ``False``, the ``cv_results_`` attribute will not include training scores. Computing training scores is used to get insights on how different parameter settings impact the overfitting/underfitting trade-off. However computing the scores on the training set can be computationally expensive and is not strictly required to select the parameters that yield the best generalization performance. .. versionadded:: 0.19 .. versionchanged:: 0.21 Default value was changed from ``True`` to ``False`` \n",
+ " \n",
+ " \n",
+ " False \n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " Parameters \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " input\n",
+ " input: {'filename', 'file', 'content'}, default='content' - If `'filename'`, the sequence passed as an argument to fit is expected to be a list of filenames that need reading to fetch the raw content to analyze. - If `'file'`, the sequence items must have a 'read' method (file-like object) that is called to fetch the bytes in memory. - If `'content'`, the input is expected to be a sequence of items that can be of type string or byte. \n",
+ " \n",
+ " \n",
+ " 'content' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " encoding\n",
+ " encoding: str, default='utf-8' If bytes or files are given to analyze, this encoding is used to decode. \n",
+ " \n",
+ " \n",
+ " 'utf-8' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " decode_error\n",
+ " decode_error: {'strict', 'ignore', 'replace'}, default='strict' Instruction on what to do if a byte sequence is given to analyze that contains characters not of the given `encoding`. By default, it is 'strict', meaning that a UnicodeDecodeError will be raised. Other values are 'ignore' and 'replace'. \n",
+ " \n",
+ " \n",
+ " 'strict' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " strip_accents\n",
+ " strip_accents: {'ascii', 'unicode'} or callable, default=None Remove accents and perform other character normalization during the preprocessing step. 'ascii' is a fast method that only works on characters that have a direct ASCII mapping. 'unicode' is a slightly slower method that works on any characters. None (default) means no character normalization is performed. Both 'ascii' and 'unicode' use NFKD normalization from :func:`unicodedata.normalize`. \n",
+ " \n",
+ " \n",
+ " None \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " lowercase\n",
+ " lowercase: bool, default=True Convert all characters to lowercase before tokenizing. \n",
+ " \n",
+ " \n",
+ " True \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " preprocessor\n",
+ " preprocessor: callable, default=None Override the preprocessing (strip_accents and lowercase) stage while preserving the tokenizing and n-grams generation steps. Only applies if ``analyzer`` is not callable. \n",
+ " \n",
+ " \n",
+ " None \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " tokenizer\n",
+ " tokenizer: callable, default=None Override the string tokenization step while preserving the preprocessing and n-grams generation steps. Only applies if ``analyzer == 'word'``. \n",
+ " \n",
+ " \n",
+ " None \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " stop_words\n",
+ " stop_words: {'english'}, list, default=None If 'english', a built-in stop word list for English is used. There are several known issues with 'english' and you should consider an alternative (see :ref:`stop_words`). If a list, that list is assumed to contain stop words, all of which will be removed from the resulting tokens. Only applies if ``analyzer == 'word'``. If None, no stop words will be used. In this case, setting `max_df` to a higher value, such as in the range (0.7, 1.0), can automatically detect and filter stop words based on intra corpus document frequency of terms. \n",
+ " \n",
+ " \n",
+ " 'english' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " token_pattern\n",
+ " token_pattern: str or None, default=r\"(?u)\\\\b\\\\w\\\\w+\\\\b\" Regular expression denoting what constitutes a \"token\", only used if ``analyzer == 'word'``. The default regexp select tokens of 2 or more alphanumeric characters (punctuation is completely ignored and always treated as a token separator). If there is a capturing group in token_pattern then the captured group content, not the entire match, becomes the token. At most one capturing group is permitted. \n",
+ " \n",
+ " \n",
+ " '(?u)\\\\b\\\\w\\\\w+\\\\b' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " ngram_range\n",
+ " ngram_range: tuple (min_n, max_n), default=(1, 1) The lower and upper boundary of the range of n-values for different word n-grams or char n-grams to be extracted. All values of n such such that min_n <= n <= max_n will be used. For example an ``ngram_range`` of ``(1, 1)`` means only unigrams, ``(1, 2)`` means unigrams and bigrams, and ``(2, 2)`` means only bigrams. Only applies if ``analyzer`` is not callable. \n",
+ " \n",
+ " \n",
+ " (1, ...) \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " analyzer\n",
+ " analyzer: {'word', 'char', 'char_wb'} or callable, default='word' Whether the feature should be made of word n-gram or character n-grams. Option 'char_wb' creates character n-grams only from text inside word boundaries; n-grams at the edges of words are padded with space. If a callable is passed it is used to extract the sequence of features out of the raw, unprocessed input. .. versionchanged:: 0.21 Since v0.21, if ``input`` is ``filename`` or ``file``, the data is first read from the file and then passed to the given callable analyzer. \n",
+ " \n",
+ " \n",
+ " 'word' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " max_df\n",
+ " max_df: float in range [0.0, 1.0] or int, default=1.0 When building the vocabulary ignore terms that have a document frequency strictly higher than the given threshold (corpus-specific stop words). If float, the parameter represents a proportion of documents, integer absolute counts. This parameter is ignored if vocabulary is not None. \n",
+ " \n",
+ " \n",
+ " 1.0 \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " min_df\n",
+ " min_df: float in range [0.0, 1.0] or int, default=1 When building the vocabulary ignore terms that have a document frequency strictly lower than the given threshold. This value is also called cut-off in the literature. If float, the parameter represents a proportion of documents, integer absolute counts. This parameter is ignored if vocabulary is not None. \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " max_features\n",
+ " max_features: int, default=None If not None, build a vocabulary that only consider the top `max_features` ordered by term frequency across the corpus. Otherwise, all features are used. This parameter is ignored if vocabulary is not None. \n",
+ " \n",
+ " \n",
+ " 100 \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " vocabulary\n",
+ " vocabulary: Mapping or iterable, default=None Either a Mapping (e.g., a dict) where keys are terms and values are indices in the feature matrix, or an iterable over terms. If not given, a vocabulary is determined from the input documents. Indices in the mapping should not be repeated and should not have any gap between 0 and the largest index. \n",
+ " \n",
+ " \n",
+ " None \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " binary\n",
+ " binary: bool, default=False If True, all non zero counts are set to 1. This is useful for discrete probabilistic models that model binary events rather than integer counts. \n",
+ " \n",
+ " \n",
+ " False \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " dtype\n",
+ " dtype: dtype, default=np.int64 Type of the matrix returned by fit_transform() or transform(). \n",
+ " \n",
+ " \n",
+ " <class 'numpy.int64'> \n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " Parameters \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " norm\n",
+ " norm: {'l1', 'l2'} or None, default='l2' Each output row will have unit norm, either: - 'l2': Sum of squares of vector elements is 1. The cosine similarity between two vectors is their dot product when l2 norm has been applied. - 'l1': Sum of absolute values of vector elements is 1. See :func:`~sklearn.preprocessing.normalize`. - None: No normalization. \n",
+ " \n",
+ " \n",
+ " 'l2' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " use_idf\n",
+ " use_idf: bool, default=True Enable inverse-document-frequency reweighting. If False, idf(t) = 1. \n",
+ " \n",
+ " \n",
+ " True \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " smooth_idf\n",
+ " smooth_idf: bool, default=True Smooth idf weights by adding one to document frequencies, as if an extra document was seen containing every term in the collection exactly once. Prevents zero divisions. \n",
+ " \n",
+ " \n",
+ " True \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " sublinear_tf\n",
+ " sublinear_tf: bool, default=False Apply sublinear tf scaling, i.e. replace tf with 1 + log(tf). \n",
+ " \n",
+ " \n",
+ " False \n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " Parameters \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " n_neighbors\n",
+ " n_neighbors: int, default=5 Number of neighbors to use by default for :meth:`kneighbors` queries. \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " weights\n",
+ " weights: {'uniform', 'distance'}, callable or None, default='uniform' Weight function used in prediction. Possible values: - 'uniform' : uniform weights. All points in each neighborhood are weighted equally. - 'distance' : weight points by the inverse of their distance. in this case, closer neighbors of a query point will have a greater influence than neighbors which are further away. - [callable] : a user-defined function which accepts an array of distances, and returns an array of the same shape containing the weights. Refer to the example entitled :ref:`sphx_glr_auto_examples_neighbors_plot_classification.py` showing the impact of the `weights` parameter on the decision boundary. \n",
+ " \n",
+ " \n",
+ " 'uniform' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " algorithm\n",
+ " algorithm: {'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto' Algorithm used to compute the nearest neighbors: - 'ball_tree' will use :class:`BallTree` - 'kd_tree' will use :class:`KDTree` - 'brute' will use a brute-force search. - 'auto' will attempt to decide the most appropriate algorithm based on the values passed to :meth:`fit` method. Note: fitting on sparse input will override the setting of this parameter, using brute force. \n",
+ " \n",
+ " \n",
+ " 'auto' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " leaf_size\n",
+ " leaf_size: int, default=30 Leaf size passed to BallTree or KDTree. This can affect the speed of the construction and query, as well as the memory required to store the tree. The optimal value depends on the nature of the problem. \n",
+ " \n",
+ " \n",
+ " 30 \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " p\n",
+ " p: float, default=2 Power parameter for the Minkowski metric. When p = 1, this is equivalent to using manhattan_distance (l1), and euclidean_distance (l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used. This parameter is expected to be positive. \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " metric\n",
+ " metric: str or callable, default='minkowski' Metric to use for distance computation. Default is \"minkowski\", which results in the standard Euclidean distance when p = 2. See the documentation of `scipy.spatial.distance`_ and the metrics listed in :class:`~sklearn.metrics.pairwise.distance_metrics` for valid metric values. If metric is \"precomputed\", X is assumed to be a distance matrix and must be square during fit. X may be a :term:`sparse graph`, in which case only \"nonzero\" elements may be considered neighbors. If metric is a callable function, it takes two arrays representing 1D vectors as inputs and must return one value indicating the distance between those vectors. This works for Scipy's metrics, but is less efficient than passing the metric name as a string. \n",
+ " \n",
+ " \n",
+ " 'minkowski' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " metric_params\n",
+ " metric_params: dict, default=None Additional keyword arguments for the metric function. \n",
+ " \n",
+ " \n",
+ " None \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " n_jobs\n",
+ " n_jobs: int, default=None The number of parallel jobs to run for neighbors search. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details. Doesn't affect :meth:`fit` method. \n",
+ " \n",
+ " \n",
+ " None \n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ "GridSearchCV(cv=3,\n",
+ " estimator=Pipeline(steps=[('vect',\n",
+ " CountVectorizer(max_features=1000,\n",
+ " stop_words='english')),\n",
+ " ('tfidf', TfidfTransformer()),\n",
+ " ('clf',\n",
+ " KNeighborsClassifier(n_neighbors=1))]),\n",
+ " n_jobs=-1,\n",
+ " param_grid={'clf__n_neighbors': (1, 3, 5, 7),\n",
+ " 'tfidf__use_idf': (True, False),\n",
+ " 'vect__max_features': (100, 500, 1000, 5000, 10000),\n",
+ " 'vect__stop_words': ('english', None)},\n",
+ " scoring='f1_weighted')"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gs_clf.fit(twenty_train.data, twenty_train.target)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': 'english'}"
+ ]
+ },
+ "execution_count": 30,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gs_clf.best_params_"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'mean_fit_time': array([0.14711912, 0.1838789 , 0.16524474, 0.13485821, 0.17077629,\n",
+ " 0.171736 , 0.13478573, 0.17135119, 0.1455044 , 0.16278481,\n",
+ " 0.14029678, 0.17597159, 0.14752992, 0.1515433 , 0.2238245 ,\n",
+ " 0.16685406, 0.19492237, 0.18400868, 0.1637164 , 0.23180572,\n",
+ " 0.13470483, 0.19008255, 0.168998 , 0.14402525, 0.14959741,\n",
+ " 0.21908172, 0.12458984, 0.17676258, 0.181101 , 0.15396865,\n",
+ " 0.13980309, 0.17127609, 0.12496074, 0.21157686, 0.14910388,\n",
+ " 0.16302236, 0.20699684, 0.14392487, 0.13715124, 0.19204704,\n",
+ " 0.12343788, 0.202528 , 0.1306417 , 0.14448158, 0.16433088,\n",
+ " 0.17828568, 0.15127762, 0.16031146, 0.15140502, 0.17451549,\n",
+ " 0.13383325, 0.19785619, 0.14872464, 0.19805606, 0.17880678,\n",
+ " 0.14933427, 0.18098283, 0.17271209, 0.19098576, 0.15773439,\n",
+ " 0.14277951, 0.14034979, 0.17362897, 0.14588912, 0.15714224,\n",
+ " 0.19136397, 0.14665333, 0.18524861, 0.1639146 , 0.15949496,\n",
+ " 0.14914314, 0.19116807, 0.16394456, 0.17290258, 0.14597694,\n",
+ " 0.18869273, 0.17521318, 0.18450022, 0.15120141, 0.14013394]),\n",
+ " 'std_fit_time': array([0.0426388 , 0.00865779, 0.03047002, 0.01959866, 0.01996265,\n",
+ " 0.03991559, 0.02089307, 0.0115939 , 0.01577824, 0.01259437,\n",
+ " 0.02861695, 0.02803293, 0.02632727, 0.03513436, 0.10327593,\n",
+ " 0.06408524, 0.04651852, 0.03295957, 0.0471307 , 0.11679622,\n",
+ " 0.0110411 , 0.02803024, 0.05824908, 0.01165112, 0.01597652,\n",
+ " 0.06360679, 0.0225909 , 0.03841467, 0.06331909, 0.02901401,\n",
+ " 0.02955105, 0.00180821, 0.01985555, 0.0523942 , 0.03268085,\n",
+ " 0.03894983, 0.06517167, 0.02243237, 0.02200533, 0.0133168 ,\n",
+ " 0.02014168, 0.05538047, 0.00691204, 0.02019455, 0.04787927,\n",
+ " 0.02202076, 0.03367995, 0.04032929, 0.04035338, 0.05086764,\n",
+ " 0.03783577, 0.02842441, 0.01027949, 0.02261441, 0.07413437,\n",
+ " 0.03057069, 0.07406925, 0.06374613, 0.07499723, 0.00702891,\n",
+ " 0.01449466, 0.01535535, 0.06266348, 0.01594578, 0.03962651,\n",
+ " 0.03872874, 0.02924308, 0.03921288, 0.03093211, 0.02737551,\n",
+ " 0.03913621, 0.03922343, 0.06958397, 0.0321589 , 0.03550322,\n",
+ " 0.06330303, 0.04521662, 0.01877706, 0.01843914, 0.02003679]),\n",
+ " 'mean_score_time': array([0.07024606, 0.09829744, 0.06123559, 0.0934546 , 0.06724668,\n",
+ " 0.08956035, 0.06278229, 0.08930341, 0.07657051, 0.08451796,\n",
+ " 0.06686974, 0.06430848, 0.06159838, 0.06787952, 0.0685486 ,\n",
+ " 0.08175969, 0.07736111, 0.08548617, 0.10391927, 0.11483145,\n",
+ " 0.06140908, 0.08571712, 0.05673496, 0.07094908, 0.07356191,\n",
+ " 0.08419561, 0.06646315, 0.07401323, 0.0662907 , 0.07332158,\n",
+ " 0.06078386, 0.06057103, 0.05959137, 0.07239731, 0.06452703,\n",
+ " 0.06926155, 0.07553283, 0.07713358, 0.06963356, 0.09399152,\n",
+ " 0.05499752, 0.06101354, 0.06951396, 0.06801232, 0.06193964,\n",
+ " 0.08509183, 0.06340122, 0.09640757, 0.09270819, 0.07950346,\n",
+ " 0.06135122, 0.07454085, 0.06276194, 0.08477139, 0.06570975,\n",
+ " 0.06936661, 0.07554563, 0.08220371, 0.06743741, 0.10737284,\n",
+ " 0.06237586, 0.06374478, 0.07077082, 0.06676245, 0.06271029,\n",
+ " 0.10701195, 0.07044252, 0.07724508, 0.0787247 , 0.07522114,\n",
+ " 0.06066545, 0.08716504, 0.06308579, 0.08094724, 0.08140651,\n",
+ " 0.07310184, 0.07593838, 0.07488211, 0.06320556, 0.04453659]),\n",
+ " 'std_score_time': array([0.0307069 , 0.0182529 , 0.02101209, 0.03205303, 0.0257132 ,\n",
+ " 0.02687809, 0.0173987 , 0.02640571, 0.01539859, 0.01465485,\n",
+ " 0.02122433, 0.01356409, 0.00834315, 0.01751938, 0.01171655,\n",
+ " 0.03144068, 0.02179221, 0.01969381, 0.04494128, 0.00622263,\n",
+ " 0.01579416, 0.02897917, 0.01304806, 0.0202541 , 0.02151299,\n",
+ " 0.03211671, 0.01358637, 0.01938508, 0.01422302, 0.01688134,\n",
+ " 0.00869969, 0.0141093 , 0.01194079, 0.01273514, 0.01648673,\n",
+ " 0.01545342, 0.01899566, 0.0178712 , 0.02332053, 0.03043117,\n",
+ " 0.01200954, 0.01238742, 0.02325767, 0.01629294, 0.01778484,\n",
+ " 0.03084458, 0.0157806 , 0.04616991, 0.02142426, 0.01231431,\n",
+ " 0.0059662 , 0.01826184, 0.02046084, 0.02321078, 0.00858559,\n",
+ " 0.01727098, 0.01823882, 0.02170296, 0.0142323 , 0.06476885,\n",
+ " 0.02283514, 0.01249623, 0.00957304, 0.01621181, 0.00941332,\n",
+ " 0.04093458, 0.01894285, 0.0226533 , 0.02176937, 0.01821806,\n",
+ " 0.01055958, 0.03213191, 0.00837758, 0.02668343, 0.03007748,\n",
+ " 0.01731415, 0.01637027, 0.01601808, 0.02890825, 0.01039945]),\n",
+ " 'param_clf__n_neighbors': masked_array(data=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
+ " 3, 3, 3, 3, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n",
+ " 5, 5, 5, 5, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,\n",
+ " 7, 7, 7, 7, 7, 7, 7, 7],\n",
+ " mask=[False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False],\n",
+ " fill_value=999999),\n",
+ " 'param_tfidf__use_idf': masked_array(data=[True, True, True, True, True, True, True, True, True,\n",
+ " True, False, False, False, False, False, False, False,\n",
+ " False, False, False, True, True, True, True, True,\n",
+ " True, True, True, True, True, False, False, False,\n",
+ " False, False, False, False, False, False, False, True,\n",
+ " True, True, True, True, True, True, True, True, True,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, True, True, True, True, True, True, True,\n",
+ " True, True, True, False, False, False, False, False,\n",
+ " False, False, False, False, False],\n",
+ " mask=[False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False],\n",
+ " fill_value=True),\n",
+ " 'param_vect__max_features': masked_array(data=[100, 100, 500, 500, 1000, 1000, 5000, 5000, 10000,\n",
+ " 10000, 100, 100, 500, 500, 1000, 1000, 5000, 5000,\n",
+ " 10000, 10000, 100, 100, 500, 500, 1000, 1000, 5000,\n",
+ " 5000, 10000, 10000, 100, 100, 500, 500, 1000, 1000,\n",
+ " 5000, 5000, 10000, 10000, 100, 100, 500, 500, 1000,\n",
+ " 1000, 5000, 5000, 10000, 10000, 100, 100, 500, 500,\n",
+ " 1000, 1000, 5000, 5000, 10000, 10000, 100, 100, 500,\n",
+ " 500, 1000, 1000, 5000, 5000, 10000, 10000, 100, 100,\n",
+ " 500, 500, 1000, 1000, 5000, 5000, 10000, 10000],\n",
+ " mask=[False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False],\n",
+ " fill_value=999999),\n",
+ " 'param_vect__stop_words': masked_array(data=['english', None, 'english', None, 'english', None,\n",
+ " 'english', None, 'english', None, 'english', None,\n",
+ " 'english', None, 'english', None, 'english', None,\n",
+ " 'english', None, 'english', None, 'english', None,\n",
+ " 'english', None, 'english', None, 'english', None,\n",
+ " 'english', None, 'english', None, 'english', None,\n",
+ " 'english', None, 'english', None, 'english', None,\n",
+ " 'english', None, 'english', None, 'english', None,\n",
+ " 'english', None, 'english', None, 'english', None,\n",
+ " 'english', None, 'english', None, 'english', None,\n",
+ " 'english', None, 'english', None, 'english', None,\n",
+ " 'english', None, 'english', None, 'english', None,\n",
+ " 'english', None, 'english', None, 'english', None,\n",
+ " 'english', None],\n",
+ " mask=[False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False,\n",
+ " False, False, False, False, False, False, False, False],\n",
+ " fill_value=np.str_('?'),\n",
+ " dtype=object),\n",
+ " 'params': [{'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 1,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 3,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 5,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': True,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 100,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 500,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 1000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 5000,\n",
+ " 'vect__stop_words': None},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': 'english'},\n",
+ " {'clf__n_neighbors': 7,\n",
+ " 'tfidf__use_idf': False,\n",
+ " 'vect__max_features': 10000,\n",
+ " 'vect__stop_words': None}],\n",
+ " 'split0_test_score': array([0.76817583, 0.72894636, 0.53358157, 0.6374826 , 0.52798315,\n",
+ " 0.56519764, 0.51309075, 0.49836392, 0.49549296, 0.48037932,\n",
+ " 0.74678068, 0.6781238 , 0.53815493, 0.67256907, 0.59194787,\n",
+ " 0.61495561, 0.50607973, 0.58348819, 0.48282372, 0.56647032,\n",
+ " 0.78847084, 0.7252709 , 0.48067615, 0.52118859, 0.51721694,\n",
+ " 0.48760057, 0.51577841, 0.46346843, 0.43209944, 0.42022184,\n",
+ " 0.79654539, 0.72168041, 0.50540586, 0.62161402, 0.53994855,\n",
+ " 0.57362945, 0.53619645, 0.54414205, 0.43884702, 0.4762766 ,\n",
+ " 0.75153416, 0.72862518, 0.45804881, 0.49005168, 0.52767212,\n",
+ " 0.49479873, 0.48059474, 0.45872108, 0.43320814, 0.46831129,\n",
+ " 0.77799622, 0.71691323, 0.468479 , 0.58540742, 0.53125944,\n",
+ " 0.55188163, 0.53165214, 0.49549296, 0.46225076, 0.46334516,\n",
+ " 0.73731272, 0.722074 , 0.42768101, 0.49765146, 0.48090102,\n",
+ " 0.59104363, 0.51802701, 0.48304986, 0.44659485, 0.52984156,\n",
+ " 0.76674346, 0.71632941, 0.42110969, 0.55322776, 0.49049555,\n",
+ " 0.55175372, 0.55288072, 0.52752419, 0.47634782, 0.48746189]),\n",
+ " 'split1_test_score': array([0.79662124, 0.74654727, 0.55840557, 0.65284461, 0.56848052,\n",
+ " 0.58959475, 0.50824228, 0.52616761, 0.55281361, 0.53900106,\n",
+ " 0.79198638, 0.74396927, 0.60547906, 0.71852445, 0.55601273,\n",
+ " 0.70906678, 0.55288072, 0.70132316, 0.54064496, 0.72369393,\n",
+ " 0.8200291 , 0.76394456, 0.50142662, 0.65973942, 0.57835078,\n",
+ " 0.58411319, 0.56231765, 0.54929577, 0.53755372, 0.54631601,\n",
+ " 0.76828781, 0.73851721, 0.57218138, 0.70341759, 0.54466558,\n",
+ " 0.70088681, 0.52052675, 0.68262612, 0.56159728, 0.69962829,\n",
+ " 0.78081641, 0.76676231, 0.47784768, 0.65979122, 0.55877981,\n",
+ " 0.58351195, 0.55934603, 0.5526496 , 0.53805724, 0.54275916,\n",
+ " 0.7590099 , 0.74428157, 0.50962843, 0.69741916, 0.56721274,\n",
+ " 0.6845754 , 0.53674839, 0.65880645, 0.56642635, 0.66528626,\n",
+ " 0.75249835, 0.76114625, 0.5033107 , 0.61728125, 0.54473029,\n",
+ " 0.54902822, 0.55716104, 0.55466537, 0.51280439, 0.54540585,\n",
+ " 0.72900844, 0.73304581, 0.49250319, 0.69558085, 0.57262651,\n",
+ " 0.66997019, 0.54339107, 0.66489176, 0.56953075, 0.65745565]),\n",
+ " 'split2_test_score': array([0.82526209, 0.7519849 , 0.59510748, 0.68219353, 0.5764876 ,\n",
+ " 0.6305322 , 0.54587929, 0.56161317, 0.52387661, 0.56501726,\n",
+ " 0.82242212, 0.71247211, 0.60517653, 0.7037331 , 0.54624902,\n",
+ " 0.68817364, 0.5608424 , 0.65146866, 0.53526413, 0.66897713,\n",
+ " 0.81950927, 0.79423613, 0.52825543, 0.66907312, 0.56248767,\n",
+ " 0.60389162, 0.49652131, 0.55180864, 0.50857805, 0.52741727,\n",
+ " 0.80244926, 0.69231892, 0.53887943, 0.71969825, 0.53158497,\n",
+ " 0.6792325 , 0.49764733, 0.65713584, 0.49078794, 0.66897713,\n",
+ " 0.80172943, 0.75452358, 0.54031316, 0.63620099, 0.52529445,\n",
+ " 0.59626871, 0.51327659, 0.53187959, 0.50624241, 0.51743203,\n",
+ " 0.76711749, 0.70123506, 0.49774614, 0.71460958, 0.50316727,\n",
+ " 0.65013293, 0.49732848, 0.64249816, 0.52623393, 0.65444272,\n",
+ " 0.76791783, 0.73768667, 0.53419234, 0.63714302, 0.50165182,\n",
+ " 0.5932877 , 0.49723701, 0.51859739, 0.49993588, 0.52478814,\n",
+ " 0.76415908, 0.69550001, 0.48572333, 0.68222059, 0.50972411,\n",
+ " 0.66549536, 0.47022833, 0.64386329, 0.50648577, 0.6415092 ]),\n",
+ " 'mean_test_score': array([0.79668639, 0.74249284, 0.56236487, 0.65750691, 0.55765042,\n",
+ " 0.5951082 , 0.52240411, 0.5287149 , 0.52406106, 0.52813255,\n",
+ " 0.78706306, 0.71152173, 0.58293684, 0.69827554, 0.56473654,\n",
+ " 0.67073201, 0.53993428, 0.64542667, 0.5195776 , 0.65304713,\n",
+ " 0.8093364 , 0.76115053, 0.50345273, 0.61666704, 0.55268513,\n",
+ " 0.55853513, 0.52487246, 0.52152428, 0.49274373, 0.49798504,\n",
+ " 0.78909415, 0.71750551, 0.53882222, 0.68157662, 0.53873303,\n",
+ " 0.65124959, 0.51812351, 0.627968 , 0.49707742, 0.61496068,\n",
+ " 0.77802667, 0.74997036, 0.49206988, 0.59534796, 0.53724879,\n",
+ " 0.55819313, 0.51773912, 0.51441676, 0.4925026 , 0.50950083,\n",
+ " 0.7680412 , 0.72080995, 0.49195119, 0.66581205, 0.53387982,\n",
+ " 0.62886332, 0.52190967, 0.59893252, 0.51830368, 0.59435805,\n",
+ " 0.7525763 , 0.74030231, 0.48839468, 0.58402524, 0.50909438,\n",
+ " 0.57778652, 0.52414169, 0.51877087, 0.48644504, 0.53334518,\n",
+ " 0.75330366, 0.71495841, 0.4664454 , 0.6436764 , 0.52428206,\n",
+ " 0.62907309, 0.52216671, 0.61209308, 0.51745478, 0.59547558]),\n",
+ " 'std_test_score': array([0.02330541, 0.00983268, 0.02527339, 0.01854849, 0.02123109,\n",
+ " 0.02695613, 0.01671706, 0.02588415, 0.02340142, 0.03539763,\n",
+ " 0.0310761 , 0.0268897 , 0.03166583, 0.01915399, 0.01964985,\n",
+ " 0.04035167, 0.02415844, 0.04829527, 0.02608159, 0.06516717,\n",
+ " 0.01475571, 0.02822417, 0.01947692, 0.06762091, 0.02590243,\n",
+ " 0.05080407, 0.02762023, 0.0410645 , 0.04448367, 0.05552553,\n",
+ " 0.01490843, 0.01909001, 0.02726102, 0.04291775, 0.00540887,\n",
+ " 0.05559311, 0.0158291 , 0.06018046, 0.05030954, 0.09885959,\n",
+ " 0.02058686, 0.01589883, 0.03505766, 0.07507599, 0.01525564,\n",
+ " 0.04512812, 0.03230456, 0.04028527, 0.04389321, 0.0309063 ,\n",
+ " 0.0077786 , 0.01778837, 0.01729171, 0.05728616, 0.02621202,\n",
+ " 0.05622103, 0.0175056 , 0.07344521, 0.04289759, 0.09274581,\n",
+ " 0.0124946 , 0.01605805, 0.04474395, 0.06161139, 0.0265843 ,\n",
+ " 0.02035581, 0.02484303, 0.02923717, 0.02866389, 0.00877417,\n",
+ " 0.01721168, 0.01535864, 0.03217646, 0.064189 , 0.03507444,\n",
+ " 0.05470356, 0.03692975, 0.06041232, 0.03882443, 0.07665416]),\n",
+ " 'rank_test_score': array([ 2, 11, 41, 21, 44, 35, 58, 52, 57, 53, 4, 16, 38, 17, 40, 19, 46,\n",
+ " 24, 62, 22, 1, 7, 71, 29, 45, 42, 54, 61, 74, 72, 3, 14, 47, 18,\n",
+ " 48, 23, 65, 28, 73, 30, 5, 10, 76, 34, 49, 43, 66, 68, 75, 69, 6,\n",
+ " 13, 77, 20, 50, 27, 60, 32, 64, 36, 9, 12, 78, 37, 70, 39, 56, 63,\n",
+ " 79, 51, 8, 15, 80, 25, 55, 26, 59, 31, 67, 33], dtype=int32)}"
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gs_clf.cv_results_"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "np.float64(0.8093364049132378)"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gs_clf.best_score_"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Word embedding"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['fasttext-wiki-news-subwords-300', 'conceptnet-numberbatch-17-06-300', 'word2vec-ruscorpora-300', 'word2vec-google-news-300', 'glove-wiki-gigaword-50', 'glove-wiki-gigaword-100', 'glove-wiki-gigaword-200', 'glove-wiki-gigaword-300', 'glove-twitter-25', 'glove-twitter-50', 'glove-twitter-100', 'glove-twitter-200', '__testing_word2vec-matrix-synopsis']\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Импортируем библиотеку gensim и посмотрим какие обученные модели в ней есть.\n",
+ "import gensim.downloader\n",
+ "print(list(gensim.downloader.info()['models'].keys()))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## GloVe\n",
+ "Загрузим \"glove-twitter-25\" и будем работать с ней"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[==================================================] 100.0% 104.8/104.8MB downloaded\n"
+ ]
+ }
+ ],
+ "source": [
+ "glove_model = gensim.downloader.load(\"glove-twitter-25\") # load glove vectors\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[-0.96419 -0.60978 0.67449 0.35113 0.41317 -0.21241 1.3796\n",
+ " 0.12854 0.31567 0.66325 0.3391 -0.18934 -3.325 -1.1491\n",
+ " -0.4129 0.2195 0.8706 -0.50616 -0.12781 -0.066965 0.065761\n",
+ " 0.43927 0.1758 -0.56058 0.13529 ]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[('dog', 0.9590820074081421),\n",
+ " ('monkey', 0.920357882976532),\n",
+ " ('bear', 0.9143136739730835),\n",
+ " ('pet', 0.9108031392097473),\n",
+ " ('girl', 0.8880629539489746),\n",
+ " ('horse', 0.8872726559638977),\n",
+ " ('kitty', 0.8870542049407959),\n",
+ " ('puppy', 0.886769711971283),\n",
+ " ('hot', 0.886525571346283),\n",
+ " ('lady', 0.8845519423484802)]"
+ ]
+ },
+ "execution_count": 36,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Можем посмотреть, что за векторы у слова \"cat\"\n",
+ "print(glove_model['cat']) # word embedding for 'cat'\n",
+ "glove_model.most_similar(\"cat\") # show words that similar to word 'cat'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Векторизуем обучающую выборку\n",
+ "Получаем матрицу \"Документ-термин\". В столбцах - все слова выборки, в строках - каждый документ. Значения в ячейках - количество встречаний слова в документе. \n",
+ "В отображенной части видим только нули - наглядная демонстрация проблем разреженности матрицы"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(1064, 16170)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 00 \n",
+ " 000 \n",
+ " 000005102000 \n",
+ " 000100255pixel \n",
+ " 0007 \n",
+ " 000usd \n",
+ " 001200201pixel \n",
+ " 00196 \n",
+ " 0028 \n",
+ " 0038 \n",
+ " ... \n",
+ " zorg \n",
+ " zorn \n",
+ " zsoft \n",
+ " zues \n",
+ " zug \n",
+ " zur \n",
+ " zurich \n",
+ " zus \n",
+ " zvi \n",
+ " zyxel \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " ... \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " ... \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " ... \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " \n",
+ " \n",
+ " 3 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " ... \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " ... \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
5 rows × 16170 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 00 000 000005102000 000100255pixel 0007 000usd 001200201pixel 00196 \\\n",
+ "0 0 0 0 0 0 0 0 0 \n",
+ "1 0 0 0 0 0 0 0 0 \n",
+ "2 0 0 0 0 0 0 0 0 \n",
+ "3 0 0 0 0 0 0 0 0 \n",
+ "4 0 0 0 0 0 0 0 0 \n",
+ "\n",
+ " 0028 0038 ... zorg zorn zsoft zues zug zur zurich zus zvi zyxel \n",
+ "0 0 0 ... 0 0 0 0 0 0 0 0 0 0 \n",
+ "1 0 0 ... 0 0 0 0 0 0 0 0 0 0 \n",
+ "2 0 0 ... 0 0 0 0 0 0 0 0 0 0 \n",
+ "3 0 0 ... 0 0 0 0 0 0 0 0 0 0 \n",
+ "4 0 0 ... 0 0 0 0 0 0 0 0 0 0 \n",
+ "\n",
+ "[5 rows x 16170 columns]"
+ ]
+ },
+ "execution_count": 40,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "vectorizer = CountVectorizer(stop_words='english')\n",
+ "train_data = vectorizer.fit_transform(twenty_train['data'])\n",
+ "CV_data=pd.DataFrame(train_data.toarray(), columns=vectorizer.get_feature_names_out())\n",
+ "print(CV_data.shape)\n",
+ "CV_data.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Index(['00', '000', '000005102000', '000100255pixel', '0007', '000usd',\n",
+ " '001200201pixel', '00196', '0028', '0038'],\n",
+ " dtype='str')\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Создадим список уникальных слов, присутствующих в словаре.\n",
+ "words_vocab=CV_data.columns\n",
+ "print(words_vocab[0:10])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Векторизуем с помощью GloVe\n",
+ "\n",
+ "Нужно для каждого документа сложить glove-вектора слов, из которых он состоит.\n",
+ "В результате получим вектор документа как сумму векторов слов, из него состоящих"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Посмотрим на примере как будет работать векторизация"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 00 \n",
+ " 000 \n",
+ " 000005102000 \n",
+ " 000100255pixel \n",
+ " 0007 \n",
+ " 000usd \n",
+ " 001200201pixel \n",
+ " 00196 \n",
+ " 0028 \n",
+ " 0038 \n",
+ " ... \n",
+ " zorg \n",
+ " zorn \n",
+ " zsoft \n",
+ " zues \n",
+ " zug \n",
+ " zur \n",
+ " zurich \n",
+ " zus \n",
+ " zvi \n",
+ " zyxel \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " ... \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 1 \n",
+ " 1 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " ... \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 0 \n",
+ " 1 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
2 rows × 16170 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 00 000 000005102000 000100255pixel 0007 000usd 001200201pixel 00196 \\\n",
+ "0 0 0 0 0 0 0 0 0 \n",
+ "1 1 1 0 0 0 0 0 0 \n",
+ "\n",
+ " 0028 0038 ... zorg zorn zsoft zues zug zur zurich zus zvi zyxel \n",
+ "0 0 0 ... 0 0 0 0 0 0 0 0 0 0 \n",
+ "1 0 0 ... 0 0 0 0 0 0 0 0 0 1 \n",
+ "\n",
+ "[2 rows x 16170 columns]"
+ ]
+ },
+ "execution_count": 42,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Пусть выборка состоит из двух документов: \n",
+ "text_data = ['Hello world I love python', 'This is a great computer game! 00 000 zyxel']\n",
+ "# Векторизуем с помощью обученного CountVectorizer\n",
+ "X = vectorizer.transform(text_data)\n",
+ "CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())\n",
+ "CV_text_data\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "hello : [-0.77069 0.12827 0.33137 0.0050893 -0.47605 -0.50116\n",
+ " 1.858 1.0624 -0.56511 0.13328 -0.41918 -0.14195\n",
+ " -2.8555 -0.57131 -0.13418 -0.44922 0.48591 -0.6479\n",
+ " -0.84238 0.61669 -0.19824 -0.57967 -0.65885 0.43928\n",
+ " -0.50473 ]\n",
+ "love : [-0.62645 -0.082389 0.070538 0.5782 -0.87199 -0.14816 2.2315\n",
+ " 0.98573 -1.3154 -0.34921 -0.8847 0.14585 -4.97 -0.73369\n",
+ " -0.94359 0.035859 -0.026733 -0.77538 -0.30014 0.48853 -0.16678\n",
+ " -0.016651 -0.53164 0.64236 -0.10922 ]\n",
+ "python : [-0.25645 -0.22323 0.025901 0.22901 0.49028 -0.060829 0.24563\n",
+ " -0.84854 1.5882 -0.7274 0.60603 0.25205 -1.8064 -0.95526\n",
+ " 0.44867 0.013614 0.60856 0.65423 0.82506 0.99459 -0.29403\n",
+ " -0.27013 -0.348 -0.7293 0.2201 ]\n",
+ "world : [ 0.10301 0.095666 -0.14789 -0.22383 -0.14775 -0.11599 1.8513\n",
+ " 0.24886 -0.41877 -0.20384 -0.08509 0.33246 -4.6946 0.84096\n",
+ " -0.46666 -0.031128 -0.19539 -0.037349 0.58949 0.13941 -0.57667\n",
+ " -0.44426 -0.43085 -0.52875 0.25855 ]\n",
+ "Hello world I love python : [ -1.55058002 -0.081683 0.27991899 0.58846928 -1.00551002\n",
+ " -0.82613902 6.18642995 1.44844997 -0.71108004 -1.14717001\n",
+ " -0.78294002 0.58841 -14.32649982 -1.41929996 -1.09575997\n",
+ " -0.430875 0.87234702 -0.806399 0.27203003 2.23921998\n",
+ " -1.23571999 -1.31071102 -1.96934 -0.17641005 -0.1353 ]\n",
+ "computer : [ 0.64005 -0.019514 0.70148 -0.66123 1.1723 -0.58859 0.25917\n",
+ " -0.81541 1.1708 1.1413 -0.15405 -0.11369 -3.8414 -0.87233\n",
+ " 0.47489 1.1541 0.97678 1.1107 -0.14572 -0.52013 -0.52234\n",
+ " -0.92349 0.34651 0.061939 -0.57375 ]\n",
+ "game : [ 1.146 0.3291 0.26878 -1.3945 -0.30044 0.77901 1.3537\n",
+ " 0.37393 0.50478 -0.44266 -0.048706 0.51396 -4.3136 0.39805\n",
+ " 1.197 0.10287 -0.17618 -1.2881 -0.59801 0.26131 -1.2619\n",
+ " 0.39202 0.59309 -0.55232 0.005087]\n",
+ "great : [-8.4229e-01 3.6512e-01 -3.8841e-01 -4.6118e-01 2.4301e-01 3.2412e-01\n",
+ " 1.9009e+00 -2.2630e-01 -3.1335e-01 -1.0970e+00 -4.1494e-03 6.2074e-01\n",
+ " -5.0964e+00 6.7418e-01 5.0080e-01 -6.2119e-01 5.1765e-01 -4.4122e-01\n",
+ " -1.4364e-01 1.9130e-01 -7.4608e-01 -2.5903e-01 -7.8010e-01 1.1030e-01\n",
+ " -2.7928e-01]\n",
+ "zyxel : [ 0.79234 0.067376 -0.22639 -2.2272 0.30057 -0.85676 -1.7268\n",
+ " -0.78626 1.2042 -0.92348 -0.83987 -0.74233 0.29689 -1.208\n",
+ " 0.98706 -1.1624 0.61415 -0.27825 0.27813 1.5838 -0.63593\n",
+ " -0.10225 1.7102 -0.95599 -1.3867 ]\n",
+ "This is a great computer game! 00 000 zyxel : [ 1.73610002 0.74208201 0.35545996 -4.74411008 1.41543998\n",
+ " -0.34222007 1.78697008 -1.45404002 2.56643 -1.32184002\n",
+ " -1.04677537 0.27867999 -12.95450976 -1.00809997 3.15975004\n",
+ " -0.52662008 1.93239999 -0.89686999 -0.60924001 1.51628\n",
+ " -3.16624993 -0.89275002 1.86969995 -1.33607102 -2.23464306]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Создадим датафрейм, в который будем сохранять вектор документа\n",
+ "glove_data=pd.DataFrame()\n",
+ "\n",
+ "# Пробегаем по каждой строке датафрейма (по каждому документу)\n",
+ "for i in range(CV_text_data.shape[0]):\n",
+ " \n",
+ " # Вектор одного документа с размерностью glove-модели:\n",
+ " one_doc = np.zeros(25) \n",
+ " \n",
+ " # Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре\n",
+ " # Суммируем glove-вектора каждого известного слова в one_doc\n",
+ " for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:\n",
+ " if word in glove_model.key_to_index.keys(): \n",
+ " print(word, ': ', glove_model[word])\n",
+ " one_doc += glove_model[word]\n",
+ " print(text_data[i], ': ', one_doc)\n",
+ " glove_data= pd.concat([glove_data, pd.DataFrame([one_doc])], ignore_index=True) \n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "На основе ячейки выше, напишем функцию, которая для каждого слова в тексте ищет это слово в glove_model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def text2vec(text_data):\n",
+ " \n",
+ " # Векторизуем с помощью обученного CountVectorizer\n",
+ " X = vectorizer.transform(text_data)\n",
+ " CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())\n",
+ " CV_text_data\n",
+ " # Создадим датафрейм, в который будем сохранять вектор документа\n",
+ " glove_data=pd.DataFrame()\n",
+ "\n",
+ " # Пробегаем по каждой строке (по каждому документу)\n",
+ " for i in range(CV_text_data.shape[0]):\n",
+ "\n",
+ " # Вектор одного документа с размерностью glove-модели:\n",
+ " one_doc = np.zeros(25) \n",
+ "\n",
+ " # Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре\n",
+ " # Суммируем glove-вектора каждого известного слова в one_doc\n",
+ " for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:\n",
+ " if word in glove_model.key_to_index.keys(): \n",
+ " #print(word, ': ', glove_model[word])\n",
+ " one_doc += glove_model[word]\n",
+ " #print(text_data[i], ': ', one_doc)\n",
+ " glove_data = pd.concat([glove_data, pd.DataFrame([one_doc])], axis = 0)\n",
+ " #print('glove_data: ', glove_data)\n",
+ " return glove_data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 1 \n",
+ " 2 \n",
+ " 3 \n",
+ " 4 \n",
+ " 5 \n",
+ " 6 \n",
+ " 7 \n",
+ " 8 \n",
+ " 9 \n",
+ " ... \n",
+ " 15 \n",
+ " 16 \n",
+ " 17 \n",
+ " 18 \n",
+ " 19 \n",
+ " 20 \n",
+ " 21 \n",
+ " 22 \n",
+ " 23 \n",
+ " 24 \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " -1.55058 \n",
+ " -0.081683 \n",
+ " 0.279919 \n",
+ " 0.588469 \n",
+ " -1.00551 \n",
+ " -0.826139 \n",
+ " 6.18643 \n",
+ " 1.44845 \n",
+ " -0.71108 \n",
+ " -1.14717 \n",
+ " ... \n",
+ " -0.430875 \n",
+ " 0.872347 \n",
+ " -0.806399 \n",
+ " 0.27203 \n",
+ " 2.23922 \n",
+ " -1.23572 \n",
+ " -1.310711 \n",
+ " -1.96934 \n",
+ " -0.176410 \n",
+ " -0.135300 \n",
+ " \n",
+ " \n",
+ " 1 \n",
+ " 1.73610 \n",
+ " 0.742082 \n",
+ " 0.355460 \n",
+ " -4.744110 \n",
+ " 1.41544 \n",
+ " -0.342220 \n",
+ " 1.78697 \n",
+ " -1.45404 \n",
+ " 2.56643 \n",
+ " -1.32184 \n",
+ " ... \n",
+ " -0.526620 \n",
+ " 1.932400 \n",
+ " -0.896870 \n",
+ " -0.60924 \n",
+ " 1.51628 \n",
+ " -3.16625 \n",
+ " -0.892750 \n",
+ " 1.86970 \n",
+ " -1.336071 \n",
+ " -2.234643 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
2 rows × 25 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2 3 4 5 6 7 \\\n",
+ "0 -1.55058 -0.081683 0.279919 0.588469 -1.00551 -0.826139 6.18643 1.44845 \n",
+ "1 1.73610 0.742082 0.355460 -4.744110 1.41544 -0.342220 1.78697 -1.45404 \n",
+ "\n",
+ " 8 9 ... 15 16 17 18 19 \\\n",
+ "0 -0.71108 -1.14717 ... -0.430875 0.872347 -0.806399 0.27203 2.23922 \n",
+ "1 2.56643 -1.32184 ... -0.526620 1.932400 -0.896870 -0.60924 1.51628 \n",
+ "\n",
+ " 20 21 22 23 24 \n",
+ "0 -1.23572 -1.310711 -1.96934 -0.176410 -0.135300 \n",
+ "1 -3.16625 -0.892750 1.86970 -1.336071 -2.234643 \n",
+ "\n",
+ "[2 rows x 25 columns]"
+ ]
+ },
+ "execution_count": 47,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Наша выборка, векторизованная через glove выглядит так. \n",
+ "# Попытаться понять, почему такая размерность матрицы и что за значения в ячейках\n",
+ "glove_data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Можем использовать написанную функцию на нашей реальной выборке\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 1 \n",
+ " 2 \n",
+ " 3 \n",
+ " 4 \n",
+ " 5 \n",
+ " 6 \n",
+ " 7 \n",
+ " 8 \n",
+ " 9 \n",
+ " ... \n",
+ " 15 \n",
+ " 16 \n",
+ " 17 \n",
+ " 18 \n",
+ " 19 \n",
+ " 20 \n",
+ " 21 \n",
+ " 22 \n",
+ " 23 \n",
+ " 24 \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 0.499258 \n",
+ " 3.090499 \n",
+ " -17.267230 \n",
+ " -0.997832 \n",
+ " -2.746523 \n",
+ " 1.586072 \n",
+ " 25.890665 \n",
+ " -27.914878 \n",
+ " -1.656527 \n",
+ " -6.599336 \n",
+ " ... \n",
+ " 6.840056 \n",
+ " 5.475214 \n",
+ " 4.488208 \n",
+ " 17.680875 \n",
+ " -8.254198 \n",
+ " 1.772313 \n",
+ " 6.184074 \n",
+ " -8.879014 \n",
+ " -0.216338 \n",
+ " -13.408250 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 2.579560 \n",
+ " -1.498130 \n",
+ " 1.752350 \n",
+ " 0.407907 \n",
+ " 1.528970 \n",
+ " 0.152390 \n",
+ " 0.283640 \n",
+ " 1.937990 \n",
+ " -0.659960 \n",
+ " -2.136370 \n",
+ " ... \n",
+ " 0.434040 \n",
+ " -0.274430 \n",
+ " 0.445033 \n",
+ " -0.645750 \n",
+ " 1.055500 \n",
+ " -0.850250 \n",
+ " -0.433200 \n",
+ " 0.212190 \n",
+ " -1.164330 \n",
+ " 0.905880 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " -3.838998 \n",
+ " 3.339424 \n",
+ " 3.943270 \n",
+ " -1.428108 \n",
+ " 4.171174 \n",
+ " 1.918481 \n",
+ " 9.177921 \n",
+ " -3.241092 \n",
+ " 1.280560 \n",
+ " -1.781829 \n",
+ " ... \n",
+ " -3.719443 \n",
+ " 9.627485 \n",
+ " -1.198027 \n",
+ " -4.010974 \n",
+ " 1.464305 \n",
+ " 0.398107 \n",
+ " -1.013318 \n",
+ " -1.591893 \n",
+ " 4.290867 \n",
+ " -3.888199 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " -1.182812 \n",
+ " 51.118736 \n",
+ " -28.117878 \n",
+ " -28.586500 \n",
+ " 24.244726 \n",
+ " -9.567325 \n",
+ " 111.896700 \n",
+ " -114.387355 \n",
+ " 12.647457 \n",
+ " -24.559720 \n",
+ " ... \n",
+ " 24.085172 \n",
+ " 39.112016 \n",
+ " -23.011944 \n",
+ " 41.767239 \n",
+ " -42.362471 \n",
+ " -16.060568 \n",
+ " -20.090755 \n",
+ " -4.457053 \n",
+ " -3.181086 \n",
+ " -90.876348 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " ... \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " \n",
+ " \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " -5.839209 \n",
+ " 12.729176 \n",
+ " -6.201106 \n",
+ " -11.255287 \n",
+ " 2.878065 \n",
+ " -11.993908 \n",
+ " 11.418405 \n",
+ " -15.479517 \n",
+ " 16.634253 \n",
+ " -10.534109 \n",
+ " ... \n",
+ " 11.665809 \n",
+ " 13.761738 \n",
+ " -2.579413 \n",
+ " -7.877240 \n",
+ " 3.299712 \n",
+ " -9.241404 \n",
+ " -16.102834 \n",
+ " 5.343221 \n",
+ " -3.646908 \n",
+ " -17.216441 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 1.286424 \n",
+ " 4.341360 \n",
+ " -29.110272 \n",
+ " 5.057644 \n",
+ " 5.610822 \n",
+ " -10.954865 \n",
+ " 29.974885 \n",
+ " -38.261129 \n",
+ " 10.369041 \n",
+ " -15.652271 \n",
+ " ... \n",
+ " 12.201855 \n",
+ " 5.913883 \n",
+ " 15.941801 \n",
+ " 26.251438 \n",
+ " -9.280780 \n",
+ " -6.656035 \n",
+ " 6.158630 \n",
+ " -14.878060 \n",
+ " -1.138172 \n",
+ " -25.532792 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 6.698275 \n",
+ " 8.135166 \n",
+ " -10.347349 \n",
+ " 0.165546 \n",
+ " 2.089456 \n",
+ " -7.695734 \n",
+ " 31.593479 \n",
+ " -38.582911 \n",
+ " -3.285587 \n",
+ " -6.087527 \n",
+ " ... \n",
+ " 16.575447 \n",
+ " 15.696628 \n",
+ " 0.785602 \n",
+ " 13.384657 \n",
+ " -13.329273 \n",
+ " 1.983042 \n",
+ " -4.041431 \n",
+ " -5.571411 \n",
+ " 0.581912 \n",
+ " -32.037549 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " -0.461364 \n",
+ " 4.499502 \n",
+ " -10.446904 \n",
+ " 1.502338 \n",
+ " -10.415269 \n",
+ " -8.801252 \n",
+ " 18.272721 \n",
+ " -10.398248 \n",
+ " -6.686713 \n",
+ " -0.250122 \n",
+ " ... \n",
+ " 5.758498 \n",
+ " 4.540718 \n",
+ " -0.372422 \n",
+ " 10.668546 \n",
+ " -6.377995 \n",
+ " 9.305774 \n",
+ " 11.568183 \n",
+ " -18.814704 \n",
+ " -1.204665 \n",
+ " -14.551387 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 8.348875 \n",
+ " 10.837893 \n",
+ " -1.528737 \n",
+ " -18.246933 \n",
+ " 33.950209 \n",
+ " -0.376254 \n",
+ " 48.017779 \n",
+ " -57.636160 \n",
+ " 25.231878 \n",
+ " -13.371779 \n",
+ " ... \n",
+ " 29.297609 \n",
+ " 34.342018 \n",
+ " -8.527635 \n",
+ " 13.555851 \n",
+ " -12.649073 \n",
+ " -28.706501 \n",
+ " -27.042654 \n",
+ " 4.292338 \n",
+ " 2.614328 \n",
+ " -48.834758 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
1064 rows × 25 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2 3 4 5 \\\n",
+ "0 0.499258 3.090499 -17.267230 -0.997832 -2.746523 1.586072 \n",
+ "0 2.579560 -1.498130 1.752350 0.407907 1.528970 0.152390 \n",
+ "0 -3.838998 3.339424 3.943270 -1.428108 4.171174 1.918481 \n",
+ "0 -1.182812 51.118736 -28.117878 -28.586500 24.244726 -9.567325 \n",
+ "0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
+ ".. ... ... ... ... ... ... \n",
+ "0 -5.839209 12.729176 -6.201106 -11.255287 2.878065 -11.993908 \n",
+ "0 1.286424 4.341360 -29.110272 5.057644 5.610822 -10.954865 \n",
+ "0 6.698275 8.135166 -10.347349 0.165546 2.089456 -7.695734 \n",
+ "0 -0.461364 4.499502 -10.446904 1.502338 -10.415269 -8.801252 \n",
+ "0 8.348875 10.837893 -1.528737 -18.246933 33.950209 -0.376254 \n",
+ "\n",
+ " 6 7 8 9 ... 15 16 \\\n",
+ "0 25.890665 -27.914878 -1.656527 -6.599336 ... 6.840056 5.475214 \n",
+ "0 0.283640 1.937990 -0.659960 -2.136370 ... 0.434040 -0.274430 \n",
+ "0 9.177921 -3.241092 1.280560 -1.781829 ... -3.719443 9.627485 \n",
+ "0 111.896700 -114.387355 12.647457 -24.559720 ... 24.085172 39.112016 \n",
+ "0 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 \n",
+ ".. ... ... ... ... ... ... ... \n",
+ "0 11.418405 -15.479517 16.634253 -10.534109 ... 11.665809 13.761738 \n",
+ "0 29.974885 -38.261129 10.369041 -15.652271 ... 12.201855 5.913883 \n",
+ "0 31.593479 -38.582911 -3.285587 -6.087527 ... 16.575447 15.696628 \n",
+ "0 18.272721 -10.398248 -6.686713 -0.250122 ... 5.758498 4.540718 \n",
+ "0 48.017779 -57.636160 25.231878 -13.371779 ... 29.297609 34.342018 \n",
+ "\n",
+ " 17 18 19 20 21 22 \\\n",
+ "0 4.488208 17.680875 -8.254198 1.772313 6.184074 -8.879014 \n",
+ "0 0.445033 -0.645750 1.055500 -0.850250 -0.433200 0.212190 \n",
+ "0 -1.198027 -4.010974 1.464305 0.398107 -1.013318 -1.591893 \n",
+ "0 -23.011944 41.767239 -42.362471 -16.060568 -20.090755 -4.457053 \n",
+ "0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
+ ".. ... ... ... ... ... ... \n",
+ "0 -2.579413 -7.877240 3.299712 -9.241404 -16.102834 5.343221 \n",
+ "0 15.941801 26.251438 -9.280780 -6.656035 6.158630 -14.878060 \n",
+ "0 0.785602 13.384657 -13.329273 1.983042 -4.041431 -5.571411 \n",
+ "0 -0.372422 10.668546 -6.377995 9.305774 11.568183 -18.814704 \n",
+ "0 -8.527635 13.555851 -12.649073 -28.706501 -27.042654 4.292338 \n",
+ "\n",
+ " 23 24 \n",
+ "0 -0.216338 -13.408250 \n",
+ "0 -1.164330 0.905880 \n",
+ "0 4.290867 -3.888199 \n",
+ "0 -3.181086 -90.876348 \n",
+ "0 0.000000 0.000000 \n",
+ ".. ... ... \n",
+ "0 -3.646908 -17.216441 \n",
+ "0 -1.138172 -25.532792 \n",
+ "0 0.581912 -32.037549 \n",
+ "0 -1.204665 -14.551387 \n",
+ "0 2.614328 -48.834758 \n",
+ "\n",
+ "[1064 rows x 25 columns]"
+ ]
+ },
+ "execution_count": 50,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_data_glove = text2vec(twenty_train['data'])\n",
+ "train_data_glove"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Мы получили матрицу, где каждый документ представлен вектором-строкой. Можем подавать эти вектора на вход для обучения классифкатора"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
KNeighborsClassifier() In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. \n",
+ "
\n",
+ "
\n",
+ " Parameters \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " n_neighbors\n",
+ " n_neighbors: int, default=5 Number of neighbors to use by default for :meth:`kneighbors` queries. \n",
+ " \n",
+ " \n",
+ " 5 \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " weights\n",
+ " weights: {'uniform', 'distance'}, callable or None, default='uniform' Weight function used in prediction. Possible values: - 'uniform' : uniform weights. All points in each neighborhood are weighted equally. - 'distance' : weight points by the inverse of their distance. in this case, closer neighbors of a query point will have a greater influence than neighbors which are further away. - [callable] : a user-defined function which accepts an array of distances, and returns an array of the same shape containing the weights. Refer to the example entitled :ref:`sphx_glr_auto_examples_neighbors_plot_classification.py` showing the impact of the `weights` parameter on the decision boundary. \n",
+ " \n",
+ " \n",
+ " 'uniform' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " algorithm\n",
+ " algorithm: {'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto' Algorithm used to compute the nearest neighbors: - 'ball_tree' will use :class:`BallTree` - 'kd_tree' will use :class:`KDTree` - 'brute' will use a brute-force search. - 'auto' will attempt to decide the most appropriate algorithm based on the values passed to :meth:`fit` method. Note: fitting on sparse input will override the setting of this parameter, using brute force. \n",
+ " \n",
+ " \n",
+ " 'auto' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " leaf_size\n",
+ " leaf_size: int, default=30 Leaf size passed to BallTree or KDTree. This can affect the speed of the construction and query, as well as the memory required to store the tree. The optimal value depends on the nature of the problem. \n",
+ " \n",
+ " \n",
+ " 30 \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " p\n",
+ " p: float, default=2 Power parameter for the Minkowski metric. When p = 1, this is equivalent to using manhattan_distance (l1), and euclidean_distance (l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used. This parameter is expected to be positive. \n",
+ " \n",
+ " \n",
+ " 2 \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " metric\n",
+ " metric: str or callable, default='minkowski' Metric to use for distance computation. Default is \"minkowski\", which results in the standard Euclidean distance when p = 2. See the documentation of `scipy.spatial.distance`_ and the metrics listed in :class:`~sklearn.metrics.pairwise.distance_metrics` for valid metric values. If metric is \"precomputed\", X is assumed to be a distance matrix and must be square during fit. X may be a :term:`sparse graph`, in which case only \"nonzero\" elements may be considered neighbors. If metric is a callable function, it takes two arrays representing 1D vectors as inputs and must return one value indicating the distance between those vectors. This works for Scipy's metrics, but is less efficient than passing the metric name as a string. \n",
+ " \n",
+ " \n",
+ " 'minkowski' \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " metric_params\n",
+ " metric_params: dict, default=None Additional keyword arguments for the metric function. \n",
+ " \n",
+ " \n",
+ " None \n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " n_jobs\n",
+ " n_jobs: int, default=None The number of parallel jobs to run for neighbors search. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details. Doesn't affect :meth:`fit` method. \n",
+ " \n",
+ " \n",
+ " None \n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ "KNeighborsClassifier()"
+ ]
+ },
+ "execution_count": 51,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "clf = KNeighborsClassifier(n_neighbors = 5)\n",
+ "clf.fit(train_data_glove, twenty_train['target'])\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Аналогичную операци проводим с тестовой частью выборки и оцениваем качество модели"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 1 \n",
+ " 2 \n",
+ " 3 \n",
+ " 4 \n",
+ " 5 \n",
+ " 6 \n",
+ " 7 \n",
+ " 8 \n",
+ " 9 \n",
+ " ... \n",
+ " 15 \n",
+ " 16 \n",
+ " 17 \n",
+ " 18 \n",
+ " 19 \n",
+ " 20 \n",
+ " 21 \n",
+ " 22 \n",
+ " 23 \n",
+ " 24 \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 2.118167 \n",
+ " 10.172361 \n",
+ " -12.929062 \n",
+ " -10.235129 \n",
+ " -5.895256 \n",
+ " -2.960585 \n",
+ " 13.365803 \n",
+ " -26.779260 \n",
+ " 0.811607 \n",
+ " -3.139421 \n",
+ " ... \n",
+ " 4.893207 \n",
+ " 4.652184 \n",
+ " 2.903990 \n",
+ " 13.077540 \n",
+ " -12.438726 \n",
+ " 2.602789 \n",
+ " -7.450777 \n",
+ " -5.202635 \n",
+ " -20.983743 \n",
+ " -14.107781 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " ... \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " 0.000000 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 4.873320 \n",
+ " 27.474242 \n",
+ " -10.773876 \n",
+ " 8.112889 \n",
+ " 0.623212 \n",
+ " -2.199475 \n",
+ " 59.094818 \n",
+ " -59.831612 \n",
+ " -2.963759 \n",
+ " -18.048856 \n",
+ " ... \n",
+ " 23.890453 \n",
+ " 12.469234 \n",
+ " -3.087457 \n",
+ " 12.981009 \n",
+ " -29.858967 \n",
+ " 13.105756 \n",
+ " 10.670630 \n",
+ " -17.378340 \n",
+ " -6.079635 \n",
+ " -41.408141 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " -3.543091 \n",
+ " 1.986200 \n",
+ " -2.191232 \n",
+ " 0.006059 \n",
+ " -2.351208 \n",
+ " 0.354323 \n",
+ " 5.127530 \n",
+ " -4.232705 \n",
+ " 0.855633 \n",
+ " -1.511475 \n",
+ " ... \n",
+ " -1.034241 \n",
+ " 2.211705 \n",
+ " -3.255103 \n",
+ " 3.145205 \n",
+ " 0.115450 \n",
+ " -0.696187 \n",
+ " 0.647662 \n",
+ " 2.188560 \n",
+ " 1.237610 \n",
+ " -3.731906 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " -0.810280 \n",
+ " -0.041986 \n",
+ " 0.309050 \n",
+ " 0.268330 \n",
+ " -0.735400 \n",
+ " -0.585930 \n",
+ " -0.207310 \n",
+ " 0.084744 \n",
+ " -0.024049 \n",
+ " -1.729000 \n",
+ " ... \n",
+ " -0.981390 \n",
+ " -0.348830 \n",
+ " -0.012710 \n",
+ " -0.399160 \n",
+ " 0.237930 \n",
+ " 0.608960 \n",
+ " 0.257370 \n",
+ " -0.611980 \n",
+ " -0.569530 \n",
+ " -0.627590 \n",
+ " \n",
+ " \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " ... \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " -2.713570 \n",
+ " 17.157689 \n",
+ " 3.372045 \n",
+ " -5.867599 \n",
+ " -6.276993 \n",
+ " -12.745005 \n",
+ " 25.654686 \n",
+ " -28.636649 \n",
+ " 0.390143 \n",
+ " -9.184376 \n",
+ " ... \n",
+ " 1.313231 \n",
+ " 11.260476 \n",
+ " -3.974922 \n",
+ " 13.881956 \n",
+ " -5.040806 \n",
+ " 4.269640 \n",
+ " 9.495090 \n",
+ " -6.392856 \n",
+ " -17.443242 \n",
+ " -8.663642 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 0.330970 \n",
+ " 4.718122 \n",
+ " -1.751891 \n",
+ " -2.035309 \n",
+ " -0.343631 \n",
+ " -4.570000 \n",
+ " 6.742370 \n",
+ " -4.367192 \n",
+ " -2.187082 \n",
+ " -1.016773 \n",
+ " ... \n",
+ " -0.032756 \n",
+ " 3.431770 \n",
+ " -4.601260 \n",
+ " -0.623544 \n",
+ " -1.379208 \n",
+ " -2.430931 \n",
+ " -7.340540 \n",
+ " -4.909410 \n",
+ " 1.812950 \n",
+ " -7.518150 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 2.374983 \n",
+ " 5.077650 \n",
+ " -3.002268 \n",
+ " -3.933630 \n",
+ " 2.078387 \n",
+ " -5.470335 \n",
+ " 7.654599 \n",
+ " -12.212619 \n",
+ " 7.815100 \n",
+ " -2.849330 \n",
+ " ... \n",
+ " -0.234639 \n",
+ " 5.634821 \n",
+ " -3.027221 \n",
+ " 1.661357 \n",
+ " 0.393759 \n",
+ " -6.279475 \n",
+ " -4.927666 \n",
+ " -2.592987 \n",
+ " 1.981002 \n",
+ " -12.255251 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 1.410159 \n",
+ " 0.770070 \n",
+ " -3.955777 \n",
+ " -2.791190 \n",
+ " 5.210700 \n",
+ " 1.621952 \n",
+ " 7.780780 \n",
+ " -9.859718 \n",
+ " 8.394491 \n",
+ " -5.347944 \n",
+ " ... \n",
+ " 1.896947 \n",
+ " 4.079237 \n",
+ " -1.207350 \n",
+ " 5.772389 \n",
+ " 1.186500 \n",
+ " -7.286901 \n",
+ " -2.485993 \n",
+ " -0.947197 \n",
+ " 0.788222 \n",
+ " -9.518580 \n",
+ " \n",
+ " \n",
+ " 0 \n",
+ " 5.710371 \n",
+ " 6.638114 \n",
+ " 0.178107 \n",
+ " -6.406337 \n",
+ " 9.719197 \n",
+ " -8.779200 \n",
+ " 8.866295 \n",
+ " -11.437220 \n",
+ " 4.871641 \n",
+ " -11.415244 \n",
+ " ... \n",
+ " 0.851029 \n",
+ " 4.182781 \n",
+ " -3.302557 \n",
+ " 1.213709 \n",
+ " -4.294990 \n",
+ " -7.751608 \n",
+ " -12.556040 \n",
+ " -10.339437 \n",
+ " 1.448986 \n",
+ " -10.394629 \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
708 rows × 25 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2 3 4 5 6 \\\n",
+ "0 2.118167 10.172361 -12.929062 -10.235129 -5.895256 -2.960585 13.365803 \n",
+ "0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
+ "0 4.873320 27.474242 -10.773876 8.112889 0.623212 -2.199475 59.094818 \n",
+ "0 -3.543091 1.986200 -2.191232 0.006059 -2.351208 0.354323 5.127530 \n",
+ "0 -0.810280 -0.041986 0.309050 0.268330 -0.735400 -0.585930 -0.207310 \n",
+ ".. ... ... ... ... ... ... ... \n",
+ "0 -2.713570 17.157689 3.372045 -5.867599 -6.276993 -12.745005 25.654686 \n",
+ "0 0.330970 4.718122 -1.751891 -2.035309 -0.343631 -4.570000 6.742370 \n",
+ "0 2.374983 5.077650 -3.002268 -3.933630 2.078387 -5.470335 7.654599 \n",
+ "0 1.410159 0.770070 -3.955777 -2.791190 5.210700 1.621952 7.780780 \n",
+ "0 5.710371 6.638114 0.178107 -6.406337 9.719197 -8.779200 8.866295 \n",
+ "\n",
+ " 7 8 9 ... 15 16 17 \\\n",
+ "0 -26.779260 0.811607 -3.139421 ... 4.893207 4.652184 2.903990 \n",
+ "0 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.000000 \n",
+ "0 -59.831612 -2.963759 -18.048856 ... 23.890453 12.469234 -3.087457 \n",
+ "0 -4.232705 0.855633 -1.511475 ... -1.034241 2.211705 -3.255103 \n",
+ "0 0.084744 -0.024049 -1.729000 ... -0.981390 -0.348830 -0.012710 \n",
+ ".. ... ... ... ... ... ... ... \n",
+ "0 -28.636649 0.390143 -9.184376 ... 1.313231 11.260476 -3.974922 \n",
+ "0 -4.367192 -2.187082 -1.016773 ... -0.032756 3.431770 -4.601260 \n",
+ "0 -12.212619 7.815100 -2.849330 ... -0.234639 5.634821 -3.027221 \n",
+ "0 -9.859718 8.394491 -5.347944 ... 1.896947 4.079237 -1.207350 \n",
+ "0 -11.437220 4.871641 -11.415244 ... 0.851029 4.182781 -3.302557 \n",
+ "\n",
+ " 18 19 20 21 22 23 \\\n",
+ "0 13.077540 -12.438726 2.602789 -7.450777 -5.202635 -20.983743 \n",
+ "0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
+ "0 12.981009 -29.858967 13.105756 10.670630 -17.378340 -6.079635 \n",
+ "0 3.145205 0.115450 -0.696187 0.647662 2.188560 1.237610 \n",
+ "0 -0.399160 0.237930 0.608960 0.257370 -0.611980 -0.569530 \n",
+ ".. ... ... ... ... ... ... \n",
+ "0 13.881956 -5.040806 4.269640 9.495090 -6.392856 -17.443242 \n",
+ "0 -0.623544 -1.379208 -2.430931 -7.340540 -4.909410 1.812950 \n",
+ "0 1.661357 0.393759 -6.279475 -4.927666 -2.592987 1.981002 \n",
+ "0 5.772389 1.186500 -7.286901 -2.485993 -0.947197 0.788222 \n",
+ "0 1.213709 -4.294990 -7.751608 -12.556040 -10.339437 1.448986 \n",
+ "\n",
+ " 24 \n",
+ "0 -14.107781 \n",
+ "0 0.000000 \n",
+ "0 -41.408141 \n",
+ "0 -3.731906 \n",
+ "0 -0.627590 \n",
+ ".. ... \n",
+ "0 -8.663642 \n",
+ "0 -7.518150 \n",
+ "0 -12.255251 \n",
+ "0 -9.518580 \n",
+ "0 -10.394629 \n",
+ "\n",
+ "[708 rows x 25 columns]"
+ ]
+ },
+ "execution_count": 53,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test_data_glove = text2vec(twenty_test['data'])\n",
+ "test_data_glove"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "predict = clf.predict(test_data_glove )\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[258 61]\n",
+ " [ 31 358]]\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.89 0.81 0.85 319\n",
+ " 1 0.85 0.92 0.89 389\n",
+ "\n",
+ " accuracy 0.87 708\n",
+ " macro avg 0.87 0.86 0.87 708\n",
+ "weighted avg 0.87 0.87 0.87 708\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print (confusion_matrix(twenty_test['target'], predict))\n",
+ "print(classification_report(twenty_test['target'], predict))"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -410,7 +4822,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -424,7 +4836,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.13"
+ "version": "3.12.3"
}
},
"nbformat": 4,