Commit 014d32d2 by 前钰

Upload New File

parent 239e6bc6
{
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"source": [
"!pip uninstall torchtext"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "w9_nFDYUMx8V",
"outputId": "5eafc13a-0a81-4ae4-ae1c-2dbe8b66772c"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Found existing installation: torchtext 0.5.0\n",
"Uninstalling torchtext-0.5.0:\n",
" Would remove:\n",
" /usr/local/lib/python3.10/dist-packages/torchtext-0.5.0.dist-info/*\n",
" /usr/local/lib/python3.10/dist-packages/torchtext/*\n",
"Proceed (Y/n)? y\n",
" Successfully uninstalled torchtext-0.5.0\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!pip install torchtext==0.5.0"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6GJJtWbWSbUP",
"outputId": "c86370de-0d0f-4a20-9946-c62ada7826d8"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting torchtext==0.5.0\n",
" Using cached torchtext-0.5.0-py3-none-any.whl (73 kB)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from torchtext==0.5.0) (4.66.1)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchtext==0.5.0) (2.31.0)\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from torchtext==0.5.0) (2.1.0+cu121)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchtext==0.5.0) (1.23.5)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from torchtext==0.5.0) (1.16.0)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from torchtext==0.5.0) (0.1.99)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torchtext==0.5.0) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchtext==0.5.0) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchtext==0.5.0) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchtext==0.5.0) (2024.2.2)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (3.13.1)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (4.9.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (3.1.3)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (2023.6.0)\n",
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (2.1.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->torchtext==0.5.0) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->torchtext==0.5.0) (1.3.0)\n",
"Installing collected packages: torchtext\n",
"Successfully installed torchtext-0.5.0\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "_4gbBNMaLkMP"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"from torchtext import data,datasets\n",
"\n",
"import spacy\n",
"import numpy as np\n",
"\n",
"import time\n",
"import random"
]
},
{
"cell_type": "code",
"source": [
"SEED = 1234\n",
"\n",
"random.seed(SEED)\n",
"np.random.seed(SEED)\n",
"torch.manual_seed(SEED)\n",
"torch.backends.cudnn.deterministic = True"
],
"metadata": {
"id": "spx1Z90lMnB_"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"TEXT = data.Field(lower = True)\n",
"UD_TAGS = data.Field(unk_token = None)\n",
"PTB_TAGS = data.Field(unk_token = None)"
],
"metadata": {
"id": "kMcOX2FJMpP-"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"fields = ((\"text\", TEXT), (\"udtags\", UD_TAGS), (\"ptbtags\", PTB_TAGS))"
],
"metadata": {
"id": "2EC4HzpdNtnQ"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_data, valid_data, test_data = datasets.UDPOS.splits(fields)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KHh4bmRhSxbE",
"outputId": "f149e1a3-6fbc-4244-abff-fed820872bd4"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"downloading en-ud-v2.zip\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"en-ud-v2.zip: 100%|██████████| 688k/688k [00:00<00:00, 7.52MB/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"extracting\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(f\"Number of training examples: {len(train_data)}\")\n",
"print(f\"Number of validation examples: {len(valid_data)}\")\n",
"print(f\"Number of testing examples: {len(test_data)}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "q4eM725LSyi8",
"outputId": "01425ced-bf53-444b-b2ba-907c5c288f5a"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of training examples: 12543\n",
"Number of validation examples: 2002\n",
"Number of testing examples: 2077\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(vars(train_data.examples[0]))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TaRQJiiVS0NR",
"outputId": "4081d745-cf1a-4564-ec06-33e3a33b66f6"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'text': ['al', '-', 'zaman', ':', 'american', 'forces', 'killed', 'shaikh', 'abdullah', 'al', '-', 'ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'qaim', ',', 'near', 'the', 'syrian', 'border', '.'], 'udtags': ['PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'ADJ', 'NOUN', 'VERB', 'PROPN', 'PROPN', 'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'PROPN', 'PUNCT', 'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT'], 'ptbtags': ['NNP', 'HYPH', 'NNP', ':', 'JJ', 'NNS', 'VBD', 'NNP', 'NNP', 'NNP', 'HYPH', 'NNP', ',', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'NNP', ',', 'IN', 'DT', 'JJ', 'NN', '.']}\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(vars(train_data.examples[0])['text'])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ao1b06NOS1oS",
"outputId": "fb948684-5bf7-44b6-9b70-128abc472565"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['al', '-', 'zaman', ':', 'american', 'forces', 'killed', 'shaikh', 'abdullah', 'al', '-', 'ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'qaim', ',', 'near', 'the', 'syrian', 'border', '.']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(vars(train_data.examples[0])['udtags'])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9Vqjl52IS3O0",
"outputId": "4c399263-2888-4830-8e06-b7f38805a300"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'ADJ', 'NOUN', 'VERB', 'PROPN', 'PROPN', 'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'PROPN', 'PUNCT', 'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(vars(train_data.examples[0])['ptbtags'])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ANZbJVQ4S4Vr",
"outputId": "b1c139c0-05d2-491d-d4db-a83d32cbbf99"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['NNP', 'HYPH', 'NNP', ':', 'JJ', 'NNS', 'VBD', 'NNP', 'NNP', 'NNP', 'HYPH', 'NNP', ',', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'NNP', ',', 'IN', 'DT', 'JJ', 'NN', '.']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"MIN_FREQ = 2\n",
"\n",
"TEXT.build_vocab(train_data,\n",
" min_freq = MIN_FREQ,\n",
" vectors = \"glove.6B.100d\",\n",
" unk_init = torch.Tensor.normal_)\n",
"\n",
"\n",
"UD_TAGS.build_vocab(train_data)\n",
"PTB_TAGS.build_vocab(train_data)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8WIWCycgS5ZH",
"outputId": "0528fbaf-e0eb-4abe-f90b-ad96c9aea08b"
},
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
".vector_cache/glove.6B.zip: 862MB [03:04, 4.67MB/s] \n",
"100%|█████████▉| 399999/400000 [00:30<00:00, 13181.49it/s]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(f\"Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}\")\n",
"print(f\"Unique tokens in UD_TAG vocabulary: {len(UD_TAGS.vocab)}\")\n",
"print(f\"Unique tokens in PTB_TAG vocabulary: {len(PTB_TAGS.vocab)}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9fV1-uRDS77N",
"outputId": "f3da5965-9e98-46a9-bb70-9c1c9cfbd069"
},
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Unique tokens in TEXT vocabulary: 8866\n",
"Unique tokens in UD_TAG vocabulary: 18\n",
"Unique tokens in PTB_TAG vocabulary: 51\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(TEXT.vocab.freqs.most_common(20))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iYgP-itMTAIf",
"outputId": "71bcfca9-05e8-4dc4-8c8c-464724646401"
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[('the', 9076), ('.', 8640), (',', 7021), ('to', 5137), ('and', 5002), ('a', 3782), ('of', 3622), ('i', 3379), ('in', 3112), ('is', 2239), ('you', 2156), ('that', 2036), ('it', 1850), ('for', 1842), ('-', 1426), ('have', 1359), ('\"', 1296), ('on', 1273), ('was', 1244), ('with', 1216)]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(UD_TAGS.vocab.itos)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "R4B3J-aWTB8W",
"outputId": "da9009a0-c3e7-4b52-e225-9c522c0e61a6"
},
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['<pad>', 'NOUN', 'PUNCT', 'VERB', 'PRON', 'ADP', 'DET', 'PROPN', 'ADJ', 'AUX', 'ADV', 'CCONJ', 'PART', 'NUM', 'SCONJ', 'X', 'INTJ', 'SYM']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(PTB_TAGS.vocab.itos)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GBMmathDTC-V",
"outputId": "1770c0d3-e330-4797-c6a3-69666d2c6ed2"
},
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['<pad>', 'NN', 'IN', 'DT', 'NNP', 'PRP', 'JJ', 'RB', '.', 'VB', 'NNS', ',', 'CC', 'VBD', 'VBP', 'VBZ', 'CD', 'VBN', 'VBG', 'MD', 'TO', 'PRP$', '-RRB-', '-LRB-', 'WDT', 'WRB', ':', '``', \"''\", 'WP', 'RP', 'UH', 'POS', 'HYPH', 'JJR', 'NNPS', 'JJS', 'EX', 'NFP', 'GW', 'ADD', 'RBR', '$', 'PDT', 'RBS', 'SYM', 'LS', 'FW', 'AFX', 'WP$', 'XX']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(UD_TAGS.vocab.freqs.most_common())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dq6Wn8BATELe",
"outputId": "2410f1bc-28fd-4d45-8f81-024a6c4809b9"
},
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[('NOUN', 34781), ('PUNCT', 23679), ('VERB', 23081), ('PRON', 18577), ('ADP', 17638), ('DET', 16285), ('PROPN', 12946), ('ADJ', 12477), ('AUX', 12343), ('ADV', 10548), ('CCONJ', 6707), ('PART', 5567), ('NUM', 3999), ('SCONJ', 3843), ('X', 847), ('INTJ', 688), ('SYM', 599)]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(PTB_TAGS.vocab.freqs.most_common())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XuAMn3K2TFbC",
"outputId": "73f430f8-37b3-4700-8769-fcac7fecb74a"
},
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[('NN', 26915), ('IN', 20724), ('DT', 16817), ('NNP', 12449), ('PRP', 12193), ('JJ', 11591), ('RB', 10831), ('.', 10317), ('VB', 9476), ('NNS', 8438), (',', 8062), ('CC', 6706), ('VBD', 5402), ('VBP', 5374), ('VBZ', 4578), ('CD', 3998), ('VBN', 3967), ('VBG', 3330), ('MD', 3294), ('TO', 3286), ('PRP$', 3068), ('-RRB-', 1008), ('-LRB-', 973), ('WDT', 948), ('WRB', 869), (':', 866), ('``', 813), (\"''\", 785), ('WP', 760), ('RP', 755), ('UH', 689), ('POS', 684), ('HYPH', 664), ('JJR', 503), ('NNPS', 498), ('JJS', 383), ('EX', 359), ('NFP', 338), ('GW', 294), ('ADD', 292), ('RBR', 276), ('$', 258), ('PDT', 175), ('RBS', 169), ('SYM', 156), ('LS', 117), ('FW', 93), ('AFX', 48), ('WP$', 15), ('XX', 1)]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"def tag_percentage(tag_counts):\n",
"\n",
" total_count = sum([count for tag, count in tag_counts])\n",
"\n",
" tag_counts_percentages = [(tag, count, count/total_count) for tag, count in tag_counts]\n",
"\n",
" return tag_counts_percentages"
],
"metadata": {
"id": "A9FbORvyTGj4"
},
"execution_count": 20,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(\"Tag\\t\\tCount\\t\\tPercentage\\n\")\n",
"\n",
"for tag, count, percent in tag_percentage(UD_TAGS.vocab.freqs.most_common()):\n",
" print(f\"{tag}\\t\\t{count}\\t\\t{percent*100:4.1f}%\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WsnW25s5TIyK",
"outputId": "8078b619-1322-4fec-826f-aac16e217d6b"
},
"execution_count": 21,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Tag\t\tCount\t\tPercentage\n",
"\n",
"NOUN\t\t34781\t\t17.0%\n",
"PUNCT\t\t23679\t\t11.6%\n",
"VERB\t\t23081\t\t11.3%\n",
"PRON\t\t18577\t\t 9.1%\n",
"ADP\t\t17638\t\t 8.6%\n",
"DET\t\t16285\t\t 8.0%\n",
"PROPN\t\t12946\t\t 6.3%\n",
"ADJ\t\t12477\t\t 6.1%\n",
"AUX\t\t12343\t\t 6.0%\n",
"ADV\t\t10548\t\t 5.2%\n",
"CCONJ\t\t6707\t\t 3.3%\n",
"PART\t\t5567\t\t 2.7%\n",
"NUM\t\t3999\t\t 2.0%\n",
"SCONJ\t\t3843\t\t 1.9%\n",
"X\t\t847\t\t 0.4%\n",
"INTJ\t\t688\t\t 0.3%\n",
"SYM\t\t599\t\t 0.3%\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(\"Tag\\t\\tCount\\t\\tPercentage\\n\")\n",
"\n",
"for tag, count, percent in tag_percentage(PTB_TAGS.vocab.freqs.most_common()):\n",
" print(f\"{tag}\\t\\t{count}\\t\\t{percent*100:4.1f}%\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eEsT_qOaTLe0",
"outputId": "e4a2dd6c-d40c-4c10-9ff9-0898de4b9ce4"
},
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Tag\t\tCount\t\tPercentage\n",
"\n",
"NN\t\t26915\t\t13.2%\n",
"IN\t\t20724\t\t10.1%\n",
"DT\t\t16817\t\t 8.2%\n",
"NNP\t\t12449\t\t 6.1%\n",
"PRP\t\t12193\t\t 6.0%\n",
"JJ\t\t11591\t\t 5.7%\n",
"RB\t\t10831\t\t 5.3%\n",
".\t\t10317\t\t 5.0%\n",
"VB\t\t9476\t\t 4.6%\n",
"NNS\t\t8438\t\t 4.1%\n",
",\t\t8062\t\t 3.9%\n",
"CC\t\t6706\t\t 3.3%\n",
"VBD\t\t5402\t\t 2.6%\n",
"VBP\t\t5374\t\t 2.6%\n",
"VBZ\t\t4578\t\t 2.2%\n",
"CD\t\t3998\t\t 2.0%\n",
"VBN\t\t3967\t\t 1.9%\n",
"VBG\t\t3330\t\t 1.6%\n",
"MD\t\t3294\t\t 1.6%\n",
"TO\t\t3286\t\t 1.6%\n",
"PRP$\t\t3068\t\t 1.5%\n",
"-RRB-\t\t1008\t\t 0.5%\n",
"-LRB-\t\t973\t\t 0.5%\n",
"WDT\t\t948\t\t 0.5%\n",
"WRB\t\t869\t\t 0.4%\n",
":\t\t866\t\t 0.4%\n",
"``\t\t813\t\t 0.4%\n",
"''\t\t785\t\t 0.4%\n",
"WP\t\t760\t\t 0.4%\n",
"RP\t\t755\t\t 0.4%\n",
"UH\t\t689\t\t 0.3%\n",
"POS\t\t684\t\t 0.3%\n",
"HYPH\t\t664\t\t 0.3%\n",
"JJR\t\t503\t\t 0.2%\n",
"NNPS\t\t498\t\t 0.2%\n",
"JJS\t\t383\t\t 0.2%\n",
"EX\t\t359\t\t 0.2%\n",
"NFP\t\t338\t\t 0.2%\n",
"GW\t\t294\t\t 0.1%\n",
"ADD\t\t292\t\t 0.1%\n",
"RBR\t\t276\t\t 0.1%\n",
"$\t\t258\t\t 0.1%\n",
"PDT\t\t175\t\t 0.1%\n",
"RBS\t\t169\t\t 0.1%\n",
"SYM\t\t156\t\t 0.1%\n",
"LS\t\t117\t\t 0.1%\n",
"FW\t\t93\t\t 0.0%\n",
"AFX\t\t48\t\t 0.0%\n",
"WP$\t\t15\t\t 0.0%\n",
"XX\t\t1\t\t 0.0%\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"BATCH_SIZE = 128\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
" (train_data, valid_data, test_data),\n",
" batch_size = BATCH_SIZE,\n",
" device = device)\n"
],
"metadata": {
"id": "-z4jitcMTNa_"
},
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class BiLSTMPOSTagger(nn.Module):\n",
" def __init__(self,\n",
" input_dim,\n",
" embedding_dim,\n",
" hidden_dim,\n",
" output_dim,\n",
" n_layers,\n",
" bidirectional,\n",
" dropout,\n",
" pad_idx):\n",
"\n",
" super().__init__()\n",
"\n",
" self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx = pad_idx)\n",
"\n",
" self.lstm = nn.LSTM(embedding_dim,\n",
" hidden_dim,\n",
" num_layers = n_layers,\n",
" bidirectional = bidirectional,\n",
" dropout = dropout if n_layers > 1 else 0)\n",
"\n",
" self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)\n",
"\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, text):\n",
"\n",
" #text = [sent len, batch size]\n",
"\n",
" #pass text through embedding layer\n",
" embedded = self.dropout(self.embedding(text))\n",
"\n",
" #embedded = [sent len, batch size, emb dim]\n",
"\n",
" #pass embeddings into LSTM\n",
" outputs, (hidden, cell) = self.lstm(embedded)\n",
"\n",
" #outputs holds the backward and forward hidden states in the final layer\n",
" #hidden and cell are the backward and forward hidden and cell states at the final time-step\n",
"\n",
" #output = [sent len, batch size, hid dim * n directions]\n",
" #hidden/cell = [n layers * n directions, batch size, hid dim]\n",
"\n",
" #we use our outputs to make a prediction of what the tag should be\n",
" predictions = self.fc(self.dropout(outputs))\n",
" #predictions = [sent len, batch size, output dim]\n",
"\n",
" return predictions"
],
"metadata": {
"id": "ID9qCiGVTPbp"
},
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"source": [
"INPUT_DIM = len(TEXT.vocab)\n",
"EMBEDDING_DIM = 100\n",
"HIDDEN_DIM = 128\n",
"OUTPUT_DIM = len(UD_TAGS.vocab)\n",
"N_LAYERS = 2\n",
"BIDIRECTIONAL = True\n",
"DROPOUT = 0.25\n",
"PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n",
"\n",
"model = BiLSTMPOSTagger(INPUT_DIM,\n",
" EMBEDDING_DIM,\n",
" HIDDEN_DIM,\n",
" OUTPUT_DIM,\n",
" N_LAYERS,\n",
" BIDIRECTIONAL,\n",
" DROPOUT,\n",
" PAD_IDX)"
],
"metadata": {
"id": "DSNg9SkHTXRd"
},
"execution_count": 25,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def init_weights(m):\n",
" for name, param in m.named_parameters():\n",
" nn.init.normal_(param.data, mean = 0, std = 0.1)\n",
"\n",
"model.apply(init_weights)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4AcETUvaTa2b",
"outputId": "e5489380-65d8-4391-ae9d-7965a5867eed"
},
"execution_count": 26,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"BiLSTMPOSTagger(\n",
" (embedding): Embedding(8866, 100, padding_idx=1)\n",
" (lstm): LSTM(100, 128, num_layers=2, dropout=0.25, bidirectional=True)\n",
" (fc): Linear(in_features=256, out_features=18, bias=True)\n",
" (dropout): Dropout(p=0.25, inplace=False)\n",
")"
]
},
"metadata": {},
"execution_count": 26
}
]
},
{
"cell_type": "code",
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ryNFiwzdTcZJ",
"outputId": "b1eda886-34a4-4ff7-8a51-70de48becddd"
},
"execution_count": 27,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The model has 1,522,010 trainable parameters\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"pretrained_embeddings = TEXT.vocab.vectors\n",
"\n",
"print(pretrained_embeddings.shape)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nctKuyIMTd0Q",
"outputId": "2c3aadac-274e-4d07-a56f-264ae247fd64"
},
"execution_count": 28,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.Size([8866, 100])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"model.embedding.weight.data.copy_(pretrained_embeddings)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hzU0JrMLTfTy",
"outputId": "7c2229cb-4c6d-4260-bf37-55de2f0a8d6c"
},
"execution_count": 29,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[-0.1117, -0.4966, 0.1631, ..., 1.2647, -0.2753, -0.1325],\n",
" [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n",
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
" ...,\n",
" [ 0.9261, 2.3049, 0.5502, ..., -0.3492, -0.5298, -0.1577],\n",
" [-0.5972, 0.0471, -0.2406, ..., -0.9446, -0.1126, -0.2260],\n",
" [-0.4809, 2.5629, 0.9530, ..., 0.5278, -0.4588, 0.7294]])"
]
},
"metadata": {},
"execution_count": 29
}
]
},
{
"cell_type": "code",
"source": [
"model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)\n",
"\n",
"print(model.embedding.weight.data)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jAe39yDZTgnZ",
"outputId": "ebcdfc9d-1704-463e-da4d-a63b025a15fc"
},
"execution_count": 30,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([[-0.1117, -0.4966, 0.1631, ..., 1.2647, -0.2753, -0.1325],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
" ...,\n",
" [ 0.9261, 2.3049, 0.5502, ..., -0.3492, -0.5298, -0.1577],\n",
" [-0.5972, 0.0471, -0.2406, ..., -0.9446, -0.1126, -0.2260],\n",
" [-0.4809, 2.5629, 0.9530, ..., 0.5278, -0.4588, 0.7294]])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"optimizer = optim.Adam(model.parameters())"
],
"metadata": {
"id": "D86Xga31TiQU"
},
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"source": [
"TAG_PAD_IDX = UD_TAGS.vocab.stoi[UD_TAGS.pad_token]\n",
"\n",
"criterion = nn.CrossEntropyLoss(ignore_index = TAG_PAD_IDX)"
],
"metadata": {
"id": "fzNTl9v5TjnS"
},
"execution_count": 32,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model = model.to(device)\n",
"criterion = criterion.to(device)"
],
"metadata": {
"id": "bqBLvtuJTlCe"
},
"execution_count": 33,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def categorical_accuracy(preds, y, tag_pad_idx):\n",
" \"\"\"\n",
" Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
" \"\"\"\n",
" max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability\n",
" non_pad_elements = (y != tag_pad_idx).nonzero()\n",
" correct = max_preds[non_pad_elements].squeeze(1).eq(y[non_pad_elements])\n",
" return correct.sum() / y[non_pad_elements].shape[0]"
],
"metadata": {
"id": "AUip03CeTmmk"
},
"execution_count": 34,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def train(model, iterator, optimizer, criterion, tag_pad_idx):\n",
"\n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
"\n",
" model.train()\n",
"\n",
" for batch in iterator:\n",
"\n",
" text = batch.text\n",
" tags = batch.udtags\n",
"\n",
" optimizer.zero_grad()\n",
"\n",
" #text = [sent len, batch size]\n",
"\n",
" predictions = model(text)\n",
"\n",
" #predictions = [sent len, batch size, output dim]\n",
" #tags = [sent len, batch size]\n",
"\n",
" predictions = predictions.view(-1, predictions.shape[-1])\n",
" tags = tags.view(-1)\n",
"\n",
" #predictions = [sent len * batch size, output dim]\n",
" #tags = [sent len * batch size]\n",
"\n",
" loss = criterion(predictions, tags)\n",
"\n",
" acc = categorical_accuracy(predictions, tags, tag_pad_idx)\n",
"\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
"\n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
"\n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
],
"metadata": {
"id": "wuAwRkfWToZ2"
},
"execution_count": 35,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def evaluate(model, iterator, criterion, tag_pad_idx):\n",
"\n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
"\n",
" model.eval()\n",
"\n",
" with torch.no_grad():\n",
"\n",
" for batch in iterator:\n",
"\n",
" text = batch.text\n",
" tags = batch.udtags\n",
"\n",
" predictions = model(text)\n",
"\n",
" predictions = predictions.view(-1, predictions.shape[-1])\n",
" tags = tags.view(-1)\n",
"\n",
" loss = criterion(predictions, tags)\n",
"\n",
" acc = categorical_accuracy(predictions, tags, tag_pad_idx)\n",
"\n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
"\n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
],
"metadata": {
"id": "iIohEqwbTrZg"
},
"execution_count": 36,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def epoch_time(start_time, end_time):\n",
" elapsed_time = end_time - start_time\n",
" elapsed_mins = int(elapsed_time / 60)\n",
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
" return elapsed_mins, elapsed_secs"
],
"metadata": {
"id": "wlC50XjJTtrh"
},
"execution_count": 37,
"outputs": []
},
{
"cell_type": "code",
"source": [
"N_EPOCHS = 10\n",
"\n",
"best_valid_loss = float('inf')\n",
"\n",
"for epoch in range(N_EPOCHS):\n",
"\n",
" start_time = time.time()\n",
"\n",
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)\n",
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)\n",
"\n",
" end_time = time.time()\n",
"\n",
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
"\n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), 'tut1-model.pt')\n",
"\n",
" print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
" print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gFfB1AscTvHA",
"outputId": "d386cfd4-cdbb-4cca-ab72-11dbf3822e6a"
},
"execution_count": 38,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch: 01 | Epoch Time: 1m 51s\n",
"\tTrain Loss: 1.337 | Train Acc: 58.14%\n",
"\t Val. Loss: 0.677 | Val. Acc: 78.71%\n",
"Epoch: 02 | Epoch Time: 1m 51s\n",
"\tTrain Loss: 0.481 | Train Acc: 84.90%\n",
"\t Val. Loss: 0.508 | Val. Acc: 83.09%\n",
"Epoch: 03 | Epoch Time: 1m 50s\n",
"\tTrain Loss: 0.350 | Train Acc: 88.96%\n",
"\t Val. Loss: 0.445 | Val. Acc: 85.30%\n",
"Epoch: 04 | Epoch Time: 1m 50s\n",
"\tTrain Loss: 0.292 | Train Acc: 90.71%\n",
"\t Val. Loss: 0.407 | Val. Acc: 86.23%\n",
"Epoch: 05 | Epoch Time: 1m 47s\n",
"\tTrain Loss: 0.253 | Train Acc: 92.02%\n",
"\t Val. Loss: 0.397 | Val. Acc: 86.28%\n",
"Epoch: 06 | Epoch Time: 1m 58s\n",
"\tTrain Loss: 0.226 | Train Acc: 92.82%\n",
"\t Val. Loss: 0.381 | Val. Acc: 87.09%\n",
"Epoch: 07 | Epoch Time: 1m 51s\n",
"\tTrain Loss: 0.206 | Train Acc: 93.40%\n",
"\t Val. Loss: 0.372 | Val. Acc: 87.45%\n",
"Epoch: 08 | Epoch Time: 1m 52s\n",
"\tTrain Loss: 0.192 | Train Acc: 93.89%\n",
"\t Val. Loss: 0.362 | Val. Acc: 87.59%\n",
"Epoch: 09 | Epoch Time: 1m 52s\n",
"\tTrain Loss: 0.176 | Train Acc: 94.30%\n",
"\t Val. Loss: 0.356 | Val. Acc: 88.34%\n",
"Epoch: 10 | Epoch Time: 1m 52s\n",
"\tTrain Loss: 0.167 | Train Acc: 94.62%\n",
"\t Val. Loss: 0.350 | Val. Acc: 88.28%\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"model.load_state_dict(torch.load('tut1-model.pt'))\n",
"\n",
"test_loss, test_acc = evaluate(model, test_iterator, criterion, TAG_PAD_IDX)\n",
"\n",
"print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LBfU8G1nTw-P",
"outputId": "9c65940e-efde-45ad-e314-e68b3b9aae1e"
},
"execution_count": 39,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Test Loss: 0.365 | Test Acc: 88.30%\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"def tag_sentence(model, device, sentence, text_field, tag_field):\n",
"\n",
" model.eval()\n",
"\n",
" if isinstance(sentence, str):\n",
" nlp = spacy.load('en_core_web_sm')\n",
" tokens = [token.text for token in nlp(sentence)]\n",
" else:\n",
" tokens = [token for token in sentence]\n",
"\n",
" if text_field.lower:\n",
" tokens = [t.lower() for t in tokens]\n",
"\n",
" numericalized_tokens = [text_field.vocab.stoi[t] for t in tokens]\n",
"\n",
" unk_idx = text_field.vocab.stoi[text_field.unk_token]\n",
"\n",
" unks = [t for t, n in zip(tokens, numericalized_tokens) if n == unk_idx]\n",
"\n",
" token_tensor = torch.LongTensor(numericalized_tokens)\n",
"\n",
" token_tensor = token_tensor.unsqueeze(-1).to(device)\n",
"\n",
" predictions = model(token_tensor)\n",
"\n",
" top_predictions = predictions.argmax(-1)\n",
"\n",
" predicted_tags = [tag_field.vocab.itos[t.item()] for t in top_predictions]\n",
"\n",
" return tokens, predicted_tags, unks"
],
"metadata": {
"id": "KgzX7IV6TzWM"
},
"execution_count": 40,
"outputs": []
},
{
"cell_type": "code",
"source": [
"example_index = 1\n",
"\n",
"sentence = vars(train_data.examples[example_index])['text']\n",
"actual_tags = vars(train_data.examples[example_index])['udtags']\n",
"\n",
"print(sentence)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3WvTT5GxT1Gn",
"outputId": "473983aa-788d-4c8d-eb43-cab443799cf1"
},
"execution_count": 41,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['[', 'this', 'killing', 'of', 'a', 'respected', 'cleric', 'will', 'be', 'causing', 'us', 'trouble', 'for', 'years', 'to', 'come', '.', ']']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"tokens, pred_tags, unks = tag_sentence(model,\n",
" device,\n",
" sentence,\n",
" TEXT,\n",
" UD_TAGS)\n",
"\n",
"print(unks)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TxRypSLVT20a",
"outputId": "5f2b4fdf-8820-4339-87ae-57aa18c48ad7"
},
"execution_count": 42,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['respected', 'cleric']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(\"Pred. Tag\\tActual Tag\\tCorrect?\\tToken\\n\")\n",
"\n",
"for token, pred_tag, actual_tag in zip(tokens, pred_tags, actual_tags):\n",
" correct = '✔' if pred_tag == actual_tag else '✘'\n",
" print(f\"{pred_tag}\\t\\t{actual_tag}\\t\\t{correct}\\t\\t{token}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UcQJ7AcLT3_U",
"outputId": "d6cf11ba-1b17-4016-d126-1a3b70e1f7f0"
},
"execution_count": 43,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Pred. Tag\tActual Tag\tCorrect?\tToken\n",
"\n",
"PUNCT\t\tPUNCT\t\t✔\t\t[\n",
"DET\t\tDET\t\t✔\t\tthis\n",
"VERB\t\tNOUN\t\t✘\t\tkilling\n",
"ADP\t\tADP\t\t✔\t\tof\n",
"DET\t\tDET\t\t✔\t\ta\n",
"ADJ\t\tADJ\t\t✔\t\trespected\n",
"NOUN\t\tNOUN\t\t✔\t\tcleric\n",
"AUX\t\tAUX\t\t✔\t\twill\n",
"AUX\t\tAUX\t\t✔\t\tbe\n",
"VERB\t\tVERB\t\t✔\t\tcausing\n",
"PRON\t\tPRON\t\t✔\t\tus\n",
"NOUN\t\tNOUN\t\t✔\t\ttrouble\n",
"ADP\t\tADP\t\t✔\t\tfor\n",
"NOUN\t\tNOUN\t\t✔\t\tyears\n",
"PART\t\tPART\t\t✔\t\tto\n",
"VERB\t\tVERB\t\t✔\t\tcome\n",
"PUNCT\t\tPUNCT\t\t✔\t\t.\n",
"PUNCT\t\tPUNCT\t\t✔\t\t]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"sentence = 'The Queen will deliver a speech about the conflict in North Korea at 1pm tomorrow.'\n",
"\n",
"tokens, tags, unks = tag_sentence(model,\n",
" device,\n",
" sentence,\n",
" TEXT,\n",
" UD_TAGS)\n",
"\n",
"print(unks)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_SaAEPsDT5tk",
"outputId": "e2d74a00-6d4c-4bce-bd4e-4deab8264398"
},
"execution_count": 44,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(\"Pred. Tag\\tToken\\n\")\n",
"\n",
"for token, tag in zip(tokens, tags):\n",
" print(f\"{tag}\\t\\t{token}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AbucCuhyT7Po",
"outputId": "1483f038-0ae4-47a5-a8d5-80c6a3652192"
},
"execution_count": 45,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Pred. Tag\tToken\n",
"\n",
"DET\t\tthe\n",
"NOUN\t\tqueen\n",
"AUX\t\twill\n",
"VERB\t\tdeliver\n",
"DET\t\ta\n",
"NOUN\t\tspeech\n",
"ADP\t\tabout\n",
"DET\t\tthe\n",
"NOUN\t\tconflict\n",
"ADP\t\tin\n",
"PROPN\t\tnorth\n",
"PROPN\t\tkorea\n",
"ADP\t\tat\n",
"NUM\t\t1\n",
"NOUN\t\tpm\n",
"NOUN\t\ttomorrow\n",
"PUNCT\t\t.\n"
]
}
]
}
]
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment