{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "srXC6pLGLwS6"
},
"source": [
"# Загрузка библиотек"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "yG_n40gFzf9s"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from sklearn.model_selection import train_test_split\n",
"import numpy as np\n",
"import pandas as pd\n",
"import os\n",
"import time\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LshgkZ0cIOor",
"outputId": "903898c0-4205-47d7-a6e3-ca22e79ffba9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found GPU at: \n"
]
}
],
"source": [
"device_name = tf.test.gpu_device_name()\n",
"print('Found GPU at: {}'.format(device_name))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "vWkRnHV0DK0L"
},
"outputs": [],
"source": [
"RANDOM_STATE = 42"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EW8HRqz8Oz_b"
},
"source": [
"# Данные"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UHjdCjDuSvX_"
},
"source": [
"## Загрузка данных\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "pP1Ou4nq_W1o"
},
"outputs": [],
"source": [
"# Выбираем поэта\n",
"poet = 'pushkin' #@param ['mayakovskiy', 'pushkin']\n",
"\n",
"path_to_file = f'{poet}.txt'\n",
"path_to_file = tf.keras.utils.get_file(path_to_file, f'http://uit.mpei.ru/git/main/TDA/raw/branch/master/assets/poems/{path_to_file}')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aavnuByVymwK",
"outputId": "1c4c379c-ab1e-4b84-939c-66a7471c0210"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Length of text: 586731 characters\n"
]
}
],
"source": [
"# Загружаем текст из файла.\n",
"# Стихотворения в файле разделены токеном '' - сохраняем в переменную\n",
"with open(path_to_file,encoding = \"utf-8\") as f:\n",
" text = f.read()\n",
"\n",
"print(f'Length of text: {len(text)} characters')\n",
"\n",
"EOS_TOKEN = ''"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Duhg9NrUymwO",
"outputId": "8f485753-1712-47a4-c328-a4a714ec0ea4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Так и мне узнать случилось,\n",
"Что за птица Купидон;\n",
"Сердце страстное пленилось;\n",
"Признаюсь – и я влюблен!\n",
"Пролетело счастья время,\n",
"Как, любви не зная бремя,\n",
"Я живал да попевал,\n",
"Как в театре и на балах,\n",
"На гуляньях иль в воксалах\n",
"Легким зефиром летал;\n",
"Как, смеясь во зло Амуру,\n",
"Я писал карикатуру\n",
"На любезный женской пол;\n",
"Но напрасно я смеялся,\n",
"Наконец и сам попался,\n",
"Сам, увы! с ума сошел.\n",
"Смехи, вольность – всё под лавку\n",
"Из Катонов я в отставку,\n",
"И теперь я – Селадон!\n",
"Миловидной жрицы Тальи\n",
"Видел прел\n"
]
}
],
"source": [
"# Посмотрим на текст\n",
"print(text[:500])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dLZNbAnzj2lR"
},
"source": [
"## Подсчет статистик\n",
"\n",
"describe_poems - функция, разбивающая файл на отдельные стихотворения (poem), и расчитывающая их характиеристики:\n",
"* длину (len), \n",
"* количество строк (lines)\n",
"* среднюю длину строки (mean_line_len)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "C7G_weaWnMSg"
},
"outputs": [],
"source": [
"def mean_line_len(poem):\n",
" lines = [len(line.strip()) for line in poem.split('\\n') if len(line.strip())>0]\n",
" return sum(lines)/len(lines)\n",
"\n",
"\n",
"def describe_poems(text,return_df = False):\n",
" poems_list = [poem.strip() for poem in text.split(EOS_TOKEN) if len(poem.strip())>0]\n",
" df = pd.DataFrame(data=poems_list,columns=['poem'])\n",
" df['len'] = df.poem.map(len)\n",
" df['lines'] = df.poem.str.count('\\n')\n",
" df['mean_line_len'] = df.poem.map(mean_line_len)\n",
" if return_df:\n",
" return df\n",
" return df.describe()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 424
},
"id": "8t4QIKLgj8_y",
"outputId": "4ffe0325-70be-4a3f-9fd6-910be40e5dd3"
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" poem | \n",
" len | \n",
" lines | \n",
" mean_line_len | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Так и мне узнать случилось,\\nЧто за птица Купи... | \n",
" 2536 | \n",
" 109 | \n",
" 23.114286 | \n",
"
\n",
" \n",
" 1 | \n",
" Хочу воспеть, как дух нечистый Ада\\nОседлан бы... | \n",
" 5543 | \n",
" 170 | \n",
" 33.372671 | \n",
"
\n",
" \n",
" 2 | \n",
" Покаместь ночь еще не удалилась,\\nПокаместь св... | \n",
" 4279 | \n",
" 131 | \n",
" 33.451613 | \n",
"
\n",
" \n",
" 3 | \n",
" Ах, отчего мне дивная природа\\nКорреджио искус... | \n",
" 4435 | \n",
" 131 | \n",
" 33.364341 | \n",
"
\n",
" \n",
" 4 | \n",
" Арист! и ты в толпе служителей Парнасса!\\nТы х... | \n",
" 3893 | \n",
" 106 | \n",
" 38.642857 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 714 | \n",
" Чудный сон мне бог послал —\\n\\nС длинной белой... | \n",
" 860 | \n",
" 38 | \n",
" 22.833333 | \n",
"
\n",
" \n",
" 715 | \n",
" О нет, мне жизнь не надоела,\\nЯ жить люблю, я ... | \n",
" 196 | \n",
" 7 | \n",
" 23.625000 | \n",
"
\n",
" \n",
" 716 | \n",
" \"Твой и мой, – говорит Лафонтен —\\nРасторгло у... | \n",
" 187 | \n",
" 5 | \n",
" 30.333333 | \n",
"
\n",
" \n",
" 717 | \n",
" Когда луны сияет лик двурогой\\nИ луч ее во мра... | \n",
" 269 | \n",
" 7 | \n",
" 32.750000 | \n",
"
\n",
" \n",
" 718 | \n",
" Там, устарелый вождь! как ратник молодой,\\nИск... | \n",
" 256 | \n",
" 5 | \n",
" 41.833333 | \n",
"
\n",
" \n",
"
\n",
"
719 rows × 4 columns
\n",
"
"
],
"text/plain": [
" poem len lines \\\n",
"0 Так и мне узнать случилось,\\nЧто за птица Купи... 2536 109 \n",
"1 Хочу воспеть, как дух нечистый Ада\\nОседлан бы... 5543 170 \n",
"2 Покаместь ночь еще не удалилась,\\nПокаместь св... 4279 131 \n",
"3 Ах, отчего мне дивная природа\\nКорреджио искус... 4435 131 \n",
"4 Арист! и ты в толпе служителей Парнасса!\\nТы х... 3893 106 \n",
".. ... ... ... \n",
"714 Чудный сон мне бог послал —\\n\\nС длинной белой... 860 38 \n",
"715 О нет, мне жизнь не надоела,\\nЯ жить люблю, я ... 196 7 \n",
"716 \"Твой и мой, – говорит Лафонтен —\\nРасторгло у... 187 5 \n",
"717 Когда луны сияет лик двурогой\\nИ луч ее во мра... 269 7 \n",
"718 Там, устарелый вождь! как ратник молодой,\\nИск... 256 5 \n",
"\n",
" mean_line_len \n",
"0 23.114286 \n",
"1 33.372671 \n",
"2 33.451613 \n",
"3 33.364341 \n",
"4 38.642857 \n",
".. ... \n",
"714 22.833333 \n",
"715 23.625000 \n",
"716 30.333333 \n",
"717 32.750000 \n",
"718 41.833333 \n",
"\n",
"[719 rows x 4 columns]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"poem_df = describe_poems(text,return_df = True)\n",
"poem_df"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 300
},
"id": "TmCI6rv1f49T",
"outputId": "444fe362-1a5f-45b0-dc21-146c08e094cb"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" len | \n",
" lines | \n",
" mean_line_len | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 719.000000 | \n",
" 719.000000 | \n",
" 719.000000 | \n",
"
\n",
" \n",
" mean | \n",
" 808.037552 | \n",
" 29.464534 | \n",
" 27.445404 | \n",
"
\n",
" \n",
" std | \n",
" 1046.786862 | \n",
" 39.244020 | \n",
" 5.854564 | \n",
"
\n",
" \n",
" min | \n",
" 74.000000 | \n",
" 5.000000 | \n",
" 8.250000 | \n",
"
\n",
" \n",
" 25% | \n",
" 280.500000 | \n",
" 9.000000 | \n",
" 24.125000 | \n",
"
\n",
" \n",
" 50% | \n",
" 453.000000 | \n",
" 16.000000 | \n",
" 25.758065 | \n",
"
\n",
" \n",
" 75% | \n",
" 852.000000 | \n",
" 33.000000 | \n",
" 31.522727 | \n",
"
\n",
" \n",
" max | \n",
" 8946.000000 | \n",
" 437.000000 | \n",
" 48.923077 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" len lines mean_line_len\n",
"count 719.000000 719.000000 719.000000\n",
"mean 808.037552 29.464534 27.445404\n",
"std 1046.786862 39.244020 5.854564\n",
"min 74.000000 5.000000 8.250000\n",
"25% 280.500000 9.000000 24.125000\n",
"50% 453.000000 16.000000 25.758065\n",
"75% 852.000000 33.000000 31.522727\n",
"max 8946.000000 437.000000 48.923077"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"poem_df.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rNnrKn_lL-IJ"
},
"source": [
"## Подготовка датасетов"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3mOXOtj1FB1v"
},
"source": [
"Разбиваем данные на тренировочные, валидационные и тестовые"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "MM5Rk7B8D1n-"
},
"outputs": [],
"source": [
"train_poems, test_poems = train_test_split(poem_df.poem.to_list(),test_size = 0.1,random_state = RANDOM_STATE)\n",
"train_poems, val_poems = train_test_split(train_poems,test_size = 0.1,random_state = RANDOM_STATE)\n",
"\n",
"train_poems = f'\\n\\n{EOS_TOKEN}\\n\\n'.join(train_poems)\n",
"val_poems = f'\\n\\n{EOS_TOKEN}\\n\\n'.join(val_poems)\n",
"test_poems = f'\\n\\n{EOS_TOKEN}\\n\\n'.join(test_poems)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6QfP2RCpqdCS"
},
"source": [
"Создаем словарь уникальных символов из текста. Не забываем добавить токен конца стиха."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IlCgQBRVymwR",
"outputId": "a9769da4-3417-44e4-e491-cd1a09a51869"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"143 unique characters\n",
"['\\n', ' ', '!', '\"', \"'\", '(', ')', '*', ',', '-', '.', '/', ':', ';', '<', '>', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'H', 'I', 'J', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'Z', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'x', 'y', 'z', '\\xa0', '«', '»', 'à', 'â', 'ç', 'è', 'é', 'ê', 'ô', 'û', 'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ж', 'З', 'И', 'Й', 'К', 'Л', 'М', 'Н', 'О', 'П', 'Р', 'С', 'Т', 'У', 'Ф', 'Х', 'Ц', 'Ч', 'Ш', 'Щ', 'Э', 'Ю', 'Я', 'а', 'б', 'в', 'г', 'д', 'е', 'ж', 'з', 'и', 'й', 'к', 'л', 'м', 'н', 'о', 'п', 'р', 'с', 'т', 'у', 'ф', 'х', 'ц', 'ч', 'ш', 'щ', 'ъ', 'ы', 'ь', 'э', 'ю', 'я', 'ё', '–', '—', '„', '…', '']\n"
]
}
],
"source": [
"vocab = sorted(set(text))+[EOS_TOKEN]\n",
"print(f'{len(vocab)} unique characters')\n",
"print (vocab)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1s4f1q3iqY8f"
},
"source": [
"Для подачи на вход нейронной сети необходимо закодировать текст в виде числовой последовательности.\n",
"\n",
"Воспользуемся для этого слоем StringLookup \n",
"https://www.tensorflow.org/api_docs/python/tf/keras/layers/StringLookup"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "6GMlCe3qzaL9"
},
"outputs": [],
"source": [
"ids_from_chars = tf.keras.layers.StringLookup(\n",
" vocabulary=list(vocab), mask_token=None)\n",
"chars_from_ids = tf.keras.layers.StringLookup(\n",
" vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)\n",
"\n",
"def text_from_ids(ids):\n",
" return tf.strings.reduce_join(chars_from_ids(ids), axis=-1).numpy().decode('utf-8')\n",
" \n",
"def ids_from_text(text):\n",
" return ids_from_chars(tf.strings.unicode_split(text, input_encoding='UTF-8'))\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Wd2m3mqkDjRj",
"outputId": "88ebef4b-6488-465e-d2f7-612e27efc0d2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Корабль испанский тр\n",
"tf.Tensor(\n",
"[ 87 120 122 106 107 117 134 2 114 123 121 106 119 123 116 114 115 2\n",
" 124 122], shape=(20,), dtype=int64)\n",
"Корабль испанский тр\n"
]
}
],
"source": [
"# пример кодирования\n",
"ids = ids_from_text(train_poems[:20])\n",
"res_text = text_from_ids(ids)\n",
"print(train_poems[:20],ids,res_text,sep = '\\n')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uzC2u022WHsa"
},
"source": [
"Кодируем данные и преобразуем их в Датасеты"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "UopbsKi88tm5"
},
"outputs": [],
"source": [
"train_ids = ids_from_text(train_poems)\n",
"val_ids = ids_from_text(val_poems)\n",
"test_ids = ids_from_text(test_poems)\n",
"\n",
"train_ids_dataset = tf.data.Dataset.from_tensor_slices(train_ids)\n",
"val_ids_dataset = tf.data.Dataset.from_tensor_slices(val_ids)\n",
"test_ids_dataset = tf.data.Dataset.from_tensor_slices(test_ids)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-ZSYAcQV8OGP"
},
"source": [
"Весь текст разбивается на последовательности длины `seq_length`. По этим последовательностям будет предсказываться следующий символ.\n",
"\n",
"**Попробовать разные длины - среднюю длину строки, среднюю длину стиха**"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "C-G2oaTxy6km"
},
"outputs": [],
"source": [
"seq_length = 100\n",
"examples_per_epoch = len(train_ids_dataset)//(seq_length+1)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BpdjRO2CzOfZ",
"outputId": "515651a7-765d-4bbf-9e90-8eb9a9380759"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Корабль испанский трехмачтовый,\n",
"Пристать в Голландию готовый:\n",
"На нем мерзавцев сотни три,\n",
"Две обезьян\n"
]
}
],
"source": [
"train_sequences = train_ids_dataset.batch(seq_length+1, drop_remainder=True)\n",
"val_sequences = val_ids_dataset.batch(seq_length+1, drop_remainder=True)\n",
"test_sequences = test_ids_dataset.batch(seq_length+1, drop_remainder=True)\n",
"\n",
"for seq in train_sequences.take(1):\n",
" print(text_from_ids(seq))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UbLcIPBj_mWZ"
},
"source": [
"Создаем датасет с input и target строками\n",
"\n",
"target сдвинута относительно input на один символ.\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "9NGu-FkO_kYU"
},
"outputs": [],
"source": [
"def split_input_target(sequence):\n",
" input_text = sequence[:-1]\n",
" target_text = sequence[1:]\n",
" return input_text, target_text"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WxbDTJTw5u_P",
"outputId": "f44e70c6-b600-4fb3-8073-ba8d774d8832"
},
"outputs": [
{
"data": {
"text/plain": [
"(['П', 'у', 'ш', 'к', 'и'], ['у', 'ш', 'к', 'и', 'н'])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# пример\n",
"split_input_target(list(\"Пушкин\"))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"id": "B9iKPXkw5xwa"
},
"outputs": [],
"source": [
"train_dataset = train_sequences.map(split_input_target)\n",
"val_dataset = val_sequences.map(split_input_target)\n",
"test_dataset = test_sequences.map(split_input_target)\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GNbw-iR0ymwj",
"outputId": "888b397b-5370-4ae4-d2ba-98d54595ee72"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input : Прими сей череп, Дельвиг, он\n",
"Принадлежит тебе по праву.\n",
"Тебе поведаю, барон,\n",
"Его готическую славу.\n",
"\n",
"\n",
"Target: рими сей череп, Дельвиг, он\n",
"Принадлежит тебе по праву.\n",
"Тебе поведаю, барон,\n",
"Его готическую славу.\n",
"\n",
"П\n"
]
}
],
"source": [
"for input_example, target_example in val_dataset.take(1):\n",
" print(\"Input :\", text_from_ids(input_example))\n",
" print(\"Target:\", text_from_ids(target_example))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MJdfPmdqzf-R"
},
"source": [
"Перемешиваем датасеты и разбиваем их на батчи для оптимизации обучения"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "p2pGotuNzf-S",
"outputId": "d9ae90cb-d904-4528-d0ed-eb11caeb43d3"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Batch size\n",
"BATCH_SIZE = 64\n",
"\n",
"BUFFER_SIZE = 10000\n",
"\n",
"def prepare_dataset(dataset):\n",
" dataset = (\n",
" dataset\n",
" .shuffle(BUFFER_SIZE)\n",
" .batch(BATCH_SIZE, drop_remainder=True)\n",
" .prefetch(tf.data.experimental.AUTOTUNE))\n",
" return dataset \n",
"\n",
"train_dataset = prepare_dataset(train_dataset)\n",
"val_dataset = prepare_dataset(val_dataset)\n",
"test_dataset = prepare_dataset(test_dataset)\n",
"\n",
"train_dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r6oUuElIMgVx"
},
"source": [
"# Нейросеть"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "22uVCbSyPBjD"
},
"source": [
"## Построение модели"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m8gPwEjRzf-Z"
},
"source": [
"Модель состоит из трех слоев\n",
"\n",
"* `tf.keras.layers.Embedding`: Входной слой. Кодирует каждый идентификатор символа в вектор размерностью `embedding_dim`; \n",
"* `tf.keras.layers.GRU`: Рекуррентный слой на ячейках GRU. Выходной вектор размерностью `units=rnn_units` **(Здесь нужно указать тип ячеек в соответствии с вариантом)**\n",
"* `tf.keras.layers.Dense`: Выходной полносвязный слой размерностью `vocab_size`, в который выводится вероятность каждого символа в словаре. \n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"id": "zHT8cLh7EAsg"
},
"outputs": [],
"source": [
"# Длина словаря символов\n",
"vocab_size = len(vocab)\n",
"\n",
"# размерность Embedding'а\n",
"embedding_dim = 20 #@param{type:\"number\"}\n",
"\n",
"# Параметры RNN-слоя\n",
"rnn_units = 300 #@param {type:\"number\"}\n",
"dropout_p = 0.5"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "wj8HQ2w8z4iO"
},
"outputs": [],
"source": [
"class MyModel(tf.keras.Model):\n",
" def __init__(self, vocab_size, embedding_dim, rnn_units):\n",
" super().__init__(self)\n",
" self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
" self.gru = tf.keras.layers.GRU(rnn_units,\n",
" dropout = dropout_p,\n",
" return_sequences=True,\n",
" return_state=True)\n",
" self.dense = tf.keras.layers.Dense(vocab_size)\n",
"\n",
" def call(self, inputs, states=None, return_state=False, training=False):\n",
" x = inputs\n",
" x = self.embedding(x, training=training)\n",
" \n",
" if states is None:\n",
" states = self.gru.get_initial_state(x)\n",
"\n",
" x, states = self.gru(x, initial_state=states, training=training)\n",
" x = self.dense(x, training=training)\n",
"\n",
" if return_state:\n",
" return x, states\n",
" else:\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"id": "IX58Xj9z47Aw"
},
"outputs": [],
"source": [
"model = MyModel(\n",
" vocab_size=len(ids_from_chars.get_vocabulary()),\n",
" embedding_dim=embedding_dim,\n",
" rnn_units=rnn_units)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RkA5upJIJ7W7"
},
"source": [
"Иллюстрация работы сети\n",
"\n",
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LdgaUC3tPHAy"
},
"source": [
"## Проверка необученой модели"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "C-_70kKAPrPU",
"outputId": "02de9cf7-29d5-4345-8e02-b8c940ca5c1a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(64, 100, 144) # (batch_size, sequence_length, vocab_size)\n"
]
}
],
"source": [
"# посмотрим на один батч из датасета\n",
"for input_example_batch, target_example_batch in train_dataset.take(1):\n",
" example_batch_predictions = model(input_example_batch)\n",
" print(example_batch_predictions.shape, \"# (batch_size, sequence_length, vocab_size)\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uwv0gEkURfx1"
},
"source": [
"prediction() предсказывает логиты вероятности каждого символа на следующей позиции. При этом, если мы будем выбирать символ с максимальной вероятностью, то из раза в раз модель нам будет выдавать один и тот же текст. \n",
"Чтобы этого избежать, нужно выбирать очередной индекс из распределения\n",
"`tf.random.categorical` - чем выше значение на выходном слое полносвязной сети, тем вероятнее, что данный символ будет выбран в качестве очередного. Однако, это не обязательно будет символ с максимальной вероятностью.\n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "W6_G5P7W1TrE",
"outputId": "5fdc06a0-9be3-42b4-8054-454e997586ee"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"example_batch_predictions[0][0]\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0yJ5Je4J1XRb"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UuAT_ymD15U7"
},
"source": [
"На картинке отмечены наиболее вероятные символы."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"id": "4V4MfFg0RQJg"
},
"outputs": [],
"source": [
"sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)\n",
"sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YqFMUQc_UFgM",
"outputId": "1842194b-d7d7-4702-bca5-6b03f04b9432"
},
"outputs": [
{
"data": {
"text/plain": [
"array([103, 2, 125, 127, 46, 128, 85, 84, 14, 37, 55, 7, 129,\n",
" 123, 72, 38, 138, 88, 116, 125, 142, 109, 110, 131, 21, 29,\n",
" 15, 99, 118, 48, 15, 143, 106, 139, 28, 115, 119, 8, 41,\n",
" 55, 138, 68, 130, 133, 135, 4, 114, 10, 62, 130, 120, 47,\n",
" 119, 16, 87, 71, 32, 111, 121, 85, 13, 87, 87, 75, 25,\n",
" 91, 41, 83, 51, 106, 100, 133, 3, 9, 37, 36, 103, 61,\n",
" 2, 79, 136, 56, 99, 101, 20, 143, 70, 117, 60, 43, 6,\n",
" 49, 51, 15, 85, 115, 113, 138, 23, 89], dtype=int64)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sampled_indices"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xWcFwPwLSo05",
"outputId": "bed36421-88f9-429f-a9cd-099717127029"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input:\n",
" же прошли. Их мненья, толки, страсти\n",
"Забыты для других. Смотри: вокруг тебя\n",
"Всё новое кипит, былое и\n",
"\n",
"Next Char Predictions:\n",
" Э ухfцИЗ;Vo)чсèWёЛку…гдщDN<Цмh<а–Mйн*aoё»шыэ\"и-vшоgн>КçQепИ:ККôIОaЖkаЧы!,VUЭu ВюpЦШCâлtc(ik<ИйзёFМ\n"
]
}
],
"source": [
"print(\"Input:\\n\", text_from_ids(input_example_batch[0]))\n",
"print()\n",
"print(\"Next Char Predictions:\\n\", text_from_ids(sampled_indices))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LJL0Q0YPY6Ee"
},
"source": [
"## Обучение модели"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YCbHQHiaa4Ic"
},
"source": [
"\n",
"Можно представить задачу как задачу классификации - по предыдущему состоянию RNN и входу в данный момент времени предсказать класс (очередной символ). "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "trpqTWyvk0nr"
},
"source": [
"### Настройка оптимизатора и функции потерь"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UAjbjY03eiQ4"
},
"source": [
"В этом случае работает стандартная функция потерь `tf.keras.losses.sparse_categorical_crossentropy`- кроссэнтропия, которая равна минус логарифму предсказанной вероятности для верного класса.\n",
"\n",
"Поскольку модель возвращает логиты, вам необходимо установить флаг `from_logits`."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"id": "ZOeWdgxNFDXq"
},
"outputs": [],
"source": [
"loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4HrXTACTdzY-",
"outputId": "f0dbf6f2-738e-48f8-df7c-0bbf60fc17e4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prediction shape: (64, 100, 144) # (batch_size, sequence_length, vocab_size)\n",
"Mean loss: tf.Tensor(4.970122, shape=(), dtype=float32)\n"
]
}
],
"source": [
"example_batch_mean_loss = loss(target_example_batch, example_batch_predictions)\n",
"print(\"Prediction shape: \", example_batch_predictions.shape, \" # (batch_size, sequence_length, vocab_size)\")\n",
"print(\"Mean loss: \", example_batch_mean_loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vkvUIneTFiow"
},
"source": [
"Необученная модель не может делать адекватные предсказания. Ее перплексия («коэффициент неопределённости») приблизительно равна размеру словаря. Это говорит о полной неопределенности модели при генерации текста.\n",
"\n",
"Перплексия = exp(кроссэнтропия)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MAJfS5YoFiHf",
"outputId": "0588a9b2-8598-458f-db1d-6923e6a50502"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"perplexity: 144.04443\n"
]
}
],
"source": [
"print('perplexity: ',np.exp(example_batch_mean_loss))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jeOXriLcymww"
},
"source": [
"Настраиваем обучение, используя метод `tf.keras.Model.compile`. Используйте `tf.keras.optimizers.Adam` с аргументами по умолчанию и функцией потерь."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"id": "DDl1_Een6rL0"
},
"outputs": [],
"source": [
"model.compile(optimizer='adam', loss=loss)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vPGmAAXmVLGC",
"outputId": "1399a58c-85ff-430d-f67f-942942dec071"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"my_model\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" embedding (Embedding) multiple 2880 \n",
" \n",
" gru (GRU) multiple 289800 \n",
" \n",
" dense (Dense) multiple 43344 \n",
" \n",
"=================================================================\n",
"Total params: 336,024\n",
"Trainable params: 336,024\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C6XBUUavgF56"
},
"source": [
"Используем `tf.keras.callbacks.ModelCheckpoint`, чтобы убедиться, что контрольные точки сохраняются во время обучения:"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"id": "W6fWTriUZP-n"
},
"outputs": [],
"source": [
"# Directory where the checkpoints will be saved\n",
"checkpoint_dir = './training_checkpoints'\n",
"# Name of the checkpoint files\n",
"checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")\n",
"\n",
"checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n",
" filepath=checkpoint_prefix,\n",
" monitor=\"val_loss\",\n",
" save_weights_only=True,\n",
" save_best_only=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3Ky3F_BhgkTW"
},
"source": [
"### Обучение!"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"id": "7yGBE2zxMMHs"
},
"outputs": [],
"source": [
"EPOCHS = 5"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UK-hmKjYVoll",
"outputId": "4403459e-c93d-4361-b6b4-d4f4a29797e5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"72/72 [==============================] - 24s 301ms/step - loss: 3.7387 - val_loss: 3.4320\n",
"Epoch 2/5\n",
"72/72 [==============================] - 22s 294ms/step - loss: 3.2194 - val_loss: 2.8863\n",
"Epoch 3/5\n",
"72/72 [==============================] - 22s 292ms/step - loss: 2.8698 - val_loss: 2.7059\n",
"Epoch 4/5\n",
"72/72 [==============================] - 23s 309ms/step - loss: 2.7643 - val_loss: 2.6365\n",
"Epoch 5/5\n",
"72/72 [==============================] - 24s 320ms/step - loss: 2.6941 - val_loss: 2.5800\n"
]
}
],
"source": [
"history = model.fit(train_dataset, validation_data = val_dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UejMtXPmH-U0",
"outputId": "ec06daed-25a1-4ae3-9f9a-cd130e04c216"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9/9 [==============================] - 1s 101ms/step - loss: 2.5769\n",
"eval loss: 2.576871871948242\n",
"perplexity 13.155920322524834\n"
]
}
],
"source": [
"eval_loss = model.evaluate(test_dataset)\n",
"print('eval loss:',eval_loss)\n",
"print('perplexity',np.exp(eval_loss))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kKkD5M6eoSiN"
},
"source": [
"## Генерация текста"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oIdQ8c8NvMzV"
},
"source": [
"Самый простой способ сгенерировать текст с помощью этой модели — запустить ее в цикле и отслеживать внутреннее состояние модели по мере ее выполнения.\n",
"\n",
"\n",
"\n",
"Каждый раз, когда вы вызываете модель, вы передаете некоторый текст и внутреннее состояние. Модель возвращает прогноз для следующего символа и его нового состояния. Передайте предсказание и состояние обратно, чтобы продолжить создание текста.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DjGz1tDkzf-u"
},
"source": [
"Создаем модель реализующую один шаг предсказания:"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"id": "iSBU1tHmlUSs"
},
"outputs": [],
"source": [
"class OneStep(tf.keras.Model):\n",
" def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):\n",
" super().__init__()\n",
" self.temperature = temperature\n",
" self.model = model\n",
" self.chars_from_ids = chars_from_ids\n",
" self.ids_from_chars = ids_from_chars\n",
"\n",
" # Create a mask to prevent \"[UNK]\" from being generated.\n",
" skip_ids = self.ids_from_chars(['[UNK]'])[:, None]\n",
" sparse_mask = tf.SparseTensor(\n",
" # Put a -inf at each bad index.\n",
" values=[-float('inf')]*len(skip_ids),\n",
" indices=skip_ids,\n",
" # Match the shape to the vocabulary\n",
" dense_shape=[len(ids_from_chars.get_vocabulary())])\n",
" self.prediction_mask = tf.sparse.to_dense(sparse_mask)\n",
"\n",
" \n",
" # Этот фрагмент целиком написан с использованием Tensorflow, поэтому его можно выполнять \n",
" # не с помощью интерпретатора языка Python, а через граф операций. Это будет значительно быстрее. \n",
" # Для этого воспользуемся декоратором @tf.function \n",
" @tf.function \n",
" def generate_one_step(self, inputs, states=None,temperature=1.0):\n",
" # Convert strings to token IDs.\n",
" input_chars = tf.strings.unicode_split(inputs, 'UTF-8')\n",
" input_ids = self.ids_from_chars(input_chars).to_tensor()\n",
"\n",
" # Run the model.\n",
" # predicted_logits.shape is [batch, char, next_char_logits]\n",
" predicted_logits, states = self.model(inputs=input_ids, states=states,\n",
" return_state=True)\n",
" # Only use the last prediction.\n",
" predicted_logits = predicted_logits[:, -1, :]\n",
" predicted_logits = predicted_logits/temperature\n",
" # Apply the prediction mask: prevent \"[UNK]\" from being generated.\n",
" predicted_logits = predicted_logits + self.prediction_mask\n",
"\n",
" # Sample the output logits to generate token IDs.\n",
" predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)\n",
" predicted_ids = tf.squeeze(predicted_ids, axis=-1)\n",
"\n",
" # Convert from token ids to characters\n",
" predicted_chars = self.chars_from_ids(predicted_ids)\n",
"\n",
"\n",
" # Return the characters and model state.\n",
" return predicted_chars, states"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"id": "fqMOuDutnOxK"
},
"outputs": [],
"source": [
"one_step_model = OneStep(model, chars_from_ids, ids_from_chars)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p9yDoa0G3IgQ"
},
"source": [
"Изменяя температуру можно регулировать вариативность текста"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ST7PSyk9t1mT",
"outputId": "ed0dfecc-3d31-4bfe-efeb-d2315044ef06"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"uКочаму,\n",
"Лыт\n",
"Нума Гве!\n",
"àvо етекы ва;\n",
"Хо> вы,\n",
"Муметы,.\n",
"Воцу:\n",
"ОтекаЕдраэмь эзоспосанах…Я кы.\n",
") маа Почи защоючеты>\n",
"Алаций поши)\"Пры – /й пий лазой\n",
"Еже?\n",
"И,\n",
"В, à*\n",
"logJ\"nДаогей piКо '\n",
"Дя цый,;\n",
"Зази:\n",
"Унумыйй поль ь,\n",
"Тост,,,\n",
"Дежы!Шоны\n",
"H-му,\n",
"iБачуй,,;цо,\n",
"Тахрей\n",
"Ты а;\n",
"Памогой елыны…\n",
"Жебы\n",
"Ве яны!\n",
"Экычигруй,\n",
"И вьчам ре кабощалуй, ль наша водух, ё ко,\n",
"Чячу бемий>\n",
"Елетанех гружара, eалодищи:\n",
"И,\n",
"Гро вишемороцемапа\"oОмицы\n",
"upЛицль: Преща, Оды iH бочавоча!, Ценобый ни\n",
"Вныла,\n",
"Чу:.\n",
"Ввуй,\n",
"eРаДобе по!\n",
"Qтый,\n",
"Чки fХралошумо:\n",
"na-Койкаца –\n",
"Ай\n",
"Мо,\n",
"i)цыденатедубодех выйха\n",
"Полынене на бы.\n",
"Заму рыхасла cИмино Одорафута уга гоЮй вожве,\n",
"ГчаCКазко.\n",
"Мотя выйхоруnАчлатоный\n",
"Рабойх удапи,\n",
"Ты уТашь,\n",
"Обымаекоги вайзоты зой кишаметумилетелы\n",
"Я, —\n",
"yJи, солыFлы!éuХуZетце, ё\n",
"oПробемнанетече ей.\n",
"Ри,\n",
"Шоный.\n",
"По!\n",
"Ностыгу!\n",
"Hv\n",
"________________________________________________________________________________\n",
"\n",
"Run time: 1.620434284210205\n"
]
}
],
"source": [
"T = 0.5 #@param {type:\"slider\", min:0, max:2, step:0.1}\n",
"N = 1000\n",
"\n",
"start = time.time()\n",
"states = None\n",
"next_char = tf.constant(['\\n'])\n",
"result = [next_char]\n",
"\n",
"for n in range(N):\n",
" next_char, states = one_step_model.generate_one_step(next_char, states=states,temperature=T)\n",
" result.append(next_char)\n",
"\n",
"result = tf.strings.join(result)\n",
"end = time.time()\n",
"\n",
"result_text = result[0].numpy().decode('utf-8')\n",
"print(result_text)\n",
"print('_'*80)\n",
"print('\\nRun time:', end - start)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 300
},
"id": "VCwWY9xM6KCB",
"outputId": "e709c661-8bbe-4abf-e329-db9af7e6eb55"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" len | \n",
" lines | \n",
" mean_line_len | \n",
"
\n",
" \n",
" \n",
" \n",
" count | \n",
" 2.000000 | \n",
" 2.000000 | \n",
" 2.000000 | \n",
"
\n",
" \n",
" mean | \n",
" 499.000000 | \n",
" 35.500000 | \n",
" 12.123718 | \n",
"
\n",
" \n",
" std | \n",
" 482.246825 | \n",
" 33.234019 | \n",
" 1.262820 | \n",
"
\n",
" \n",
" min | \n",
" 158.000000 | \n",
" 12.000000 | \n",
" 11.230769 | \n",
"
\n",
" \n",
" 25% | \n",
" 328.500000 | \n",
" 23.750000 | \n",
" 11.677244 | \n",
"
\n",
" \n",
" 50% | \n",
" 499.000000 | \n",
" 35.500000 | \n",
" 12.123718 | \n",
"
\n",
" \n",
" 75% | \n",
" 669.500000 | \n",
" 47.250000 | \n",
" 12.570192 | \n",
"
\n",
" \n",
" max | \n",
" 840.000000 | \n",
" 59.000000 | \n",
" 13.016667 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" len lines mean_line_len\n",
"count 2.000000 2.000000 2.000000\n",
"mean 499.000000 35.500000 12.123718\n",
"std 482.246825 33.234019 1.262820\n",
"min 158.000000 12.000000 11.230769\n",
"25% 328.500000 23.750000 11.677244\n",
"50% 499.000000 35.500000 12.123718\n",
"75% 669.500000 47.250000 12.570192\n",
"max 840.000000 59.000000 13.016667"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"describe_poems(result_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "db7UJQr9ILfW"
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "22rnSwqqICn2"
},
"source": [
"По мотивам https://colab.research.google.com/github/tensorflow/text/blob/master/docs/tutorials/text_generation.ipynb"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}