Files
TDA/labs/OATD_LR2_metod.ipynb
2026-03-17 08:45:31 +03:00

233 KiB
Исходник Ответственный История

Методические указания к лабораторной работе №2

В данной работе мы продолжаем работать с библиотекой scikit-learn (http://scikit-learn.org), и хотим выяснить ее возможности при работе с текстовыми документами.

Ниже приведены новые модули, которые будут использованы в данной работе:

Для проведения стемминга предлагается использовать библиотеку NLTK и стеммер Портера

Импорт библиотек

from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer 
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, roc_auc_score

Загрузка выборки

Выборка 20 news groups представляет собой сообщения, котороые состоят из заголовка (header), основной части, подписи или сноски (footer), а также могут содержать в себе цитирование предыдущего сообщения (quotes). Модуль fetch_20newsgroups позволяет выбирать интересующие тематики и удалять ненужные части сообщений. Для того чтобы выбрать сообщения по интересующим тематикам, необходимо передать список тематик в параметр categories.

Для того чтобы удалить ненужные части сообщений, нужно передать их в параметр remove. Кроме того, важным параметром fetch_20newsgroups является subset - тип выборки – обучающая или тестовая.

Выберем сообщения по тематикам Атеизм и Компьютерная графика, а также укажем, что нас не интересуют заголовки, цитаты и подписи:

categories = ['alt.atheism', 'comp.graphics'] 
remove = ('headers', 'footers', 'quotes')
twenty_train = fetch_20newsgroups(subset='train', shuffle=True, random_state=42, categories = categories, remove = remove )
twenty_test = fetch_20newsgroups(subset='test', shuffle=True, random_state=42, categories = categories, remove = remove )

Возвращаемый набор данных — это scikit-learn совокупность: одномерный контейнер с полями, которые могут интерпретироваться как признаки объекта (object attributes). Например, target_names содержит список названий запрошенных категорий, target - тематику сообщения, а data – непосредственно текст сообщения:

print (twenty_train.data[2])
Does anyone know of any good shareware animation or paint software for an SGI
 machine?  I've exhausted everyplace on the net I can find and still don't hava
 a nice piece of software.

Thanks alot!

Chad


Векторизация

Чтобы использовать машинное обучение на текстовых документах, первым делом, нужно перевести текстовое содержимое в числовой вектор признаков. Предобработка текста, токенизация и отбрасывание стоп-слов включены в состав модуля CountVectorizer, который позволяет создать словарь характерных признаков и перевести документы в векторы признаков. Создадим объект-векторизатор vect со следующими параметрами:

  • max_features = 10000 - количество наиболее частотных терминов, из которых будет состоять словарь. По умолчанию используются все слова.
  • stop_words = 'english' – на данный момент модулем поддерживается отсечение английских стоп-слов. Кроме того, здесь можно указать список стоп-слов вручную. Если параметр не указывать, будут использованы все термины словаря.
vect = CountVectorizer(max_features = 10000, stop_words = 'english')

Также, в работе может потребоваться настройка следующих параметров:

  • max_df - float в диапазоне [0.0, 1.0] или int, по умолчанию = 1.0. При построении словаря игнорирует термины, частота которых в документе строго превышает заданный порог (стоп-слова для конкретного корпуса). Если float, параметр обозначает долю документов, если целое число – то абсолютное значение.
  • min_df - float в диапазоне [0.0, 1.0] или int, по умолчанию = 1.0. При построении словаря игнорируйте термины, частота которых в документе строго ниже заданного порога. В литературе это значение также называется порогом. Если float, параметр обозначает долю документов, если целое число – то абсолютное значение.

После того как объект-векторизатор создан, необходимо создать словарь характерных признаков с помощью метода fit() и перевести документы в векторы признаков c помощью метода transform(), подав на него обучающую выборку:

vect.fit(twenty_train.data)

train_data = vect.transform(twenty_train.data)
test_data = vect.transform(twenty_test.data)

Также, можно отметить что эти два действия могут быть объединены одним методом fit_transform(). Однако, в этом случае нужно учесть, что для перевода тестовой выборки в вектор признаков, по-прежнему нужно использовать метод transform():

train_data = vect.fit_transform(twenty_train.data)
test_data = vect.transform(twenty_test.data)

Если для тестовых данных также воспользоваться методом fit_transform(), это приведет к перестроению словаря признаков и неправильным результатам классификации.

Следующий блок кода позволит вывести первые 10 терминов, упорядоченных по частоте встречаемости:

x = list(zip(vect.get_feature_names_out(), np.ravel(train_data.sum(axis=0))))
def SortbyTF(inputStr):
    return inputStr[1]
x.sort(key=SortbyTF, reverse = True)
print (x[:10])
[('image', np.int64(489)), ('don', np.int64(417)), ('graphics', np.int64(410)), ('god', np.int64(409)), ('people', np.int64(384)), ('does', np.int64(364)), ('edu', np.int64(349)), ('like', np.int64(329)), ('just', np.int64(327)), ('know', np.int64(319))]

Стемминг

Существует целый ряд алгоритмов стемминга. Один из популярных - алгоритм Портера, реализация которого приведена в библиотеке nltk Для проведения стемминга нужно создать объект PorterStemmer(). Стеммер работает таким образом: у созданного объекта PorterStemmer есть метод stem, производящий стемминга. Таким образом, необходимо каждую из частей выборки (обучающую и тестовую) разбить на отдельные документы, затем, проходя в цикле по каждому слову в документе, произвести стемминг и объединить эти слова в новый документ.

В ЛР проводить стемминг не требуется.

import nltk
from nltk.stem import *
from nltk import word_tokenize
nltk.download('punkt_tab')

porter_stemmer = PorterStemmer()
stem_train = []
for text in twenty_train.data:
    nltk_tokens = word_tokenize(text)
    line = ''
    for word in nltk_tokens:
        line += ' ' + porter_stemmer.stem(word)
    stem_train.append(line)
print (stem_train[0])

TF- и TF-IDF взвешивание

CountVectorizer позволяет лишь определять частоту встречаемости термина во всей выборке, но такой подход к выявлению информативных терминов не всегда дает качественный результат. На практике используют более продвинутые способы, наибольшее распространение из которых получили TF- и TF-IDF взвешивания. Воспользуемся методом fit() класса TfidfTransformer(), который переводит матрицу частот встречаемости в TF- и TF-IDF веса.

tfidf = TfidfTransformer(use_idf = True).fit(train_data)
train_data_tfidf = tfidf.transform(train_data)

Отметим, что в метод fit() нужно передавать не исходные текстовые данные, а вектор слов и их частот, полученный с помощью метода transform() класса CountVectorizer. Для того, чтобы получить tf-idf значения, необходимо установить параметр use_idf = True, в противном случае на выходе мы получим значения tf

Классификация

После того как мы провели векторизацию текста, обучение модели и классификация для текстовых данных выглядит абсолютно идентично классификации объектов в первой лабораторной работе.

Задача обучения модели заключается не только в выборе подходящих данных обучающей выборки, способных качественно охарактеризовать объекты, но и в настройке многочисленных параметров метода классификации, предварительной обработке данных и т.д. Рассмотрим, какие возможности предлагаются в библиотеке scikit-learn для автоматизации и упрощения данной задачи.

Pipeline

Чтобы с цепочкой vectorizer => transformer => classifier было проще работать, в scikit-learn есть класс Pipeline (конвейер), который функционирует как составной (конвейерный) классификатор.

from sklearn.pipeline import Pipeline

Промежуточными шагами конвейера должны быть преобразования, то есть должны выполняться методы fit() и transform(), а последний шаг – только fit(). При этом, pipeline позволяет устанавливать различные параметры на каждом своем шаге. Таким образом, проделанные нами действия по векторизации данных, взвешиванию с помощью TF-IDF и классификации методом К-БС с использованием pipeline будут выглядеть следующим образом:

text_clf = Pipeline([('vect', CountVectorizer(max_features= 1000, stop_words = 'english')),
                    ('tfidf', TfidfTransformer(use_idf = True)),
                    ('clf', KNeighborsClassifier (n_neighbors=1)),])   

Названия vect, tfidf и clf выбраны нами произвольно. Мы рассмотрим их использование в следующей лабораторной работе. Теперь обучим модель с помощью всего 1 команды:

text_clf = text_clf.fit(twenty_train.data, twenty_train.target)

И проведем классификацию на тестовой выборке:

prediction = text_clf.predict(twenty_test.data)
from sklearn.model_selection import GridSearchCV

С помощью конвейерной обработки (pipelines) стало гораздо проще указывать параметры обучения модели, однако перебирать все возможные варианты вручную даже в нашем простом случае выйдет затратно. В нашем случае имеется четыре настраиваемых параметра: max_features, stop_words, use_idf и n_neighbors.

Вместо поиска лучших параметров в конвейере вручную, можно запустить поиск (методом полного перебора) лучших параметров в сетке возможных значений. Сделать это можно с помощью объекта класса GridSearchCV.

Для того чтобы задать сетку параметров необходимо создать переменную-словарь, ключами которого являются конструкции вида: «НазваниеШагаКонвейера__НазваниеПараметра», а значениями – кортеж из значений параметра

parameters = {'vect__max_features': (100,500,1000,5000,10000),
              'vect__stop_words': ('english', None),
              'tfidf__use_idf': (True, False),              
              'clf__n_neighbors': (1,3,5,7)} 

Далее необходимо создать объект класса GridSearchCV, передав в него объект pipeline или классификатор, список параметров сетки, а также при необходимости, задав прочие параметры, такие так количество задействованых ядер процессора n_jobs, количество фолдов кросс-валидации cv, метрику, по которой будем судить о качестве модели scoring, и другие

gs_clf = GridSearchCV(text_clf, parameters, n_jobs=-1, cv=3, scoring = 'f1_weighted')

Теперь объект gs_clf можно обучить по всем параметрам, как обычный классификатор, методом fit(). После того как прошло обучение, узнать лучшую совокупность параметров можно, обратившись к атрибуту best_params_. Для необученной модели атрибуты будут отсутствовать.

gs_clf.fit(twenty_train.data, twenty_train.target)
<style>#sk-container-id-1 { /* Definition of color scheme common for light and dark mode */ --sklearn-color-text: #000; --sklearn-color-text-muted: #666; --sklearn-color-line: gray; /* Definition of color scheme for unfitted estimators */ --sklearn-color-unfitted-level-0: #fff5e6; --sklearn-color-unfitted-level-1: #f6e4d2; --sklearn-color-unfitted-level-2: #ffe0b3; --sklearn-color-unfitted-level-3: chocolate; /* Definition of color scheme for fitted estimators */ --sklearn-color-fitted-level-0: #f0f8ff; --sklearn-color-fitted-level-1: #d4ebff; --sklearn-color-fitted-level-2: #b3dbfd; --sklearn-color-fitted-level-3: cornflowerblue; } #sk-container-id-1.light { /* Specific color for light theme */ --sklearn-color-text-on-default-background: black; --sklearn-color-background: white; --sklearn-color-border-box: black; --sklearn-color-icon: #696969; } #sk-container-id-1.dark { --sklearn-color-text-on-default-background: white; --sklearn-color-background: #111; --sklearn-color-border-box: white; --sklearn-color-icon: #878787; } #sk-container-id-1 { color: var(--sklearn-color-text); } #sk-container-id-1 pre { padding: 0; } #sk-container-id-1 input.sk-hidden--visually { border: 0; clip: rect(1px 1px 1px 1px); clip: rect(1px, 1px, 1px, 1px); height: 1px; margin: -1px; overflow: hidden; padding: 0; position: absolute; width: 1px; } #sk-container-id-1 div.sk-dashed-wrapped { border: 1px dashed var(--sklearn-color-line); margin: 0 0.4em 0.5em 0.4em; box-sizing: border-box; padding-bottom: 0.4em; background-color: var(--sklearn-color-background); } #sk-container-id-1 div.sk-container { /* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */ display: inline-block !important; position: relative; } #sk-container-id-1 div.sk-text-repr-fallback { display: none; } div.sk-parallel-item, div.sk-serial, div.sk-item { /* draw centered vertical line to link estimators */ background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background)); background-size: 2px 100%; background-repeat: no-repeat; background-position: center center; } /* Parallel-specific style estimator block */ #sk-container-id-1 div.sk-parallel-item::after { content: ""; width: 100%; border-bottom: 2px solid var(--sklearn-color-text-on-default-background); flex-grow: 1; } #sk-container-id-1 div.sk-parallel { display: flex; align-items: stretch; justify-content: center; background-color: var(--sklearn-color-background); position: relative; } #sk-container-id-1 div.sk-parallel-item { display: flex; flex-direction: column; } #sk-container-id-1 div.sk-parallel-item:first-child::after { align-self: flex-end; width: 50%; } #sk-container-id-1 div.sk-parallel-item:last-child::after { align-self: flex-start; width: 50%; } #sk-container-id-1 div.sk-parallel-item:only-child::after { width: 0; } /* Serial-specific style estimator block */ #sk-container-id-1 div.sk-serial { display: flex; flex-direction: column; align-items: center; background-color: var(--sklearn-color-background); padding-right: 1em; padding-left: 1em; } /* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is clickable and can be expanded/collapsed. - Pipeline and ColumnTransformer use this feature and define the default style - Estimators will overwrite some part of the style using the `sk-estimator` class */ /* Pipeline and ColumnTransformer style (default) */ #sk-container-id-1 div.sk-toggleable { /* Default theme specific background. It is overwritten whether we have a specific estimator or a Pipeline/ColumnTransformer */ background-color: var(--sklearn-color-background); } /* Toggleable label */ #sk-container-id-1 label.sk-toggleable__label { cursor: pointer; display: flex; width: 100%; margin-bottom: 0; padding: 0.5em; box-sizing: border-box; text-align: center; align-items: center; justify-content: center; gap: 0.5em; } #sk-container-id-1 label.sk-toggleable__label .caption { font-size: 0.6rem; font-weight: lighter; color: var(--sklearn-color-text-muted); } #sk-container-id-1 label.sk-toggleable__label-arrow:before { /* Arrow on the left of the label */ content: "▸"; float: left; margin-right: 0.25em; color: var(--sklearn-color-icon); } #sk-container-id-1 label.sk-toggleable__label-arrow:hover:before { color: var(--sklearn-color-text); } /* Toggleable content - dropdown */ #sk-container-id-1 div.sk-toggleable__content { display: none; text-align: left; /* unfitted */ background-color: var(--sklearn-color-unfitted-level-0); } #sk-container-id-1 div.sk-toggleable__content.fitted { /* fitted */ background-color: var(--sklearn-color-fitted-level-0); } #sk-container-id-1 div.sk-toggleable__content pre { margin: 0.2em; border-radius: 0.25em; color: var(--sklearn-color-text); /* unfitted */ background-color: var(--sklearn-color-unfitted-level-0); } #sk-container-id-1 div.sk-toggleable__content.fitted pre { /* unfitted */ background-color: var(--sklearn-color-fitted-level-0); } #sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content { /* Expand drop-down */ display: block; width: 100%; overflow: visible; } #sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before { content: "▾"; } /* Pipeline/ColumnTransformer-specific style */ #sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label { color: var(--sklearn-color-text); background-color: var(--sklearn-color-unfitted-level-2); } #sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label { background-color: var(--sklearn-color-fitted-level-2); } /* Estimator-specific style */ /* Colorize estimator box */ #sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label { /* unfitted */ background-color: var(--sklearn-color-unfitted-level-2); } #sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label { /* fitted */ background-color: var(--sklearn-color-fitted-level-2); } #sk-container-id-1 div.sk-label label.sk-toggleable__label, #sk-container-id-1 div.sk-label label { /* The background is the default theme color */ color: var(--sklearn-color-text-on-default-background); } /* On hover, darken the color of the background */ #sk-container-id-1 div.sk-label:hover label.sk-toggleable__label { color: var(--sklearn-color-text); background-color: var(--sklearn-color-unfitted-level-2); } /* Label box, darken color on hover, fitted */ #sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted { color: var(--sklearn-color-text); background-color: var(--sklearn-color-fitted-level-2); } /* Estimator label */ #sk-container-id-1 div.sk-label label { font-family: monospace; font-weight: bold; line-height: 1.2em; } #sk-container-id-1 div.sk-label-container { text-align: center; } /* Estimator-specific */ #sk-container-id-1 div.sk-estimator { font-family: monospace; border: 1px dotted var(--sklearn-color-border-box); border-radius: 0.25em; box-sizing: border-box; margin-bottom: 0.5em; /* unfitted */ background-color: var(--sklearn-color-unfitted-level-0); } #sk-container-id-1 div.sk-estimator.fitted { /* fitted */ background-color: var(--sklearn-color-fitted-level-0); } /* on hover */ #sk-container-id-1 div.sk-estimator:hover { /* unfitted */ background-color: var(--sklearn-color-unfitted-level-2); } #sk-container-id-1 div.sk-estimator.fitted:hover { /* fitted */ background-color: var(--sklearn-color-fitted-level-2); } /* Specification for estimator info (e.g. "i" and "?") */ /* Common style for "i" and "?" */ .sk-estimator-doc-link, a:link.sk-estimator-doc-link, a:visited.sk-estimator-doc-link { float: right; font-size: smaller; line-height: 1em; font-family: monospace; background-color: var(--sklearn-color-unfitted-level-0); border-radius: 1em; height: 1em; width: 1em; text-decoration: none !important; margin-left: 0.5em; text-align: center; /* unfitted */ border: var(--sklearn-color-unfitted-level-3) 1pt solid; color: var(--sklearn-color-unfitted-level-3); } .sk-estimator-doc-link.fitted, a:link.sk-estimator-doc-link.fitted, a:visited.sk-estimator-doc-link.fitted { /* fitted */ background-color: var(--sklearn-color-fitted-level-0); border: var(--sklearn-color-fitted-level-3) 1pt solid; color: var(--sklearn-color-fitted-level-3); } /* On hover */ div.sk-estimator:hover .sk-estimator-doc-link:hover, .sk-estimator-doc-link:hover, div.sk-label-container:hover .sk-estimator-doc-link:hover, .sk-estimator-doc-link:hover { /* unfitted */ background-color: var(--sklearn-color-unfitted-level-3); border: var(--sklearn-color-fitted-level-0) 1pt solid; color: var(--sklearn-color-unfitted-level-0); text-decoration: none; } div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover, .sk-estimator-doc-link.fitted:hover, div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover, .sk-estimator-doc-link.fitted:hover { /* fitted */ background-color: var(--sklearn-color-fitted-level-3); border: var(--sklearn-color-fitted-level-0) 1pt solid; color: var(--sklearn-color-fitted-level-0); text-decoration: none; } /* Span, style for the box shown on hovering the info icon */ .sk-estimator-doc-link span { display: none; z-index: 9999; position: relative; font-weight: normal; right: .2ex; padding: .5ex; margin: .5ex; width: min-content; min-width: 20ex; max-width: 50ex; color: var(--sklearn-color-text); box-shadow: 2pt 2pt 4pt #999; /* unfitted */ background: var(--sklearn-color-unfitted-level-0); border: .5pt solid var(--sklearn-color-unfitted-level-3); } .sk-estimator-doc-link.fitted span { /* fitted */ background: var(--sklearn-color-fitted-level-0); border: var(--sklearn-color-fitted-level-3); } .sk-estimator-doc-link:hover span { display: block; } /* "?"-specific style due to the `` HTML tag */ #sk-container-id-1 a.estimator_doc_link { float: right; font-size: 1rem; line-height: 1em; font-family: monospace; background-color: var(--sklearn-color-unfitted-level-0); border-radius: 1rem; height: 1rem; width: 1rem; text-decoration: none; /* unfitted */ color: var(--sklearn-color-unfitted-level-1); border: var(--sklearn-color-unfitted-level-1) 1pt solid; } #sk-container-id-1 a.estimator_doc_link.fitted { /* fitted */ background-color: var(--sklearn-color-fitted-level-0); border: var(--sklearn-color-fitted-level-1) 1pt solid; color: var(--sklearn-color-fitted-level-1); } /* On hover */ #sk-container-id-1 a.estimator_doc_link:hover { /* unfitted */ background-color: var(--sklearn-color-unfitted-level-3); color: var(--sklearn-color-background); text-decoration: none; } #sk-container-id-1 a.estimator_doc_link.fitted:hover { /* fitted */ background-color: var(--sklearn-color-fitted-level-3); } .estimator-table { font-family: monospace; } .estimator-table summary { padding: .5rem; cursor: pointer; } .estimator-table summary::marker { font-size: 0.7rem; } .estimator-table details[open] { padding-left: 0.1rem; padding-right: 0.1rem; padding-bottom: 0.3rem; } .estimator-table .parameters-table { margin-left: auto !important; margin-right: auto !important; margin-top: 0; } .estimator-table .parameters-table tr:nth-child(odd) { background-color: #fff; } .estimator-table .parameters-table tr:nth-child(even) { background-color: #f6f6f6; } .estimator-table .parameters-table tr:hover { background-color: #e0e0e0; } .estimator-table table td { border: 1px solid rgba(106, 105, 104, 0.232); } /* `table td`is set in notebook with right text-align. We need to overwrite it. */ .estimator-table table td.param { text-align: left; position: relative; padding: 0; } .user-set td { color:rgb(255, 94, 0); text-align: left !important; } .user-set td.value { color:rgb(255, 94, 0); background-color: transparent; } .default td { color: black; text-align: left !important; } .user-set td i, .default td i { color: black; } /* Styles for parameter documentation links We need styling for visited so jupyter doesn't overwrite it */ a.param-doc-link, a.param-doc-link:link, a.param-doc-link:visited { text-decoration: underline dashed; text-underline-offset: .3em; color: inherit; display: block; padding: .5em; } /* "hack" to make the entire area of the cell containing the link clickable */ a.param-doc-link::before { position: absolute; content: ""; inset: 0; } .param-doc-description { display: none; position: absolute; z-index: 9999; left: 0; padding: .5ex; margin-left: 1.5em; color: var(--sklearn-color-text); box-shadow: .3em .3em .4em #999; width: max-content; text-align: left; max-height: 10em; overflow-y: auto; /* unfitted */ background: var(--sklearn-color-unfitted-level-0); border: thin solid var(--sklearn-color-unfitted-level-3); } /* Fitted state for parameter tooltips */ .fitted .param-doc-description { /* fitted */ background: var(--sklearn-color-fitted-level-0); border: thin solid var(--sklearn-color-fitted-level-3); } .param-doc-link:hover .param-doc-description { display: block; } .copy-paste-icon { background-image: url(data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCA0NDggNTEyIj48IS0tIUZvbnQgQXdlc29tZSBGcmVlIDYuNy4yIGJ5IEBmb250YXdlc29tZSAtIGh0dHBzOi8vZm9udGF3ZXNvbWUuY29tIExpY2Vuc2UgLSBodHRwczovL2ZvbnRhd2Vzb21lLmNvbS9saWNlbnNlL2ZyZWUgQ29weXJpZ2h0IDIwMjUgRm9udGljb25zLCBJbmMuLS0+PHBhdGggZD0iTTIwOCAwTDMzMi4xIDBjMTIuNyAwIDI0LjkgNS4xIDMzLjkgMTQuMWw2Ny45IDY3LjljOSA5IDE0LjEgMjEuMiAxNC4xIDMzLjlMNDQ4IDMzNmMwIDI2LjUtMjEuNSA0OC00OCA0OGwtMTkyIDBjLTI2LjUgMC00OC0yMS41LTQ4LTQ4bDAtMjg4YzAtMjYuNSAyMS41LTQ4IDQ4LTQ4ek00OCAxMjhsODAgMCAwIDY0LTY0IDAgMCAyNTYgMTkyIDAgMC0zMiA2NCAwIDAgNDhjMCAyNi41LTIxLjUgNDgtNDggNDhMNDggNTEyYy0yNi41IDAtNDgtMjEuNS00OC00OEwwIDE3NmMwLTI2LjUgMjEuNS00OCA0OC00OHoiLz48L3N2Zz4=); background-repeat: no-repeat; background-size: 14px 14px; background-position: 0; display: inline-block; width: 14px; height: 14px; cursor: pointer; } </style>
GridSearchCV(cv=3,
             estimator=Pipeline(steps=[('vect',
                                        CountVectorizer(max_features=1000,
                                                        stop_words='english')),
                                       ('tfidf', TfidfTransformer()),
                                       ('clf',
                                        KNeighborsClassifier(n_neighbors=1))]),
             n_jobs=-1,
             param_grid={'clf__n_neighbors': (1, 3, 5, 7),
                         'tfidf__use_idf': (True, False),
                         'vect__max_features': (100, 500, 1000, 5000, 10000),
                         'vect__stop_words': ('english', None)},
             scoring='f1_weighted')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
estimator estimator: estimator object

This is assumed to implement the scikit-learn estimator interface.
Either estimator needs to provide a ``score`` function,
or ``scoring`` must be passed.
Pipeline(step...eighbors=1))])
param_grid param_grid: dict or list of dictionaries

