From 13bfecb6111836680412e34cd7e41fa3f28c11b5 Mon Sep 17 00:00:00 2001 From: Andrey Date: Tue, 9 Apr 2024 23:18:46 +0300 Subject: [PATCH] lab04: EOS in to_matrix --- labs/OATD_LR4.ipynb | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/labs/OATD_LR4.ipynb b/labs/OATD_LR4.ipynb index f0a7aca..26f7f1b 100644 --- a/labs/OATD_LR4.ipynb +++ b/labs/OATD_LR4.ipynb @@ -163,7 +163,7 @@ "source": [ "!wget -O poems.txt http://uit.mpei.ru/git/main/TDA/raw/branch/master/assets/poems/pushkin.txt\n", "\n", - "# Маяковский: http://uit.mpei.ru/git/main/TDA/raw/branch/master/assets/poems/mayakovskiy.txt\n" + "# Маяковский: http://uit.mpei.ru/git/main/TDA/raw/branch/master/assets/poems/mayakovskiy.txt\n" ] }, { @@ -466,20 +466,18 @@ }, { "cell_type": "code", - "execution_count": 19, - "metadata": { - "id": "YYGc7vo2FoTw" - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "def to_matrix(data, char2id, max_len=None, dtype='int64', batch_first = True):\n", "\n", " max_len = max_len if max_len else max(map(len, data))\n", - " data = [text[:max_len] + \" EOS\" for text in data]\n", - " data_ix = np.zeros([len(data), max_len], dtype)\n", + " data = [text[:max_len] for text in data]\n", + " data_ix = np.zeros([len(data), max_len+1], dtype)\n", "\n", " for i in range(len(data)):\n", - " line_ix = [char2id[c] for c in data[i][:max_len]]\n", + " line_ix = [char2id[c] for c in data[i][:max_len]] + [char2id[\"EOS\"]]\n", " data_ix[i, :len(line_ix)] = line_ix\n", "\n", " if not batch_first: # convert [batch, time] into [time, batch]\n",