Dictionary with parameters names (`str`) as keys and lists of
parameter settings to try as values, or a list of such
dictionaries, in which case the grids spanned by each dictionary
in the list are explored. This enables searching over any sequence
of parameter settings.
{'clf__n_neighbors': (1, ...), 'tfidf__use_idf': (True, ...), 'vect__max_features': (100, ...), 'vect__stop_words': ('english', ...)}
scoring scoring: str, callable, list, tuple or dict, default=None

Strategy to evaluate the performance of the cross-validated model on
the test set.

If `scoring` represents a single score, one can use:

- a single string (see :ref:`scoring_string_names`);
- a callable (see :ref:`scoring_callable`) that returns a single value;
- `None`, the `estimator`'s
:ref:`default evaluation criterion ` is used.

If `scoring` represents multiple scores, one can use:

- a list or tuple of unique strings;
- a callable returning a dictionary where the keys are the metric
names and the values are the metric scores;
- a dictionary with metric names as keys and callables as values.

See :ref:`multimetric_grid_search` for an example.
'f1_weighted'
n_jobs n_jobs: int, default=None

Number of jobs to run in parallel.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary `
for more details.

.. versionchanged:: v0.20
`n_jobs` default changed from 1 to None
-1
refit refit: bool, str, or callable, default=True

Refit an estimator using the best found parameters on the whole
dataset.

For multiple metric evaluation, this needs to be a `str` denoting the
scorer that would be used to find the best parameters for refitting
the estimator at the end.

Where there are considerations other than maximum score in
choosing a best estimator, ``refit`` can be set to a function which
returns the selected ``best_index_`` given ``cv_results_``. In that
case, the ``best_estimator_`` and ``best_params_`` will be set
according to the returned ``best_index_`` while the ``best_score_``
attribute will not be available.

The refitted estimator is made available at the ``best_estimator_``
attribute and permits using ``predict`` directly on this
``GridSearchCV`` instance.

Also for multiple metric evaluation, the attributes ``best_index_``,
``best_score_`` and ``best_params_`` will only be available if
``refit`` is set and all of them will be determined w.r.t this specific
scorer.

See ``scoring`` parameter to know more about multiple metric
evaluation.

See :ref:`sphx_glr_auto_examples_model_selection_plot_grid_search_digits.py`
to see how to design a custom selection strategy using a callable
via `refit`.

See :ref:`this example
`
for an example of how to use ``refit=callable`` to balance model
complexity and cross-validated score.

.. versionchanged:: 0.20
Support for callable added.
True
cv cv: int, cross-validation generator or an iterable, default=None

Determines the cross-validation splitting strategy.
Possible inputs for cv are:

- None, to use the default 5-fold cross validation,
- integer, to specify the number of folds in a `(Stratified)KFold`,
- :term:`CV splitter`,
- An iterable yielding (train, test) splits as arrays of indices.

For integer/None inputs, if the estimator is a classifier and ``y`` is
either binary or multiclass, :class:`StratifiedKFold` is used. In all
other cases, :class:`KFold` is used. These splitters are instantiated
with `shuffle=False` so the splits will be the same across calls.

Refer :ref:`User Guide ` for the various
cross-validation strategies that can be used here.

.. versionchanged:: 0.22
``cv`` default value if None changed from 3-fold to 5-fold.
3
verbose verbose: int

Controls the verbosity: the higher, the more messages.

- >1 : the computation time for each fold and parameter candidate is
displayed;
- >2 : the score is also displayed;
- >3 : the fold and candidate parameter indexes are also displayed
together with the starting time of the computation.
0
pre_dispatch pre_dispatch: int, or str, default='2*n_jobs'

Controls the number of jobs that get dispatched during parallel
execution. Reducing this number can be useful to avoid an
explosion of memory consumption when more jobs get dispatched
than CPUs can process. This parameter can be:

- None, in which case all the jobs are immediately created and spawned. Use
this for lightweight and fast-running jobs, to avoid delays due to on-demand
spawning of the jobs
- An int, giving the exact number of total jobs that are spawned
- A str, giving an expression as a function of n_jobs, as in '2*n_jobs'
'2*n_jobs'
error_score error_score: 'raise' or numeric, default=np.nan

Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised. If a numeric value is given,
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error.
nan
return_train_score return_train_score: bool, default=False

If ``False``, the ``cv_results_`` attribute will not include training
scores.
Computing training scores is used to get insights on how different
parameter settings impact the overfitting/underfitting trade-off.
However computing the scores on the training set can be computationally
expensive and is not strictly required to select the parameters that
yield the best generalization performance.

.. versionadded:: 0.19

.. versionchanged:: 0.21
Default value was changed from ``True`` to ``False``
False
Parameters
input input: {'filename', 'file', 'content'}, default='content'

- If `'filename'`, the sequence passed as an argument to fit is
expected to be a list of filenames that need reading to fetch
the raw content to analyze.

- If `'file'`, the sequence items must have a 'read' method (file-like
object) that is called to fetch the bytes in memory.

- If `'content'`, the input is expected to be a sequence of items that
can be of type string or byte.
'content'
encoding encoding: str, default='utf-8'

If bytes or files are given to analyze, this encoding is used to
decode.
'utf-8'
decode_error decode_error: {'strict', 'ignore', 'replace'}, default='strict'

Instruction on what to do if a byte sequence is given to analyze that
contains characters not of the given `encoding`. By default, it is
'strict', meaning that a UnicodeDecodeError will be raised. Other
values are 'ignore' and 'replace'.
'strict'
strip_accents strip_accents: {'ascii', 'unicode'} or callable, default=None

Remove accents and perform other character normalization
during the preprocessing step.
'ascii' is a fast method that only works on characters that have
a direct ASCII mapping.
'unicode' is a slightly slower method that works on any characters.
None (default) means no character normalization is performed.

Both 'ascii' and 'unicode' use NFKD normalization from
:func:`unicodedata.normalize`.
None
lowercase lowercase: bool, default=True

Convert all characters to lowercase before tokenizing.
True
preprocessor preprocessor: callable, default=None

Override the preprocessing (strip_accents and lowercase) stage while
preserving the tokenizing and n-grams generation steps.
Only applies if ``analyzer`` is not callable.
None
tokenizer tokenizer: callable, default=None

Override the string tokenization step while preserving the
preprocessing and n-grams generation steps.
Only applies if ``analyzer == 'word'``.
None
stop_words stop_words: {'english'}, list, default=None

If 'english', a built-in stop word list for English is used.
There are several known issues with 'english' and you should
consider an alternative (see :ref:`stop_words`).

If a list, that list is assumed to contain stop words, all of which
will be removed from the resulting tokens.
Only applies if ``analyzer == 'word'``.

If None, no stop words will be used. In this case, setting `max_df`
to a higher value, such as in the range (0.7, 1.0), can automatically detect
and filter stop words based on intra corpus document frequency of terms.
'english'
token_pattern token_pattern: str or None, default=r"(?u)\\b\\w\\w+\\b"

Regular expression denoting what constitutes a "token", only used
if ``analyzer == 'word'``. The default regexp select tokens of 2
or more alphanumeric characters (punctuation is completely ignored
and always treated as a token separator).

If there is a capturing group in token_pattern then the
captured group content, not the entire match, becomes the token.
At most one capturing group is permitted.
'(?u)\\b\\w\\w+\\b'
ngram_range ngram_range: tuple (min_n, max_n), default=(1, 1)

The lower and upper boundary of the range of n-values for different
word n-grams or char n-grams to be extracted. All values of n such
such that min_n <= n <= max_n will be used. For example an
``ngram_range`` of ``(1, 1)`` means only unigrams, ``(1, 2)`` means
unigrams and bigrams, and ``(2, 2)`` means only bigrams.
Only applies if ``analyzer`` is not callable.
(1, ...)
analyzer analyzer: {'word', 'char', 'char_wb'} or callable, default='word'

Whether the feature should be made of word n-gram or character
n-grams.
Option 'char_wb' creates character n-grams only from text inside
word boundaries; n-grams at the edges of words are padded with space.

If a callable is passed it is used to extract the sequence of features
out of the raw, unprocessed input.

.. versionchanged:: 0.21

Since v0.21, if ``input`` is ``filename`` or ``file``, the data is
first read from the file and then passed to the given callable
analyzer.
'word'
max_df max_df: float in range [0.0, 1.0] or int, default=1.0

When building the vocabulary ignore terms that have a document
frequency strictly higher than the given threshold (corpus-specific
stop words).
If float, the parameter represents a proportion of documents, integer
absolute counts.
This parameter is ignored if vocabulary is not None.
1.0
min_df min_df: float in range [0.0, 1.0] or int, default=1

When building the vocabulary ignore terms that have a document
frequency strictly lower than the given threshold. This value is also
called cut-off in the literature.
If float, the parameter represents a proportion of documents, integer
absolute counts.
This parameter is ignored if vocabulary is not None.
1
max_features max_features: int, default=None

If not None, build a vocabulary that only consider the top
`max_features` ordered by term frequency across the corpus.
Otherwise, all features are used.

This parameter is ignored if vocabulary is not None.
100
vocabulary vocabulary: Mapping or iterable, default=None

Either a Mapping (e.g., a dict) where keys are terms and values are
indices in the feature matrix, or an iterable over terms. If not
given, a vocabulary is determined from the input documents. Indices
in the mapping should not be repeated and should not have any gap
between 0 and the largest index.
None
binary binary: bool, default=False

If True, all non zero counts are set to 1. This is useful for discrete
probabilistic models that model binary events rather than integer
counts.
False
dtype dtype: dtype, default=np.int64

Type of the matrix returned by fit_transform() or transform().
<class 'numpy.int64'>
Parameters
norm norm: {'l1', 'l2'} or None, default='l2'

Each output row will have unit norm, either:

- 'l2': Sum of squares of vector elements is 1. The cosine
similarity between two vectors is their dot product when l2 norm has
been applied.
- 'l1': Sum of absolute values of vector elements is 1.
See :func:`~sklearn.preprocessing.normalize`.
- None: No normalization.
'l2'
use_idf use_idf: bool, default=True

Enable inverse-document-frequency reweighting. If False, idf(t) = 1.
True
smooth_idf smooth_idf: bool, default=True

Smooth idf weights by adding one to document frequencies, as if an
extra document was seen containing every term in the collection
exactly once. Prevents zero divisions.
True
sublinear_tf sublinear_tf: bool, default=False

Apply sublinear tf scaling, i.e. replace tf with 1 + log(tf).
False
Parameters
n_neighbors n_neighbors: int, default=5

Number of neighbors to use by default for :meth:`kneighbors` queries.
3
weights weights: {'uniform', 'distance'}, callable or None, default='uniform'

Weight function used in prediction. Possible values:

- 'uniform' : uniform weights. All points in each neighborhood
are weighted equally.
- 'distance' : weight points by the inverse of their distance.
in this case, closer neighbors of a query point will have a
greater influence than neighbors which are further away.
- [callable] : a user-defined function which accepts an
array of distances, and returns an array of the same shape
containing the weights.

Refer to the example entitled
:ref:`sphx_glr_auto_examples_neighbors_plot_classification.py`
showing the impact of the `weights` parameter on the decision
boundary.
'uniform'
algorithm algorithm: {'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto'

Algorithm used to compute the nearest neighbors:

- 'ball_tree' will use :class:`BallTree`
- 'kd_tree' will use :class:`KDTree`
- 'brute' will use a brute-force search.
- 'auto' will attempt to decide the most appropriate algorithm
based on the values passed to :meth:`fit` method.

Note: fitting on sparse input will override the setting of
this parameter, using brute force.
'auto'
leaf_size leaf_size: int, default=30

Leaf size passed to BallTree or KDTree. This can affect the
speed of the construction and query, as well as the memory
required to store the tree. The optimal value depends on the
nature of the problem.
30
p p: float, default=2

Power parameter for the Minkowski metric. When p = 1, this is equivalent
to using manhattan_distance (l1), and euclidean_distance (l2) for p = 2.
For arbitrary p, minkowski_distance (l_p) is used. This parameter is expected
to be positive.
2
metric metric: str or callable, default='minkowski'

Metric to use for distance computation. Default is "minkowski", which
results in the standard Euclidean distance when p = 2. See the
documentation of `scipy.spatial.distance
`_ and
the metrics listed in
:class:`~sklearn.metrics.pairwise.distance_metrics` for valid metric
values.

If metric is "precomputed", X is assumed to be a distance matrix and
must be square during fit. X may be a :term:`sparse graph`, in which
case only "nonzero" elements may be considered neighbors.

If metric is a callable function, it takes two arrays representing 1D
vectors as inputs and must return one value indicating the distance
between those vectors. This works for Scipy's metrics, but is less
efficient than passing the metric name as a string.
'minkowski'
metric_params metric_params: dict, default=None

Additional keyword arguments for the metric function.
None
n_jobs n_jobs: int, default=None

The number of parallel jobs to run for neighbors search.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary `
for more details.
Doesn't affect :meth:`fit` method.
None
<script>function copyToClipboard(text, element) { // Get the parameter prefix from the closest toggleable content const toggleableContent = element.closest('.sk-toggleable__content'); const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : ''; const fullParamName = paramPrefix ? `${paramPrefix}${text}` : text; const originalStyle = element.style; const computedStyle = window.getComputedStyle(element); const originalWidth = computedStyle.width; const originalHTML = element.innerHTML.replace('Copied!', ''); navigator.clipboard.writeText(fullParamName) .then(() => { element.style.width = originalWidth; element.style.color = 'green'; element.innerHTML = "Copied!"; setTimeout(() => { element.innerHTML = originalHTML; element.style = originalStyle; }, 2000); }) .catch(err => { console.error('Failed to copy:', err); element.style.color = 'red'; element.innerHTML = "Failed!"; setTimeout(() => { element.innerHTML = originalHTML; element.style = originalStyle; }, 2000); }); return false; } document.querySelectorAll('.copy-paste-icon').forEach(function(element) { const toggleableContent = element.closest('.sk-toggleable__content'); const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : ''; const paramName = element.parentElement.nextElementSibling .textContent.trim().split(' ')[0]; const fullParamName = paramPrefix ? `${paramPrefix}${paramName}` : paramName; element.setAttribute('title', fullParamName); }); /** * Adapted from Skrub * 403466d1d5/skrub/_reporting/_data/templates/report.js (L789) * @returns "light" or "dark" */ function detectTheme(element) { const body = document.querySelector('body'); // Check VSCode theme const themeKindAttr = body.getAttribute('data-vscode-theme-kind'); const themeNameAttr = body.getAttribute('data-vscode-theme-name'); if (themeKindAttr && themeNameAttr) { const themeKind = themeKindAttr.toLowerCase(); const themeName = themeNameAttr.toLowerCase(); if (themeKind.includes("dark") || themeName.includes("dark")) { return "dark"; } if (themeKind.includes("light") || themeName.includes("light")) { return "light"; } } // Check Jupyter theme if (body.getAttribute('data-jp-theme-light') === 'false') { return 'dark'; } else if (body.getAttribute('data-jp-theme-light') === 'true') { return 'light'; } // Guess based on a parent element's color const color = window.getComputedStyle(element.parentNode, null).getPropertyValue('color'); const match = color.match(/^rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)\s*$/i); if (match) { const [r, g, b] = [ parseFloat(match[1]), parseFloat(match[2]), parseFloat(match[3]) ]; // https://en.wikipedia.org/wiki/HSL_and_HSV#Lightness const luma = 0.299 * r + 0.587 * g + 0.114 * b; if (luma > 180) { // If the text is very bright we have a dark theme return 'dark'; } if (luma < 75) { // If the text is very dark we have a light theme return 'light'; } // Otherwise fall back to the next heuristic. } // Fallback to system preference return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light'; } function forceTheme(elementId) { const estimatorElement = document.querySelector(`#${elementId}`); if (estimatorElement === null) { console.error(`Element with id ${elementId} not found.`); } else { const theme = detectTheme(estimatorElement); estimatorElement.classList.add(theme); } } forceTheme('sk-container-id-1');</script>
gs_clf.best_params_
{'clf__n_neighbors': 3,
 'tfidf__use_idf': True,
 'vect__max_features': 100,
 'vect__stop_words': 'english'}
gs_clf.cv_results_
{'mean_fit_time': array([0.14711912, 0.1838789 , 0.16524474, 0.13485821, 0.17077629,
        0.171736  , 0.13478573, 0.17135119, 0.1455044 , 0.16278481,
        0.14029678, 0.17597159, 0.14752992, 0.1515433 , 0.2238245 ,
        0.16685406, 0.19492237, 0.18400868, 0.1637164 , 0.23180572,
        0.13470483, 0.19008255, 0.168998  , 0.14402525, 0.14959741,
        0.21908172, 0.12458984, 0.17676258, 0.181101  , 0.15396865,
        0.13980309, 0.17127609, 0.12496074, 0.21157686, 0.14910388,
        0.16302236, 0.20699684, 0.14392487, 0.13715124, 0.19204704,
        0.12343788, 0.202528  , 0.1306417 , 0.14448158, 0.16433088,
        0.17828568, 0.15127762, 0.16031146, 0.15140502, 0.17451549,
        0.13383325, 0.19785619, 0.14872464, 0.19805606, 0.17880678,
        0.14933427, 0.18098283, 0.17271209, 0.19098576, 0.15773439,
        0.14277951, 0.14034979, 0.17362897, 0.14588912, 0.15714224,
        0.19136397, 0.14665333, 0.18524861, 0.1639146 , 0.15949496,
        0.14914314, 0.19116807, 0.16394456, 0.17290258, 0.14597694,
        0.18869273, 0.17521318, 0.18450022, 0.15120141, 0.14013394]),
 'std_fit_time': array([0.0426388 , 0.00865779, 0.03047002, 0.01959866, 0.01996265,
        0.03991559, 0.02089307, 0.0115939 , 0.01577824, 0.01259437,
        0.02861695, 0.02803293, 0.02632727, 0.03513436, 0.10327593,
        0.06408524, 0.04651852, 0.03295957, 0.0471307 , 0.11679622,
        0.0110411 , 0.02803024, 0.05824908, 0.01165112, 0.01597652,
        0.06360679, 0.0225909 , 0.03841467, 0.06331909, 0.02901401,
        0.02955105, 0.00180821, 0.01985555, 0.0523942 , 0.03268085,
        0.03894983, 0.06517167, 0.02243237, 0.02200533, 0.0133168 ,
        0.02014168, 0.05538047, 0.00691204, 0.02019455, 0.04787927,
        0.02202076, 0.03367995, 0.04032929, 0.04035338, 0.05086764,
        0.03783577, 0.02842441, 0.01027949, 0.02261441, 0.07413437,
        0.03057069, 0.07406925, 0.06374613, 0.07499723, 0.00702891,
        0.01449466, 0.01535535, 0.06266348, 0.01594578, 0.03962651,
        0.03872874, 0.02924308, 0.03921288, 0.03093211, 0.02737551,
        0.03913621, 0.03922343, 0.06958397, 0.0321589 , 0.03550322,
        0.06330303, 0.04521662, 0.01877706, 0.01843914, 0.02003679]),
 'mean_score_time': array([0.07024606, 0.09829744, 0.06123559, 0.0934546 , 0.06724668,
        0.08956035, 0.06278229, 0.08930341, 0.07657051, 0.08451796,
        0.06686974, 0.06430848, 0.06159838, 0.06787952, 0.0685486 ,
        0.08175969, 0.07736111, 0.08548617, 0.10391927, 0.11483145,
        0.06140908, 0.08571712, 0.05673496, 0.07094908, 0.07356191,
        0.08419561, 0.06646315, 0.07401323, 0.0662907 , 0.07332158,
        0.06078386, 0.06057103, 0.05959137, 0.07239731, 0.06452703,
        0.06926155, 0.07553283, 0.07713358, 0.06963356, 0.09399152,
        0.05499752, 0.06101354, 0.06951396, 0.06801232, 0.06193964,
        0.08509183, 0.06340122, 0.09640757, 0.09270819, 0.07950346,
        0.06135122, 0.07454085, 0.06276194, 0.08477139, 0.06570975,
        0.06936661, 0.07554563, 0.08220371, 0.06743741, 0.10737284,
        0.06237586, 0.06374478, 0.07077082, 0.06676245, 0.06271029,
        0.10701195, 0.07044252, 0.07724508, 0.0787247 , 0.07522114,
        0.06066545, 0.08716504, 0.06308579, 0.08094724, 0.08140651,
        0.07310184, 0.07593838, 0.07488211, 0.06320556, 0.04453659]),
 'std_score_time': array([0.0307069 , 0.0182529 , 0.02101209, 0.03205303, 0.0257132 ,
        0.02687809, 0.0173987 , 0.02640571, 0.01539859, 0.01465485,
        0.02122433, 0.01356409, 0.00834315, 0.01751938, 0.01171655,
        0.03144068, 0.02179221, 0.01969381, 0.04494128, 0.00622263,
        0.01579416, 0.02897917, 0.01304806, 0.0202541 , 0.02151299,
        0.03211671, 0.01358637, 0.01938508, 0.01422302, 0.01688134,
        0.00869969, 0.0141093 , 0.01194079, 0.01273514, 0.01648673,
        0.01545342, 0.01899566, 0.0178712 , 0.02332053, 0.03043117,
        0.01200954, 0.01238742, 0.02325767, 0.01629294, 0.01778484,
        0.03084458, 0.0157806 , 0.04616991, 0.02142426, 0.01231431,
        0.0059662 , 0.01826184, 0.02046084, 0.02321078, 0.00858559,
        0.01727098, 0.01823882, 0.02170296, 0.0142323 , 0.06476885,
        0.02283514, 0.01249623, 0.00957304, 0.01621181, 0.00941332,
        0.04093458, 0.01894285, 0.0226533 , 0.02176937, 0.01821806,
        0.01055958, 0.03213191, 0.00837758, 0.02668343, 0.03007748,
        0.01731415, 0.01637027, 0.01601808, 0.02890825, 0.01039945]),
 'param_clf__n_neighbors': masked_array(data=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                    1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
                    3, 3, 3, 3, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
                    5, 5, 5, 5, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
                    7, 7, 7, 7, 7, 7, 7, 7],
              mask=[False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False],
        fill_value=999999),
 'param_tfidf__use_idf': masked_array(data=[True, True, True, True, True, True, True, True, True,
                    True, False, False, False, False, False, False, False,
                    False, False, False, True, True, True, True, True,
                    True, True, True, True, True, False, False, False,
                    False, False, False, False, False, False, False, True,
                    True, True, True, True, True, True, True, True, True,
                    False, False, False, False, False, False, False, False,
                    False, False, True, True, True, True, True, True, True,
                    True, True, True, False, False, False, False, False,
                    False, False, False, False, False],
              mask=[False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False],
        fill_value=True),
 'param_vect__max_features': masked_array(data=[100, 100, 500, 500, 1000, 1000, 5000, 5000, 10000,
                    10000, 100, 100, 500, 500, 1000, 1000, 5000, 5000,
                    10000, 10000, 100, 100, 500, 500, 1000, 1000, 5000,
                    5000, 10000, 10000, 100, 100, 500, 500, 1000, 1000,
                    5000, 5000, 10000, 10000, 100, 100, 500, 500, 1000,
                    1000, 5000, 5000, 10000, 10000, 100, 100, 500, 500,
                    1000, 1000, 5000, 5000, 10000, 10000, 100, 100, 500,
                    500, 1000, 1000, 5000, 5000, 10000, 10000, 100, 100,
                    500, 500, 1000, 1000, 5000, 5000, 10000, 10000],
              mask=[False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False],
        fill_value=999999),
 'param_vect__stop_words': masked_array(data=['english', None, 'english', None, 'english', None,
                    'english', None, 'english', None, 'english', None,
                    'english', None, 'english', None, 'english', None,
                    'english', None, 'english', None, 'english', None,
                    'english', None, 'english', None, 'english', None,
                    'english', None, 'english', None, 'english', None,
                    'english', None, 'english', None, 'english', None,
                    'english', None, 'english', None, 'english', None,
                    'english', None, 'english', None, 'english', None,
                    'english', None, 'english', None, 'english', None,
                    'english', None, 'english', None, 'english', None,
                    'english', None, 'english', None, 'english', None,
                    'english', None, 'english', None, 'english', None,
                    'english', None],
              mask=[False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False],
        fill_value=np.str_('?'),
             dtype=object),
 'params': [{'clf__n_neighbors': 1,
   'tfidf__use_idf': True,
   'vect__max_features': 100,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': True,
   'vect__max_features': 100,
   'vect__stop_words': None},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': True,
   'vect__max_features': 500,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': True,
   'vect__max_features': 500,
   'vect__stop_words': None},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': True,
   'vect__max_features': 1000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': True,
   'vect__max_features': 1000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': True,
   'vect__max_features': 5000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': True,
   'vect__max_features': 5000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': True,
   'vect__max_features': 10000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': True,
   'vect__max_features': 10000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': False,
   'vect__max_features': 100,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': False,
   'vect__max_features': 100,
   'vect__stop_words': None},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': False,
   'vect__max_features': 500,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': False,
   'vect__max_features': 500,
   'vect__stop_words': None},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': False,
   'vect__max_features': 1000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': False,
   'vect__max_features': 1000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': False,
   'vect__max_features': 5000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': False,
   'vect__max_features': 5000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': False,
   'vect__max_features': 10000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 1,
   'tfidf__use_idf': False,
   'vect__max_features': 10000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': True,
   'vect__max_features': 100,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': True,
   'vect__max_features': 100,
   'vect__stop_words': None},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': True,
   'vect__max_features': 500,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': True,
   'vect__max_features': 500,
   'vect__stop_words': None},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': True,
   'vect__max_features': 1000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': True,
   'vect__max_features': 1000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': True,
   'vect__max_features': 5000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': True,
   'vect__max_features': 5000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': True,
   'vect__max_features': 10000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': True,
   'vect__max_features': 10000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': False,
   'vect__max_features': 100,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': False,
   'vect__max_features': 100,
   'vect__stop_words': None},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': False,
   'vect__max_features': 500,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': False,
   'vect__max_features': 500,
   'vect__stop_words': None},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': False,
   'vect__max_features': 1000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': False,
   'vect__max_features': 1000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': False,
   'vect__max_features': 5000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': False,
   'vect__max_features': 5000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': False,
   'vect__max_features': 10000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 3,
   'tfidf__use_idf': False,
   'vect__max_features': 10000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': True,
   'vect__max_features': 100,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': True,
   'vect__max_features': 100,
   'vect__stop_words': None},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': True,
   'vect__max_features': 500,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': True,
   'vect__max_features': 500,
   'vect__stop_words': None},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': True,
   'vect__max_features': 1000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': True,
   'vect__max_features': 1000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': True,
   'vect__max_features': 5000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': True,
   'vect__max_features': 5000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': True,
   'vect__max_features': 10000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': True,
   'vect__max_features': 10000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': False,
   'vect__max_features': 100,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': False,
   'vect__max_features': 100,
   'vect__stop_words': None},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': False,
   'vect__max_features': 500,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': False,
   'vect__max_features': 500,
   'vect__stop_words': None},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': False,
   'vect__max_features': 1000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': False,
   'vect__max_features': 1000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': False,
   'vect__max_features': 5000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': False,
   'vect__max_features': 5000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': False,
   'vect__max_features': 10000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 5,
   'tfidf__use_idf': False,
   'vect__max_features': 10000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': True,
   'vect__max_features': 100,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': True,
   'vect__max_features': 100,
   'vect__stop_words': None},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': True,
   'vect__max_features': 500,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': True,
   'vect__max_features': 500,
   'vect__stop_words': None},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': True,
   'vect__max_features': 1000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': True,
   'vect__max_features': 1000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': True,
   'vect__max_features': 5000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': True,
   'vect__max_features': 5000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': True,
   'vect__max_features': 10000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': True,
   'vect__max_features': 10000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': False,
   'vect__max_features': 100,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': False,
   'vect__max_features': 100,
   'vect__stop_words': None},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': False,
   'vect__max_features': 500,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': False,
   'vect__max_features': 500,
   'vect__stop_words': None},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': False,
   'vect__max_features': 1000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': False,
   'vect__max_features': 1000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': False,
   'vect__max_features': 5000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': False,
   'vect__max_features': 5000,
   'vect__stop_words': None},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': False,
   'vect__max_features': 10000,
   'vect__stop_words': 'english'},
  {'clf__n_neighbors': 7,
   'tfidf__use_idf': False,
   'vect__max_features': 10000,
   'vect__stop_words': None}],
 'split0_test_score': array([0.76817583, 0.72894636, 0.53358157, 0.6374826 , 0.52798315,
        0.56519764, 0.51309075, 0.49836392, 0.49549296, 0.48037932,
        0.74678068, 0.6781238 , 0.53815493, 0.67256907, 0.59194787,
        0.61495561, 0.50607973, 0.58348819, 0.48282372, 0.56647032,
        0.78847084, 0.7252709 , 0.48067615, 0.52118859, 0.51721694,
        0.48760057, 0.51577841, 0.46346843, 0.43209944, 0.42022184,
        0.79654539, 0.72168041, 0.50540586, 0.62161402, 0.53994855,
        0.57362945, 0.53619645, 0.54414205, 0.43884702, 0.4762766 ,
        0.75153416, 0.72862518, 0.45804881, 0.49005168, 0.52767212,
        0.49479873, 0.48059474, 0.45872108, 0.43320814, 0.46831129,
        0.77799622, 0.71691323, 0.468479  , 0.58540742, 0.53125944,
        0.55188163, 0.53165214, 0.49549296, 0.46225076, 0.46334516,
        0.73731272, 0.722074  , 0.42768101, 0.49765146, 0.48090102,
        0.59104363, 0.51802701, 0.48304986, 0.44659485, 0.52984156,
        0.76674346, 0.71632941, 0.42110969, 0.55322776, 0.49049555,
        0.55175372, 0.55288072, 0.52752419, 0.47634782, 0.48746189]),
 'split1_test_score': array([0.79662124, 0.74654727, 0.55840557, 0.65284461, 0.56848052,
        0.58959475, 0.50824228, 0.52616761, 0.55281361, 0.53900106,
        0.79198638, 0.74396927, 0.60547906, 0.71852445, 0.55601273,
        0.70906678, 0.55288072, 0.70132316, 0.54064496, 0.72369393,
        0.8200291 , 0.76394456, 0.50142662, 0.65973942, 0.57835078,
        0.58411319, 0.56231765, 0.54929577, 0.53755372, 0.54631601,
        0.76828781, 0.73851721, 0.57218138, 0.70341759, 0.54466558,
        0.70088681, 0.52052675, 0.68262612, 0.56159728, 0.69962829,
        0.78081641, 0.76676231, 0.47784768, 0.65979122, 0.55877981,
        0.58351195, 0.55934603, 0.5526496 , 0.53805724, 0.54275916,
        0.7590099 , 0.74428157, 0.50962843, 0.69741916, 0.56721274,
        0.6845754 , 0.53674839, 0.65880645, 0.56642635, 0.66528626,
        0.75249835, 0.76114625, 0.5033107 , 0.61728125, 0.54473029,
        0.54902822, 0.55716104, 0.55466537, 0.51280439, 0.54540585,
        0.72900844, 0.73304581, 0.49250319, 0.69558085, 0.57262651,
        0.66997019, 0.54339107, 0.66489176, 0.56953075, 0.65745565]),
 'split2_test_score': array([0.82526209, 0.7519849 , 0.59510748, 0.68219353, 0.5764876 ,
        0.6305322 , 0.54587929, 0.56161317, 0.52387661, 0.56501726,
        0.82242212, 0.71247211, 0.60517653, 0.7037331 , 0.54624902,
        0.68817364, 0.5608424 , 0.65146866, 0.53526413, 0.66897713,
        0.81950927, 0.79423613, 0.52825543, 0.66907312, 0.56248767,
        0.60389162, 0.49652131, 0.55180864, 0.50857805, 0.52741727,
        0.80244926, 0.69231892, 0.53887943, 0.71969825, 0.53158497,
        0.6792325 , 0.49764733, 0.65713584, 0.49078794, 0.66897713,
        0.80172943, 0.75452358, 0.54031316, 0.63620099, 0.52529445,
        0.59626871, 0.51327659, 0.53187959, 0.50624241, 0.51743203,
        0.76711749, 0.70123506, 0.49774614, 0.71460958, 0.50316727,
        0.65013293, 0.49732848, 0.64249816, 0.52623393, 0.65444272,
        0.76791783, 0.73768667, 0.53419234, 0.63714302, 0.50165182,
        0.5932877 , 0.49723701, 0.51859739, 0.49993588, 0.52478814,
        0.76415908, 0.69550001, 0.48572333, 0.68222059, 0.50972411,
        0.66549536, 0.47022833, 0.64386329, 0.50648577, 0.6415092 ]),
 'mean_test_score': array([0.79668639, 0.74249284, 0.56236487, 0.65750691, 0.55765042,
        0.5951082 , 0.52240411, 0.5287149 , 0.52406106, 0.52813255,
        0.78706306, 0.71152173, 0.58293684, 0.69827554, 0.56473654,
        0.67073201, 0.53993428, 0.64542667, 0.5195776 , 0.65304713,
        0.8093364 , 0.76115053, 0.50345273, 0.61666704, 0.55268513,
        0.55853513, 0.52487246, 0.52152428, 0.49274373, 0.49798504,
        0.78909415, 0.71750551, 0.53882222, 0.68157662, 0.53873303,
        0.65124959, 0.51812351, 0.627968  , 0.49707742, 0.61496068,
        0.77802667, 0.74997036, 0.49206988, 0.59534796, 0.53724879,
        0.55819313, 0.51773912, 0.51441676, 0.4925026 , 0.50950083,
        0.7680412 , 0.72080995, 0.49195119, 0.66581205, 0.53387982,
        0.62886332, 0.52190967, 0.59893252, 0.51830368, 0.59435805,
        0.7525763 , 0.74030231, 0.48839468, 0.58402524, 0.50909438,
        0.57778652, 0.52414169, 0.51877087, 0.48644504, 0.53334518,
        0.75330366, 0.71495841, 0.4664454 , 0.6436764 , 0.52428206,
        0.62907309, 0.52216671, 0.61209308, 0.51745478, 0.59547558]),
 'std_test_score': array([0.02330541, 0.00983268, 0.02527339, 0.01854849, 0.02123109,
        0.02695613, 0.01671706, 0.02588415, 0.02340142, 0.03539763,
        0.0310761 , 0.0268897 , 0.03166583, 0.01915399, 0.01964985,
        0.04035167, 0.02415844, 0.04829527, 0.02608159, 0.06516717,
        0.01475571, 0.02822417, 0.01947692, 0.06762091, 0.02590243,
        0.05080407, 0.02762023, 0.0410645 , 0.04448367, 0.05552553,
        0.01490843, 0.01909001, 0.02726102, 0.04291775, 0.00540887,
        0.05559311, 0.0158291 , 0.06018046, 0.05030954, 0.09885959,
        0.02058686, 0.01589883, 0.03505766, 0.07507599, 0.01525564,
        0.04512812, 0.03230456, 0.04028527, 0.04389321, 0.0309063 ,
        0.0077786 , 0.01778837, 0.01729171, 0.05728616, 0.02621202,
        0.05622103, 0.0175056 , 0.07344521, 0.04289759, 0.09274581,
        0.0124946 , 0.01605805, 0.04474395, 0.06161139, 0.0265843 ,
        0.02035581, 0.02484303, 0.02923717, 0.02866389, 0.00877417,
        0.01721168, 0.01535864, 0.03217646, 0.064189  , 0.03507444,
        0.05470356, 0.03692975, 0.06041232, 0.03882443, 0.07665416]),
 'rank_test_score': array([ 2, 11, 41, 21, 44, 35, 58, 52, 57, 53,  4, 16, 38, 17, 40, 19, 46,
        24, 62, 22,  1,  7, 71, 29, 45, 42, 54, 61, 74, 72,  3, 14, 47, 18,
        48, 23, 65, 28, 73, 30,  5, 10, 76, 34, 49, 43, 66, 68, 75, 69,  6,
        13, 77, 20, 50, 27, 60, 32, 64, 36,  9, 12, 78, 37, 70, 39, 56, 63,
        79, 51,  8, 15, 80, 25, 55, 26, 59, 31, 67, 33], dtype=int32)}
gs_clf.best_score_
np.float64(0.8093364049132378)

Word embedding

# Импортируем библиотеку gensim и посмотрим какие обученные модели в ней есть.
import gensim.downloader
print(list(gensim.downloader.info()['models'].keys()))
['fasttext-wiki-news-subwords-300', 'conceptnet-numberbatch-17-06-300', 'word2vec-ruscorpora-300', 'word2vec-google-news-300', 'glove-wiki-gigaword-50', 'glove-wiki-gigaword-100', 'glove-wiki-gigaword-200', 'glove-wiki-gigaword-300', 'glove-twitter-25', 'glove-twitter-50', 'glove-twitter-100', 'glove-twitter-200', '__testing_word2vec-matrix-synopsis']

GloVe

Загрузим "glove-twitter-25" и будем работать с ней

glove_model = gensim.downloader.load("glove-twitter-25")  # load glove vectors
[==================================================] 100.0% 104.8/104.8MB downloaded
# Можем посмотреть, что за векторы у слова "cat"
print(glove_model['cat']) # word embedding for 'cat'
glove_model.most_similar("cat")  # show words that similar to word 'cat'
[-0.96419  -0.60978   0.67449   0.35113   0.41317  -0.21241   1.3796
  0.12854   0.31567   0.66325   0.3391   -0.18934  -3.325    -1.1491
 -0.4129    0.2195    0.8706   -0.50616  -0.12781  -0.066965  0.065761
  0.43927   0.1758   -0.56058   0.13529 ]
[('dog', 0.9590820074081421),
 ('monkey', 0.920357882976532),
 ('bear', 0.9143136739730835),
 ('pet', 0.9108031392097473),
 ('girl', 0.8880629539489746),
 ('horse', 0.8872726559638977),
 ('kitty', 0.8870542049407959),
 ('puppy', 0.886769711971283),
 ('hot', 0.886525571346283),
 ('lady', 0.8845519423484802)]

Векторизуем обучающую выборку

Получаем матрицу "Документ-термин". В столбцах - все слова выборки, в строках - каждый документ. Значения в ячейках - количество встречаний слова в документе. В отображенной части видим только нули - наглядная демонстрация проблем разреженности матрицы

vectorizer = CountVectorizer(stop_words='english')
train_data = vectorizer.fit_transform(twenty_train['data'])
CV_data=pd.DataFrame(train_data.toarray(), columns=vectorizer.get_feature_names_out())
print(CV_data.shape)
CV_data.head()
(1064, 16170)
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </style>
00 000 000005102000 000100255pixel 0007 000usd 001200201pixel 00196 0028 0038 ... zorg zorn zsoft zues zug zur zurich zus zvi zyxel
0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
2 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
3 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
4 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0

5 rows × 16170 columns

# Создадим список уникальных слов, присутствующих в словаре.
words_vocab=CV_data.columns
print(words_vocab[0:10])
Index(['00', '000', '000005102000', '000100255pixel', '0007', '000usd',
       '001200201pixel', '00196', '0028', '0038'],
      dtype='str')

Векторизуем с помощью GloVe

Нужно для каждого документа сложить glove-вектора слов, из которых он состоит. В результате получим вектор документа как сумму векторов слов, из него состоящих

Посмотрим на примере как будет работать векторизация

# Пусть выборка состоит из двух документов: 
text_data = ['Hello world I love python', 'This is a great computer game! 00 000 zyxel']
# Векторизуем с помощью обученного CountVectorizer
X = vectorizer.transform(text_data)
CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())
CV_text_data
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </style>
00 000 000005102000 000100255pixel 0007 000usd 001200201pixel 00196 0028 0038 ... zorg zorn zsoft zues zug zur zurich zus zvi zyxel
0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
1 1 1 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 1

2 rows × 16170 columns

# Создадим датафрейм, в который будем сохранять вектор документа
glove_data=pd.DataFrame()

# Пробегаем по каждой строке датафрейма (по каждому документу)
for i in range(CV_text_data.shape[0]):
    
    # Вектор одного документа с размерностью glove-модели:
    one_doc = np.zeros(25) 
    
    # Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре
    # Суммируем glove-вектора каждого известного слова в one_doc
    for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:
        if word in glove_model.key_to_index.keys(): 
            print(word, ': ', glove_model[word])
            one_doc += glove_model[word]
    print(text_data[i], ': ', one_doc)
    glove_data= pd.concat([glove_data, pd.DataFrame([one_doc])], ignore_index=True)    
hello :  [-0.77069    0.12827    0.33137    0.0050893 -0.47605   -0.50116
  1.858      1.0624    -0.56511    0.13328   -0.41918   -0.14195
 -2.8555    -0.57131   -0.13418   -0.44922    0.48591   -0.6479
 -0.84238    0.61669   -0.19824   -0.57967   -0.65885    0.43928
 -0.50473  ]
love :  [-0.62645  -0.082389  0.070538  0.5782   -0.87199  -0.14816   2.2315
  0.98573  -1.3154   -0.34921  -0.8847    0.14585  -4.97     -0.73369
 -0.94359   0.035859 -0.026733 -0.77538  -0.30014   0.48853  -0.16678
 -0.016651 -0.53164   0.64236  -0.10922 ]
python :  [-0.25645  -0.22323   0.025901  0.22901   0.49028  -0.060829  0.24563
 -0.84854   1.5882   -0.7274    0.60603   0.25205  -1.8064   -0.95526
  0.44867   0.013614  0.60856   0.65423   0.82506   0.99459  -0.29403
 -0.27013  -0.348    -0.7293    0.2201  ]
world :  [ 0.10301   0.095666 -0.14789  -0.22383  -0.14775  -0.11599   1.8513
  0.24886  -0.41877  -0.20384  -0.08509   0.33246  -4.6946    0.84096
 -0.46666  -0.031128 -0.19539  -0.037349  0.58949   0.13941  -0.57667
 -0.44426  -0.43085  -0.52875   0.25855 ]
Hello world I love python :  [ -1.55058002  -0.081683     0.27991899   0.58846928  -1.00551002
  -0.82613902   6.18642995   1.44844997  -0.71108004  -1.14717001
  -0.78294002   0.58841    -14.32649982  -1.41929996  -1.09575997
  -0.430875     0.87234702  -0.806399     0.27203003   2.23921998
  -1.23571999  -1.31071102  -1.96934     -0.17641005  -0.1353    ]
computer :  [ 0.64005  -0.019514  0.70148  -0.66123   1.1723   -0.58859   0.25917
 -0.81541   1.1708    1.1413   -0.15405  -0.11369  -3.8414   -0.87233
  0.47489   1.1541    0.97678   1.1107   -0.14572  -0.52013  -0.52234
 -0.92349   0.34651   0.061939 -0.57375 ]
game :  [ 1.146     0.3291    0.26878  -1.3945   -0.30044   0.77901   1.3537
  0.37393   0.50478  -0.44266  -0.048706  0.51396  -4.3136    0.39805
  1.197     0.10287  -0.17618  -1.2881   -0.59801   0.26131  -1.2619
  0.39202   0.59309  -0.55232   0.005087]
great :  [-8.4229e-01  3.6512e-01 -3.8841e-01 -4.6118e-01  2.4301e-01  3.2412e-01
  1.9009e+00 -2.2630e-01 -3.1335e-01 -1.0970e+00 -4.1494e-03  6.2074e-01
 -5.0964e+00  6.7418e-01  5.0080e-01 -6.2119e-01  5.1765e-01 -4.4122e-01
 -1.4364e-01  1.9130e-01 -7.4608e-01 -2.5903e-01 -7.8010e-01  1.1030e-01
 -2.7928e-01]
zyxel :  [ 0.79234   0.067376 -0.22639  -2.2272    0.30057  -0.85676  -1.7268
 -0.78626   1.2042   -0.92348  -0.83987  -0.74233   0.29689  -1.208
  0.98706  -1.1624    0.61415  -0.27825   0.27813   1.5838   -0.63593
 -0.10225   1.7102   -0.95599  -1.3867  ]
This is a great computer game! 00 000 zyxel :  [  1.73610002   0.74208201   0.35545996  -4.74411008   1.41543998
  -0.34222007   1.78697008  -1.45404002   2.56643     -1.32184002
  -1.04677537   0.27867999 -12.95450976  -1.00809997   3.15975004
  -0.52662008   1.93239999  -0.89686999  -0.60924001   1.51628
  -3.16624993  -0.89275002   1.86969995  -1.33607102  -2.23464306]

На основе ячейки выше, напишем функцию, которая для каждого слова в тексте ищет это слово в glove_model

def text2vec(text_data):
    
    # Векторизуем с помощью обученного CountVectorizer
    X = vectorizer.transform(text_data)
    CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())
    CV_text_data
    # Создадим датафрейм, в который будем сохранять вектор документа
    glove_data=pd.DataFrame()

    # Пробегаем по каждой строке (по каждому документу)
    for i in range(CV_text_data.shape[0]):

        # Вектор одного документа с размерностью glove-модели:
        one_doc = np.zeros(25) 

        # Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре
        # Суммируем glove-вектора каждого известного слова в one_doc
        for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:
            if word in glove_model.key_to_index.keys(): 
                #print(word, ': ', glove_model[word])
                one_doc += glove_model[word]
        #print(text_data[i], ': ', one_doc)
        glove_data = pd.concat([glove_data, pd.DataFrame([one_doc])], axis = 0)
    #print('glove_data: ', glove_data)
    return glove_data
# Наша выборка, векторизованная через glove выглядит так. 
# Попытаться понять, почему такая размерность матрицы и что за значения в ячейках
glove_data
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </style>
0 1 2 3 4 5 6 7 8 9 ... 15 16 17 18 19 20 21 22 23 24
0 -1.55058 -0.081683 0.279919 0.588469 -1.00551 -0.826139 6.18643 1.44845 -0.71108 -1.14717 ... -0.430875 0.872347 -0.806399 0.27203 2.23922 -1.23572 -1.310711 -1.96934 -0.176410 -0.135300
1 1.73610 0.742082 0.355460 -4.744110 1.41544 -0.342220 1.78697 -1.45404 2.56643 -1.32184 ... -0.526620 1.932400 -0.896870 -0.60924 1.51628 -3.16625 -0.892750 1.86970 -1.336071 -2.234643

2 rows × 25 columns

Можем использовать написанную функцию на нашей реальной выборке

train_data_glove = text2vec(twenty_train['data'])
train_data_glove
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </style>
0 1 2 3 4 5 6 7 8 9 ... 15 16 17 18 19 20 21 22 23 24
0 0.499258 3.090499 -17.267230 -0.997832 -2.746523 1.586072 25.890665 -27.914878 -1.656527 -6.599336 ... 6.840056 5.475214 4.488208 17.680875 -8.254198 1.772313 6.184074 -8.879014 -0.216338 -13.408250
0 2.579560 -1.498130 1.752350 0.407907 1.528970 0.152390 0.283640 1.937990 -0.659960 -2.136370 ... 0.434040 -0.274430 0.445033 -0.645750 1.055500 -0.850250 -0.433200 0.212190 -1.164330 0.905880
0 -3.838998 3.339424 3.943270 -1.428108 4.171174 1.918481 9.177921 -3.241092 1.280560 -1.781829 ... -3.719443 9.627485 -1.198027 -4.010974 1.464305 0.398107 -1.013318 -1.591893 4.290867 -3.888199
0 -1.182812 51.118736 -28.117878 -28.586500 24.244726 -9.567325 111.896700 -114.387355 12.647457 -24.559720 ... 24.085172 39.112016 -23.011944 41.767239 -42.362471 -16.060568 -20.090755 -4.457053 -3.181086 -90.876348
0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
0 -5.839209 12.729176 -6.201106 -11.255287 2.878065 -11.993908 11.418405 -15.479517 16.634253 -10.534109 ... 11.665809 13.761738 -2.579413 -7.877240 3.299712 -9.241404 -16.102834 5.343221 -3.646908 -17.216441
0 1.286424 4.341360 -29.110272 5.057644 5.610822 -10.954865 29.974885 -38.261129 10.369041 -15.652271 ... 12.201855 5.913883 15.941801 26.251438 -9.280780 -6.656035 6.158630 -14.878060 -1.138172 -25.532792
0 6.698275 8.135166 -10.347349 0.165546 2.089456 -7.695734 31.593479 -38.582911 -3.285587 -6.087527 ... 16.575447 15.696628 0.785602 13.384657 -13.329273 1.983042 -4.041431 -5.571411 0.581912 -32.037549
0 -0.461364 4.499502 -10.446904 1.502338 -10.415269 -8.801252 18.272721 -10.398248 -6.686713 -0.250122 ... 5.758498 4.540718 -0.372422 10.668546 -6.377995 9.305774 11.568183 -18.814704 -1.204665 -14.551387
0 8.348875 10.837893 -1.528737 -18.246933 33.950209 -0.376254 48.017779 -57.636160 25.231878 -13.371779 ... 29.297609 34.342018 -8.527635 13.555851 -12.649073 -28.706501 -27.042654 4.292338 2.614328 -48.834758

1064 rows × 25 columns

Мы получили матрицу, где каждый документ представлен вектором-строкой. Можем подавать эти вектора на вход для обучения классифкатора

clf = KNeighborsClassifier(n_neighbors = 5)
clf.fit(train_data_glove, twenty_train['target'])
<style>#sk-container-id-2 { /* Definition of color scheme common for light and dark mode */ --sklearn-color-text: #000; --sklearn-color-text-muted: #666; --sklearn-color-line: gray; /* Definition of color scheme for unfitted estimators */ --sklearn-color-unfitted-level-0: #fff5e6; --sklearn-color-unfitted-level-1: #f6e4d2; --sklearn-color-unfitted-level-2: #ffe0b3; --sklearn-color-unfitted-level-3: chocolate; /* Definition of color scheme for fitted estimators */ --sklearn-color-fitted-level-0: #f0f8ff; --sklearn-color-fitted-level-1: #d4ebff; --sklearn-color-fitted-level-2: #b3dbfd; --sklearn-color-fitted-level-3: cornflowerblue; } #sk-container-id-2.light { /* Specific color for light theme */ --sklearn-color-text-on-default-background: black; --sklearn-color-background: white; --sklearn-color-border-box: black; --sklearn-color-icon: #696969; } #sk-container-id-2.dark { --sklearn-color-text-on-default-background: white; --sklearn-color-background: #111; --sklearn-color-border-box: white; --sklearn-color-icon: #878787; } #sk-container-id-2 { color: var(--sklearn-color-text); } #sk-container-id-2 pre { padding: 0; } #sk-container-id-2 input.sk-hidden--visually { border: 0; clip: rect(1px 1px 1px 1px); clip: rect(1px, 1px, 1px, 1px); height: 1px; margin: -1px; overflow: hidden; padding: 0; position: absolute; width: 1px; } #sk-container-id-2 div.sk-dashed-wrapped { border: 1px dashed var(--sklearn-color-line); margin: 0 0.4em 0.5em 0.4em; box-sizing: border-box; padding-bottom: 0.4em; background-color: var(--sklearn-color-background); } #sk-container-id-2 div.sk-container { /* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */ display: inline-block !important; position: relative; } #sk-container-id-2 div.sk-text-repr-fallback { display: none; } div.sk-parallel-item, div.sk-serial, div.sk-item { /* draw centered vertical line to link estimators */ background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background)); background-size: 2px 100%; background-repeat: no-repeat; background-position: center center; } /* Parallel-specific style estimator block */ #sk-container-id-2 div.sk-parallel-item::after { content: ""; width: 100%; border-bottom: 2px solid var(--sklearn-color-text-on-default-background); flex-grow: 1; } #sk-container-id-2 div.sk-parallel { display: flex; align-items: stretch; justify-content: center; background-color: var(--sklearn-color-background); position: relative; } #sk-container-id-2 div.sk-parallel-item { display: flex; flex-direction: column; } #sk-container-id-2 div.sk-parallel-item:first-child::after { align-self: flex-end; width: 50%; } #sk-container-id-2 div.sk-parallel-item:last-child::after { align-self: flex-start; width: 50%; } #sk-container-id-2 div.sk-parallel-item:only-child::after { width: 0; } /* Serial-specific style estimator block */ #sk-container-id-2 div.sk-serial { display: flex; flex-direction: column; align-items: center; background-color: var(--sklearn-color-background); padding-right: 1em; padding-left: 1em; } /* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is clickable and can be expanded/collapsed. - Pipeline and ColumnTransformer use this feature and define the default style - Estimators will overwrite some part of the style using the `sk-estimator` class */ /* Pipeline and ColumnTransformer style (default) */ #sk-container-id-2 div.sk-toggleable { /* Default theme specific background. It is overwritten whether we have a specific estimator or a Pipeline/ColumnTransformer */ background-color: var(--sklearn-color-background); } /* Toggleable label */ #sk-container-id-2 label.sk-toggleable__label { cursor: pointer; display: flex; width: 100%; margin-bottom: 0; padding: 0.5em; box-sizing: border-box; text-align: center; align-items: center; justify-content: center; gap: 0.5em; } #sk-container-id-2 label.sk-toggleable__label .caption { font-size: 0.6rem; font-weight: lighter; color: var(--sklearn-color-text-muted); } #sk-container-id-2 label.sk-toggleable__label-arrow:before { /* Arrow on the left of the label */ content: "▸"; float: left; margin-right: 0.25em; color: var(--sklearn-color-icon); } #sk-container-id-2 label.sk-toggleable__label-arrow:hover:before { color: var(--sklearn-color-text); } /* Toggleable content - dropdown */ #sk-container-id-2 div.sk-toggleable__content { display: none; text-align: left; /* unfitted */ background-color: var(--sklearn-color-unfitted-level-0); } #sk-container-id-2 div.sk-toggleable__content.fitted { /* fitted */ background-color: var(--sklearn-color-fitted-level-0); } #sk-container-id-2 div.sk-toggleable__content pre { margin: 0.2em; border-radius: 0.25em; color: var(--sklearn-color-text); /* unfitted */ background-color: var(--sklearn-color-unfitted-level-0); } #sk-container-id-2 div.sk-toggleable__content.fitted pre { /* unfitted */ background-color: var(--sklearn-color-fitted-level-0); } #sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content { /* Expand drop-down */ display: block; width: 100%; overflow: visible; } #sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before { content: "▾"; } /* Pipeline/ColumnTransformer-specific style */ #sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label { color: var(--sklearn-color-text); background-color: var(--sklearn-color-unfitted-level-2); } #sk-container-id-2 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label { background-color: var(--sklearn-color-fitted-level-2); } /* Estimator-specific style */ /* Colorize estimator box */ #sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label { /* unfitted */ background-color: var(--sklearn-color-unfitted-level-2); } #sk-container-id-2 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label { /* fitted */ background-color: var(--sklearn-color-fitted-level-2); } #sk-container-id-2 div.sk-label label.sk-toggleable__label, #sk-container-id-2 div.sk-label label { /* The background is the default theme color */ color: var(--sklearn-color-text-on-default-background); } /* On hover, darken the color of the background */ #sk-container-id-2 div.sk-label:hover label.sk-toggleable__label { color: var(--sklearn-color-text); background-color: var(--sklearn-color-unfitted-level-2); } /* Label box, darken color on hover, fitted */ #sk-container-id-2 div.sk-label.fitted:hover label.sk-toggleable__label.fitted { color: var(--sklearn-color-text); background-color: var(--sklearn-color-fitted-level-2); } /* Estimator label */ #sk-container-id-2 div.sk-label label { font-family: monospace; font-weight: bold; line-height: 1.2em; } #sk-container-id-2 div.sk-label-container { text-align: center; } /* Estimator-specific */ #sk-container-id-2 div.sk-estimator { font-family: monospace; border: 1px dotted var(--sklearn-color-border-box); border-radius: 0.25em; box-sizing: border-box; margin-bottom: 0.5em; /* unfitted */ background-color: var(--sklearn-color-unfitted-level-0); } #sk-container-id-2 div.sk-estimator.fitted { /* fitted */ background-color: var(--sklearn-color-fitted-level-0); } /* on hover */ #sk-container-id-2 div.sk-estimator:hover { /* unfitted */ background-color: var(--sklearn-color-unfitted-level-2); } #sk-container-id-2 div.sk-estimator.fitted:hover { /* fitted */ background-color: var(--sklearn-color-fitted-level-2); } /* Specification for estimator info (e.g. "i" and "?") */ /* Common style for "i" and "?" */ .sk-estimator-doc-link, a:link.sk-estimator-doc-link, a:visited.sk-estimator-doc-link { float: right; font-size: smaller; line-height: 1em; font-family: monospace; background-color: var(--sklearn-color-unfitted-level-0); border-radius: 1em; height: 1em; width: 1em; text-decoration: none !important; margin-left: 0.5em; text-align: center; /* unfitted */ border: var(--sklearn-color-unfitted-level-3) 1pt solid; color: var(--sklearn-color-unfitted-level-3); } .sk-estimator-doc-link.fitted, a:link.sk-estimator-doc-link.fitted, a:visited.sk-estimator-doc-link.fitted { /* fitted */ background-color: var(--sklearn-color-fitted-level-0); border: var(--sklearn-color-fitted-level-3) 1pt solid; color: var(--sklearn-color-fitted-level-3); } /* On hover */ div.sk-estimator:hover .sk-estimator-doc-link:hover, .sk-estimator-doc-link:hover, div.sk-label-container:hover .sk-estimator-doc-link:hover, .sk-estimator-doc-link:hover { /* unfitted */ background-color: var(--sklearn-color-unfitted-level-3); border: var(--sklearn-color-fitted-level-0) 1pt solid; color: var(--sklearn-color-unfitted-level-0); text-decoration: none; } div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover, .sk-estimator-doc-link.fitted:hover, div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover, .sk-estimator-doc-link.fitted:hover { /* fitted */ background-color: var(--sklearn-color-fitted-level-3); border: var(--sklearn-color-fitted-level-0) 1pt solid; color: var(--sklearn-color-fitted-level-0); text-decoration: none; } /* Span, style for the box shown on hovering the info icon */ .sk-estimator-doc-link span { display: none; z-index: 9999; position: relative; font-weight: normal; right: .2ex; padding: .5ex; margin: .5ex; width: min-content; min-width: 20ex; max-width: 50ex; color: var(--sklearn-color-text); box-shadow: 2pt 2pt 4pt #999; /* unfitted */ background: var(--sklearn-color-unfitted-level-0); border: .5pt solid var(--sklearn-color-unfitted-level-3); } .sk-estimator-doc-link.fitted span { /* fitted */ background: var(--sklearn-color-fitted-level-0); border: var(--sklearn-color-fitted-level-3); } .sk-estimator-doc-link:hover span { display: block; } /* "?"-specific style due to the `` HTML tag */ #sk-container-id-2 a.estimator_doc_link { float: right; font-size: 1rem; line-height: 1em; font-family: monospace; background-color: var(--sklearn-color-unfitted-level-0); border-radius: 1rem; height: 1rem; width: 1rem; text-decoration: none; /* unfitted */ color: var(--sklearn-color-unfitted-level-1); border: var(--sklearn-color-unfitted-level-1) 1pt solid; } #sk-container-id-2 a.estimator_doc_link.fitted { /* fitted */ background-color: var(--sklearn-color-fitted-level-0); border: var(--sklearn-color-fitted-level-1) 1pt solid; color: var(--sklearn-color-fitted-level-1); } /* On hover */ #sk-container-id-2 a.estimator_doc_link:hover { /* unfitted */ background-color: var(--sklearn-color-unfitted-level-3); color: var(--sklearn-color-background); text-decoration: none; } #sk-container-id-2 a.estimator_doc_link.fitted:hover { /* fitted */ background-color: var(--sklearn-color-fitted-level-3); } .estimator-table { font-family: monospace; } .estimator-table summary { padding: .5rem; cursor: pointer; } .estimator-table summary::marker { font-size: 0.7rem; } .estimator-table details[open] { padding-left: 0.1rem; padding-right: 0.1rem; padding-bottom: 0.3rem; } .estimator-table .parameters-table { margin-left: auto !important; margin-right: auto !important; margin-top: 0; } .estimator-table .parameters-table tr:nth-child(odd) { background-color: #fff; } .estimator-table .parameters-table tr:nth-child(even) { background-color: #f6f6f6; } .estimator-table .parameters-table tr:hover { background-color: #e0e0e0; } .estimator-table table td { border: 1px solid rgba(106, 105, 104, 0.232); } /* `table td`is set in notebook with right text-align. We need to overwrite it. */ .estimator-table table td.param { text-align: left; position: relative; padding: 0; } .user-set td { color:rgb(255, 94, 0); text-align: left !important; } .user-set td.value { color:rgb(255, 94, 0); background-color: transparent; } .default td { color: black; text-align: left !important; } .user-set td i, .default td i { color: black; } /* Styles for parameter documentation links We need styling for visited so jupyter doesn't overwrite it */ a.param-doc-link, a.param-doc-link:link, a.param-doc-link:visited { text-decoration: underline dashed; text-underline-offset: .3em; color: inherit; display: block; padding: .5em; } /* "hack" to make the entire area of the cell containing the link clickable */ a.param-doc-link::before { position: absolute; content: ""; inset: 0; } .param-doc-description { display: none; position: absolute; z-index: 9999; left: 0; padding: .5ex; margin-left: 1.5em; color: var(--sklearn-color-text); box-shadow: .3em .3em .4em #999; width: max-content; text-align: left; max-height: 10em; overflow-y: auto; /* unfitted */ background: var(--sklearn-color-unfitted-level-0); border: thin solid var(--sklearn-color-unfitted-level-3); } /* Fitted state for parameter tooltips */ .fitted .param-doc-description { /* fitted */ background: var(--sklearn-color-fitted-level-0); border: thin solid var(--sklearn-color-fitted-level-3); } .param-doc-link:hover .param-doc-description { display: block; } .copy-paste-icon { background-image: url(data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCA0NDggNTEyIj48IS0tIUZvbnQgQXdlc29tZSBGcmVlIDYuNy4yIGJ5IEBmb250YXdlc29tZSAtIGh0dHBzOi8vZm9udGF3ZXNvbWUuY29tIExpY2Vuc2UgLSBodHRwczovL2ZvbnRhd2Vzb21lLmNvbS9saWNlbnNlL2ZyZWUgQ29weXJpZ2h0IDIwMjUgRm9udGljb25zLCBJbmMuLS0+PHBhdGggZD0iTTIwOCAwTDMzMi4xIDBjMTIuNyAwIDI0LjkgNS4xIDMzLjkgMTQuMWw2Ny45IDY3LjljOSA5IDE0LjEgMjEuMiAxNC4xIDMzLjlMNDQ4IDMzNmMwIDI2LjUtMjEuNSA0OC00OCA0OGwtMTkyIDBjLTI2LjUgMC00OC0yMS41LTQ4LTQ4bDAtMjg4YzAtMjYuNSAyMS41LTQ4IDQ4LTQ4ek00OCAxMjhsODAgMCAwIDY0LTY0IDAgMCAyNTYgMTkyIDAgMC0zMiA2NCAwIDAgNDhjMCAyNi41LTIxLjUgNDgtNDggNDhMNDggNTEyYy0yNi41IDAtNDgtMjEuNS00OC00OEwwIDE3NmMwLTI2LjUgMjEuNS00OCA0OC00OHoiLz48L3N2Zz4=); background-repeat: no-repeat; background-size: 14px 14px; background-position: 0; display: inline-block; width: 14px; height: 14px; cursor: pointer; } </style>
KNeighborsClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
n_neighbors n_neighbors: int, default=5

Number of neighbors to use by default for :meth:`kneighbors` queries.
5
weights weights: {'uniform', 'distance'}, callable or None, default='uniform'

Weight function used in prediction. Possible values:

- 'uniform' : uniform weights. All points in each neighborhood
are weighted equally.
- 'distance' : weight points by the inverse of their distance.
in this case, closer neighbors of a query point will have a
greater influence than neighbors which are further away.
- [callable] : a user-defined function which accepts an
array of distances, and returns an array of the same shape
containing the weights.

Refer to the example entitled
:ref:`sphx_glr_auto_examples_neighbors_plot_classification.py`
showing the impact of the `weights` parameter on the decision
boundary.
'uniform'
algorithm algorithm: {'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto'

Algorithm used to compute the nearest neighbors:

- 'ball_tree' will use :class:`BallTree`
- 'kd_tree' will use :class:`KDTree`
- 'brute' will use a brute-force search.
- 'auto' will attempt to decide the most appropriate algorithm
based on the values passed to :meth:`fit` method.

Note: fitting on sparse input will override the setting of
this parameter, using brute force.
'auto'
leaf_size leaf_size: int, default=30

Leaf size passed to BallTree or KDTree. This can affect the
speed of the construction and query, as well as the memory
required to store the tree. The optimal value depends on the
nature of the problem.
30
p p: float, default=2

Power parameter for the Minkowski metric. When p = 1, this is equivalent
to using manhattan_distance (l1), and euclidean_distance (l2) for p = 2.
For arbitrary p, minkowski_distance (l_p) is used. This parameter is expected
to be positive.
2
metric metric: str or callable, default='minkowski'

Metric to use for distance computation. Default is "minkowski", which
results in the standard Euclidean distance when p = 2. See the
documentation of `scipy.spatial.distance
`_ and
the metrics listed in
:class:`~sklearn.metrics.pairwise.distance_metrics` for valid metric
values.

If metric is "precomputed", X is assumed to be a distance matrix and
must be square during fit. X may be a :term:`sparse graph`, in which
case only "nonzero" elements may be considered neighbors.

If metric is a callable function, it takes two arrays representing 1D
vectors as inputs and must return one value indicating the distance
between those vectors. This works for Scipy's metrics, but is less
efficient than passing the metric name as a string.
'minkowski'
metric_params metric_params: dict, default=None

Additional keyword arguments for the metric function.
None
n_jobs n_jobs: int, default=None

The number of parallel jobs to run for neighbors search.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary `
for more details.
Doesn't affect :meth:`fit` method.
None
<script>function copyToClipboard(text, element) { // Get the parameter prefix from the closest toggleable content const toggleableContent = element.closest('.sk-toggleable__content'); const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : ''; const fullParamName = paramPrefix ? `${paramPrefix}${text}` : text; const originalStyle = element.style; const computedStyle = window.getComputedStyle(element); const originalWidth = computedStyle.width; const originalHTML = element.innerHTML.replace('Copied!', ''); navigator.clipboard.writeText(fullParamName) .then(() => { element.style.width = originalWidth; element.style.color = 'green'; element.innerHTML = "Copied!"; setTimeout(() => { element.innerHTML = originalHTML; element.style = originalStyle; }, 2000); }) .catch(err => { console.error('Failed to copy:', err); element.style.color = 'red'; element.innerHTML = "Failed!"; setTimeout(() => { element.innerHTML = originalHTML; element.style = originalStyle; }, 2000); }); return false; } document.querySelectorAll('.copy-paste-icon').forEach(function(element) { const toggleableContent = element.closest('.sk-toggleable__content'); const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : ''; const paramName = element.parentElement.nextElementSibling .textContent.trim().split(' ')[0]; const fullParamName = paramPrefix ? `${paramPrefix}${paramName}` : paramName; element.setAttribute('title', fullParamName); }); /** * Adapted from Skrub * 403466d1d5/skrub/_reporting/_data/templates/report.js (L789) * @returns "light" or "dark" */ function detectTheme(element) { const body = document.querySelector('body'); // Check VSCode theme const themeKindAttr = body.getAttribute('data-vscode-theme-kind'); const themeNameAttr = body.getAttribute('data-vscode-theme-name'); if (themeKindAttr && themeNameAttr) { const themeKind = themeKindAttr.toLowerCase(); const themeName = themeNameAttr.toLowerCase(); if (themeKind.includes("dark") || themeName.includes("dark")) { return "dark"; } if (themeKind.includes("light") || themeName.includes("light")) { return "light"; } } // Check Jupyter theme if (body.getAttribute('data-jp-theme-light') === 'false') { return 'dark'; } else if (body.getAttribute('data-jp-theme-light') === 'true') { return 'light'; } // Guess based on a parent element's color const color = window.getComputedStyle(element.parentNode, null).getPropertyValue('color'); const match = color.match(/^rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)\s*$/i); if (match) { const [r, g, b] = [ parseFloat(match[1]), parseFloat(match[2]), parseFloat(match[3]) ]; // https://en.wikipedia.org/wiki/HSL_and_HSV#Lightness const luma = 0.299 * r + 0.587 * g + 0.114 * b; if (luma > 180) { // If the text is very bright we have a dark theme return 'dark'; } if (luma < 75) { // If the text is very dark we have a light theme return 'light'; } // Otherwise fall back to the next heuristic. } // Fallback to system preference return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light'; } function forceTheme(elementId) { const estimatorElement = document.querySelector(`#${elementId}`); if (estimatorElement === null) { console.error(`Element with id ${elementId} not found.`); } else { const theme = detectTheme(estimatorElement); estimatorElement.classList.add(theme); } } forceTheme('sk-container-id-2');</script>

Аналогичную операци проводим с тестовой частью выборки и оцениваем качество модели

test_data_glove = text2vec(twenty_test['data'])
test_data_glove
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </style>
0 1 2 3 4 5 6 7 8 9 ... 15 16 17 18 19 20 21 22 23 24
0 2.118167 10.172361 -12.929062 -10.235129 -5.895256 -2.960585 13.365803 -26.779260 0.811607 -3.139421 ... 4.893207 4.652184 2.903990 13.077540 -12.438726 2.602789 -7.450777 -5.202635 -20.983743 -14.107781
0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
0 4.873320 27.474242 -10.773876 8.112889 0.623212 -2.199475 59.094818 -59.831612 -2.963759 -18.048856 ... 23.890453 12.469234 -3.087457 12.981009 -29.858967 13.105756 10.670630 -17.378340 -6.079635 -41.408141
0 -3.543091 1.986200 -2.191232 0.006059 -2.351208 0.354323 5.127530 -4.232705 0.855633 -1.511475 ... -1.034241 2.211705 -3.255103 3.145205 0.115450 -0.696187 0.647662 2.188560 1.237610 -3.731906
0 -0.810280 -0.041986 0.309050 0.268330 -0.735400 -0.585930 -0.207310 0.084744 -0.024049 -1.729000 ... -0.981390 -0.348830 -0.012710 -0.399160 0.237930 0.608960 0.257370 -0.611980 -0.569530 -0.627590
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
0 -2.713570 17.157689 3.372045 -5.867599 -6.276993 -12.745005 25.654686 -28.636649 0.390143 -9.184376 ... 1.313231 11.260476 -3.974922 13.881956 -5.040806 4.269640 9.495090 -6.392856 -17.443242 -8.663642
0 0.330970 4.718122 -1.751891 -2.035309 -0.343631 -4.570000 6.742370 -4.367192 -2.187082 -1.016773 ... -0.032756 3.431770 -4.601260 -0.623544 -1.379208 -2.430931 -7.340540 -4.909410 1.812950 -7.518150
0 2.374983 5.077650 -3.002268 -3.933630 2.078387 -5.470335 7.654599 -12.212619 7.815100 -2.849330 ... -0.234639 5.634821 -3.027221 1.661357 0.393759 -6.279475 -4.927666 -2.592987 1.981002 -12.255251
0 1.410159 0.770070 -3.955777 -2.791190 5.210700 1.621952 7.780780 -9.859718 8.394491 -5.347944 ... 1.896947 4.079237 -1.207350 5.772389 1.186500 -7.286901 -2.485993 -0.947197 0.788222 -9.518580
0 5.710371 6.638114 0.178107 -6.406337 9.719197 -8.779200 8.866295 -11.437220 4.871641 -11.415244 ... 0.851029 4.182781 -3.302557 1.213709 -4.294990 -7.751608 -12.556040 -10.339437 1.448986 -10.394629

708 rows × 25 columns

predict = clf.predict(test_data_glove )
print (confusion_matrix(twenty_test['target'], predict))
print(classification_report(twenty_test['target'], predict))
[[258  61]
 [ 31 358]]
              precision    recall  f1-score   support

           0       0.89      0.81      0.85       319
           1       0.85      0.92      0.89       389

    accuracy                           0.87       708
   macro avg       0.87      0.86      0.87       708
weighted avg       0.87      0.87      0.87       708