Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
人
人工智能系统实战第三期
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
yy
人工智能系统实战第三期
Commits
014d32d2
Commit
014d32d2
authored
Feb 16, 2024
by
前钰
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
239e6bc6
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1349 additions
and
0 deletions
+1349
-0
BiLSTM_for_PoS_Tagging.ipynb
...基于FNN、RNN和Transformer的词性标注实战/BiLSTM_for_PoS_Tagging.ipynb
+1349
-0
No files found.
人工智能系统实战第三期/实战代码/自然语言处理/基于FNN、RNN和Transformer的词性标注实战/BiLSTM_for_PoS_Tagging.ipynb
0 → 100644
View file @
014d32d2
{
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"source": [
"!pip uninstall torchtext"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "w9_nFDYUMx8V",
"outputId": "5eafc13a-0a81-4ae4-ae1c-2dbe8b66772c"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Found existing installation: torchtext 0.5.0\n",
"Uninstalling torchtext-0.5.0:\n",
" Would remove:\n",
" /usr/local/lib/python3.10/dist-packages/torchtext-0.5.0.dist-info/*\n",
" /usr/local/lib/python3.10/dist-packages/torchtext/*\n",
"Proceed (Y/n)? y\n",
" Successfully uninstalled torchtext-0.5.0\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!pip install torchtext==0.5.0"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6GJJtWbWSbUP",
"outputId": "c86370de-0d0f-4a20-9946-c62ada7826d8"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting torchtext==0.5.0\n",
" Using cached torchtext-0.5.0-py3-none-any.whl (73 kB)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from torchtext==0.5.0) (4.66.1)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchtext==0.5.0) (2.31.0)\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from torchtext==0.5.0) (2.1.0+cu121)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchtext==0.5.0) (1.23.5)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from torchtext==0.5.0) (1.16.0)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from torchtext==0.5.0) (0.1.99)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torchtext==0.5.0) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchtext==0.5.0) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchtext==0.5.0) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchtext==0.5.0) (2024.2.2)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (3.13.1)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (4.9.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (3.1.3)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (2023.6.0)\n",
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch->torchtext==0.5.0) (2.1.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->torchtext==0.5.0) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->torchtext==0.5.0) (1.3.0)\n",
"Installing collected packages: torchtext\n",
"Successfully installed torchtext-0.5.0\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "_4gbBNMaLkMP"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"from torchtext import data,datasets\n",
"\n",
"import spacy\n",
"import numpy as np\n",
"\n",
"import time\n",
"import random"
]
},
{
"cell_type": "code",
"source": [
"SEED = 1234\n",
"\n",
"random.seed(SEED)\n",
"np.random.seed(SEED)\n",
"torch.manual_seed(SEED)\n",
"torch.backends.cudnn.deterministic = True"
],
"metadata": {
"id": "spx1Z90lMnB_"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"TEXT = data.Field(lower = True)\n",
"UD_TAGS = data.Field(unk_token = None)\n",
"PTB_TAGS = data.Field(unk_token = None)"
],
"metadata": {
"id": "kMcOX2FJMpP-"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"fields = ((\"text\", TEXT), (\"udtags\", UD_TAGS), (\"ptbtags\", PTB_TAGS))"
],
"metadata": {
"id": "2EC4HzpdNtnQ"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_data, valid_data, test_data = datasets.UDPOS.splits(fields)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KHh4bmRhSxbE",
"outputId": "f149e1a3-6fbc-4244-abff-fed820872bd4"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"downloading en-ud-v2.zip\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"en-ud-v2.zip: 100%|██████████| 688k/688k [00:00<00:00, 7.52MB/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"extracting\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(f\"Number of training examples: {len(train_data)}\")\n",
"print(f\"Number of validation examples: {len(valid_data)}\")\n",
"print(f\"Number of testing examples: {len(test_data)}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "q4eM725LSyi8",
"outputId": "01425ced-bf53-444b-b2ba-907c5c288f5a"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of training examples: 12543\n",
"Number of validation examples: 2002\n",
"Number of testing examples: 2077\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(vars(train_data.examples[0]))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TaRQJiiVS0NR",
"outputId": "4081d745-cf1a-4564-ec06-33e3a33b66f6"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'text': ['al', '-', 'zaman', ':', 'american', 'forces', 'killed', 'shaikh', 'abdullah', 'al', '-', 'ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'qaim', ',', 'near', 'the', 'syrian', 'border', '.'], 'udtags': ['PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'ADJ', 'NOUN', 'VERB', 'PROPN', 'PROPN', 'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'PROPN', 'PUNCT', 'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT'], 'ptbtags': ['NNP', 'HYPH', 'NNP', ':', 'JJ', 'NNS', 'VBD', 'NNP', 'NNP', 'NNP', 'HYPH', 'NNP', ',', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'NNP', ',', 'IN', 'DT', 'JJ', 'NN', '.']}\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(vars(train_data.examples[0])['text'])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ao1b06NOS1oS",
"outputId": "fb948684-5bf7-44b6-9b70-128abc472565"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['al', '-', 'zaman', ':', 'american', 'forces', 'killed', 'shaikh', 'abdullah', 'al', '-', 'ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'qaim', ',', 'near', 'the', 'syrian', 'border', '.']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(vars(train_data.examples[0])['udtags'])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9Vqjl52IS3O0",
"outputId": "4c399263-2888-4830-8e06-b7f38805a300"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'ADJ', 'NOUN', 'VERB', 'PROPN', 'PROPN', 'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'PROPN', 'PUNCT', 'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(vars(train_data.examples[0])['ptbtags'])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ANZbJVQ4S4Vr",
"outputId": "b1c139c0-05d2-491d-d4db-a83d32cbbf99"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['NNP', 'HYPH', 'NNP', ':', 'JJ', 'NNS', 'VBD', 'NNP', 'NNP', 'NNP', 'HYPH', 'NNP', ',', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'NNP', ',', 'IN', 'DT', 'JJ', 'NN', '.']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"MIN_FREQ = 2\n",
"\n",
"TEXT.build_vocab(train_data,\n",
" min_freq = MIN_FREQ,\n",
" vectors = \"glove.6B.100d\",\n",
" unk_init = torch.Tensor.normal_)\n",
"\n",
"\n",
"UD_TAGS.build_vocab(train_data)\n",
"PTB_TAGS.build_vocab(train_data)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8WIWCycgS5ZH",
"outputId": "0528fbaf-e0eb-4abe-f90b-ad96c9aea08b"
},
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
".vector_cache/glove.6B.zip: 862MB [03:04, 4.67MB/s] \n",
"100%|█████████▉| 399999/400000 [00:30<00:00, 13181.49it/s]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(f\"Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}\")\n",
"print(f\"Unique tokens in UD_TAG vocabulary: {len(UD_TAGS.vocab)}\")\n",
"print(f\"Unique tokens in PTB_TAG vocabulary: {len(PTB_TAGS.vocab)}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9fV1-uRDS77N",
"outputId": "f3da5965-9e98-46a9-bb70-9c1c9cfbd069"
},
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Unique tokens in TEXT vocabulary: 8866\n",
"Unique tokens in UD_TAG vocabulary: 18\n",
"Unique tokens in PTB_TAG vocabulary: 51\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(TEXT.vocab.freqs.most_common(20))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iYgP-itMTAIf",
"outputId": "71bcfca9-05e8-4dc4-8c8c-464724646401"
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[('the', 9076), ('.', 8640), (',', 7021), ('to', 5137), ('and', 5002), ('a', 3782), ('of', 3622), ('i', 3379), ('in', 3112), ('is', 2239), ('you', 2156), ('that', 2036), ('it', 1850), ('for', 1842), ('-', 1426), ('have', 1359), ('\"', 1296), ('on', 1273), ('was', 1244), ('with', 1216)]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(UD_TAGS.vocab.itos)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "R4B3J-aWTB8W",
"outputId": "da9009a0-c3e7-4b52-e225-9c522c0e61a6"
},
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['<pad>', 'NOUN', 'PUNCT', 'VERB', 'PRON', 'ADP', 'DET', 'PROPN', 'ADJ', 'AUX', 'ADV', 'CCONJ', 'PART', 'NUM', 'SCONJ', 'X', 'INTJ', 'SYM']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(PTB_TAGS.vocab.itos)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GBMmathDTC-V",
"outputId": "1770c0d3-e330-4797-c6a3-69666d2c6ed2"
},
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['<pad>', 'NN', 'IN', 'DT', 'NNP', 'PRP', 'JJ', 'RB', '.', 'VB', 'NNS', ',', 'CC', 'VBD', 'VBP', 'VBZ', 'CD', 'VBN', 'VBG', 'MD', 'TO', 'PRP$', '-RRB-', '-LRB-', 'WDT', 'WRB', ':', '``', \"''\", 'WP', 'RP', 'UH', 'POS', 'HYPH', 'JJR', 'NNPS', 'JJS', 'EX', 'NFP', 'GW', 'ADD', 'RBR', '$', 'PDT', 'RBS', 'SYM', 'LS', 'FW', 'AFX', 'WP$', 'XX']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(UD_TAGS.vocab.freqs.most_common())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dq6Wn8BATELe",
"outputId": "2410f1bc-28fd-4d45-8f81-024a6c4809b9"
},
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[('NOUN', 34781), ('PUNCT', 23679), ('VERB', 23081), ('PRON', 18577), ('ADP', 17638), ('DET', 16285), ('PROPN', 12946), ('ADJ', 12477), ('AUX', 12343), ('ADV', 10548), ('CCONJ', 6707), ('PART', 5567), ('NUM', 3999), ('SCONJ', 3843), ('X', 847), ('INTJ', 688), ('SYM', 599)]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(PTB_TAGS.vocab.freqs.most_common())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XuAMn3K2TFbC",
"outputId": "73f430f8-37b3-4700-8769-fcac7fecb74a"
},
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[('NN', 26915), ('IN', 20724), ('DT', 16817), ('NNP', 12449), ('PRP', 12193), ('JJ', 11591), ('RB', 10831), ('.', 10317), ('VB', 9476), ('NNS', 8438), (',', 8062), ('CC', 6706), ('VBD', 5402), ('VBP', 5374), ('VBZ', 4578), ('CD', 3998), ('VBN', 3967), ('VBG', 3330), ('MD', 3294), ('TO', 3286), ('PRP$', 3068), ('-RRB-', 1008), ('-LRB-', 973), ('WDT', 948), ('WRB', 869), (':', 866), ('``', 813), (\"''\", 785), ('WP', 760), ('RP', 755), ('UH', 689), ('POS', 684), ('HYPH', 664), ('JJR', 503), ('NNPS', 498), ('JJS', 383), ('EX', 359), ('NFP', 338), ('GW', 294), ('ADD', 292), ('RBR', 276), ('$', 258), ('PDT', 175), ('RBS', 169), ('SYM', 156), ('LS', 117), ('FW', 93), ('AFX', 48), ('WP$', 15), ('XX', 1)]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"def tag_percentage(tag_counts):\n",
"\n",
" total_count = sum([count for tag, count in tag_counts])\n",
"\n",
" tag_counts_percentages = [(tag, count, count/total_count) for tag, count in tag_counts]\n",
"\n",
" return tag_counts_percentages"
],
"metadata": {
"id": "A9FbORvyTGj4"
},
"execution_count": 20,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(\"Tag\\t\\tCount\\t\\tPercentage\\n\")\n",
"\n",
"for tag, count, percent in tag_percentage(UD_TAGS.vocab.freqs.most_common()):\n",
" print(f\"{tag}\\t\\t{count}\\t\\t{percent*100:4.1f}%\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WsnW25s5TIyK",
"outputId": "8078b619-1322-4fec-826f-aac16e217d6b"
},
"execution_count": 21,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Tag\t\tCount\t\tPercentage\n",
"\n",
"NOUN\t\t34781\t\t17.0%\n",
"PUNCT\t\t23679\t\t11.6%\n",
"VERB\t\t23081\t\t11.3%\n",
"PRON\t\t18577\t\t 9.1%\n",
"ADP\t\t17638\t\t 8.6%\n",
"DET\t\t16285\t\t 8.0%\n",
"PROPN\t\t12946\t\t 6.3%\n",
"ADJ\t\t12477\t\t 6.1%\n",
"AUX\t\t12343\t\t 6.0%\n",
"ADV\t\t10548\t\t 5.2%\n",
"CCONJ\t\t6707\t\t 3.3%\n",
"PART\t\t5567\t\t 2.7%\n",
"NUM\t\t3999\t\t 2.0%\n",
"SCONJ\t\t3843\t\t 1.9%\n",
"X\t\t847\t\t 0.4%\n",
"INTJ\t\t688\t\t 0.3%\n",
"SYM\t\t599\t\t 0.3%\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(\"Tag\\t\\tCount\\t\\tPercentage\\n\")\n",
"\n",
"for tag, count, percent in tag_percentage(PTB_TAGS.vocab.freqs.most_common()):\n",
" print(f\"{tag}\\t\\t{count}\\t\\t{percent*100:4.1f}%\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eEsT_qOaTLe0",
"outputId": "e4a2dd6c-d40c-4c10-9ff9-0898de4b9ce4"
},
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Tag\t\tCount\t\tPercentage\n",
"\n",
"NN\t\t26915\t\t13.2%\n",
"IN\t\t20724\t\t10.1%\n",
"DT\t\t16817\t\t 8.2%\n",
"NNP\t\t12449\t\t 6.1%\n",
"PRP\t\t12193\t\t 6.0%\n",
"JJ\t\t11591\t\t 5.7%\n",
"RB\t\t10831\t\t 5.3%\n",
".\t\t10317\t\t 5.0%\n",
"VB\t\t9476\t\t 4.6%\n",
"NNS\t\t8438\t\t 4.1%\n",
",\t\t8062\t\t 3.9%\n",
"CC\t\t6706\t\t 3.3%\n",
"VBD\t\t5402\t\t 2.6%\n",
"VBP\t\t5374\t\t 2.6%\n",
"VBZ\t\t4578\t\t 2.2%\n",
"CD\t\t3998\t\t 2.0%\n",
"VBN\t\t3967\t\t 1.9%\n",
"VBG\t\t3330\t\t 1.6%\n",
"MD\t\t3294\t\t 1.6%\n",
"TO\t\t3286\t\t 1.6%\n",
"PRP$\t\t3068\t\t 1.5%\n",
"-RRB-\t\t1008\t\t 0.5%\n",
"-LRB-\t\t973\t\t 0.5%\n",
"WDT\t\t948\t\t 0.5%\n",
"WRB\t\t869\t\t 0.4%\n",
":\t\t866\t\t 0.4%\n",
"``\t\t813\t\t 0.4%\n",
"''\t\t785\t\t 0.4%\n",
"WP\t\t760\t\t 0.4%\n",
"RP\t\t755\t\t 0.4%\n",
"UH\t\t689\t\t 0.3%\n",
"POS\t\t684\t\t 0.3%\n",
"HYPH\t\t664\t\t 0.3%\n",
"JJR\t\t503\t\t 0.2%\n",
"NNPS\t\t498\t\t 0.2%\n",
"JJS\t\t383\t\t 0.2%\n",
"EX\t\t359\t\t 0.2%\n",
"NFP\t\t338\t\t 0.2%\n",
"GW\t\t294\t\t 0.1%\n",
"ADD\t\t292\t\t 0.1%\n",
"RBR\t\t276\t\t 0.1%\n",
"$\t\t258\t\t 0.1%\n",
"PDT\t\t175\t\t 0.1%\n",
"RBS\t\t169\t\t 0.1%\n",
"SYM\t\t156\t\t 0.1%\n",
"LS\t\t117\t\t 0.1%\n",
"FW\t\t93\t\t 0.0%\n",
"AFX\t\t48\t\t 0.0%\n",
"WP$\t\t15\t\t 0.0%\n",
"XX\t\t1\t\t 0.0%\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"BATCH_SIZE = 128\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
" (train_data, valid_data, test_data),\n",
" batch_size = BATCH_SIZE,\n",
" device = device)\n"
],
"metadata": {
"id": "-z4jitcMTNa_"
},
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class BiLSTMPOSTagger(nn.Module):\n",
" def __init__(self,\n",
" input_dim,\n",
" embedding_dim,\n",
" hidden_dim,\n",
" output_dim,\n",
" n_layers,\n",
" bidirectional,\n",
" dropout,\n",
" pad_idx):\n",
"\n",
" super().__init__()\n",
"\n",
" self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx = pad_idx)\n",
"\n",
" self.lstm = nn.LSTM(embedding_dim,\n",
" hidden_dim,\n",
" num_layers = n_layers,\n",
" bidirectional = bidirectional,\n",
" dropout = dropout if n_layers > 1 else 0)\n",
"\n",
" self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)\n",
"\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, text):\n",
"\n",
" #text = [sent len, batch size]\n",
"\n",
" #pass text through embedding layer\n",
" embedded = self.dropout(self.embedding(text))\n",
"\n",
" #embedded = [sent len, batch size, emb dim]\n",
"\n",
" #pass embeddings into LSTM\n",
" outputs, (hidden, cell) = self.lstm(embedded)\n",
"\n",
" #outputs holds the backward and forward hidden states in the final layer\n",
" #hidden and cell are the backward and forward hidden and cell states at the final time-step\n",
"\n",
" #output = [sent len, batch size, hid dim * n directions]\n",
" #hidden/cell = [n layers * n directions, batch size, hid dim]\n",
"\n",
" #we use our outputs to make a prediction of what the tag should be\n",
" predictions = self.fc(self.dropout(outputs))\n",
" #predictions = [sent len, batch size, output dim]\n",
"\n",
" return predictions"
],
"metadata": {
"id": "ID9qCiGVTPbp"
},
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"source": [
"INPUT_DIM = len(TEXT.vocab)\n",
"EMBEDDING_DIM = 100\n",
"HIDDEN_DIM = 128\n",
"OUTPUT_DIM = len(UD_TAGS.vocab)\n",
"N_LAYERS = 2\n",
"BIDIRECTIONAL = True\n",
"DROPOUT = 0.25\n",
"PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n",
"\n",
"model = BiLSTMPOSTagger(INPUT_DIM,\n",
" EMBEDDING_DIM,\n",
" HIDDEN_DIM,\n",
" OUTPUT_DIM,\n",
" N_LAYERS,\n",
" BIDIRECTIONAL,\n",
" DROPOUT,\n",
" PAD_IDX)"
],
"metadata": {
"id": "DSNg9SkHTXRd"
},
"execution_count": 25,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def init_weights(m):\n",
" for name, param in m.named_parameters():\n",
" nn.init.normal_(param.data, mean = 0, std = 0.1)\n",
"\n",
"model.apply(init_weights)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4AcETUvaTa2b",
"outputId": "e5489380-65d8-4391-ae9d-7965a5867eed"
},
"execution_count": 26,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"BiLSTMPOSTagger(\n",
" (embedding): Embedding(8866, 100, padding_idx=1)\n",
" (lstm): LSTM(100, 128, num_layers=2, dropout=0.25, bidirectional=True)\n",
" (fc): Linear(in_features=256, out_features=18, bias=True)\n",
" (dropout): Dropout(p=0.25, inplace=False)\n",
")"
]
},
"metadata": {},
"execution_count": 26
}
]
},
{
"cell_type": "code",
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ryNFiwzdTcZJ",
"outputId": "b1eda886-34a4-4ff7-8a51-70de48becddd"
},
"execution_count": 27,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The model has 1,522,010 trainable parameters\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"pretrained_embeddings = TEXT.vocab.vectors\n",
"\n",
"print(pretrained_embeddings.shape)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nctKuyIMTd0Q",
"outputId": "2c3aadac-274e-4d07-a56f-264ae247fd64"
},
"execution_count": 28,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.Size([8866, 100])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"model.embedding.weight.data.copy_(pretrained_embeddings)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hzU0JrMLTfTy",
"outputId": "7c2229cb-4c6d-4260-bf37-55de2f0a8d6c"
},
"execution_count": 29,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[-0.1117, -0.4966, 0.1631, ..., 1.2647, -0.2753, -0.1325],\n",
" [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n",
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
" ...,\n",
" [ 0.9261, 2.3049, 0.5502, ..., -0.3492, -0.5298, -0.1577],\n",
" [-0.5972, 0.0471, -0.2406, ..., -0.9446, -0.1126, -0.2260],\n",
" [-0.4809, 2.5629, 0.9530, ..., 0.5278, -0.4588, 0.7294]])"
]
},
"metadata": {},
"execution_count": 29
}
]
},
{
"cell_type": "code",
"source": [
"model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)\n",
"\n",
"print(model.embedding.weight.data)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jAe39yDZTgnZ",
"outputId": "ebcdfc9d-1704-463e-da4d-a63b025a15fc"
},
"execution_count": 30,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([[-0.1117, -0.4966, 0.1631, ..., 1.2647, -0.2753, -0.1325],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
" ...,\n",
" [ 0.9261, 2.3049, 0.5502, ..., -0.3492, -0.5298, -0.1577],\n",
" [-0.5972, 0.0471, -0.2406, ..., -0.9446, -0.1126, -0.2260],\n",
" [-0.4809, 2.5629, 0.9530, ..., 0.5278, -0.4588, 0.7294]])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"optimizer = optim.Adam(model.parameters())"
],
"metadata": {
"id": "D86Xga31TiQU"
},
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"source": [
"TAG_PAD_IDX = UD_TAGS.vocab.stoi[UD_TAGS.pad_token]\n",
"\n",
"criterion = nn.CrossEntropyLoss(ignore_index = TAG_PAD_IDX)"
],
"metadata": {
"id": "fzNTl9v5TjnS"
},
"execution_count": 32,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model = model.to(device)\n",
"criterion = criterion.to(device)"
],
"metadata": {
"id": "bqBLvtuJTlCe"
},
"execution_count": 33,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def categorical_accuracy(preds, y, tag_pad_idx):\n",
" \"\"\"\n",
" Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
" \"\"\"\n",
" max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability\n",
" non_pad_elements = (y != tag_pad_idx).nonzero()\n",
" correct = max_preds[non_pad_elements].squeeze(1).eq(y[non_pad_elements])\n",
" return correct.sum() / y[non_pad_elements].shape[0]"
],
"metadata": {
"id": "AUip03CeTmmk"
},
"execution_count": 34,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def train(model, iterator, optimizer, criterion, tag_pad_idx):\n",
"\n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
"\n",
" model.train()\n",
"\n",
" for batch in iterator:\n",
"\n",
" text = batch.text\n",
" tags = batch.udtags\n",
"\n",
" optimizer.zero_grad()\n",
"\n",
" #text = [sent len, batch size]\n",
"\n",
" predictions = model(text)\n",
"\n",
" #predictions = [sent len, batch size, output dim]\n",
" #tags = [sent len, batch size]\n",
"\n",
" predictions = predictions.view(-1, predictions.shape[-1])\n",
" tags = tags.view(-1)\n",
"\n",
" #predictions = [sent len * batch size, output dim]\n",
" #tags = [sent len * batch size]\n",
"\n",
" loss = criterion(predictions, tags)\n",
"\n",
" acc = categorical_accuracy(predictions, tags, tag_pad_idx)\n",
"\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
"\n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
"\n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
],
"metadata": {
"id": "wuAwRkfWToZ2"
},
"execution_count": 35,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def evaluate(model, iterator, criterion, tag_pad_idx):\n",
"\n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
"\n",
" model.eval()\n",
"\n",
" with torch.no_grad():\n",
"\n",
" for batch in iterator:\n",
"\n",
" text = batch.text\n",
" tags = batch.udtags\n",
"\n",
" predictions = model(text)\n",
"\n",
" predictions = predictions.view(-1, predictions.shape[-1])\n",
" tags = tags.view(-1)\n",
"\n",
" loss = criterion(predictions, tags)\n",
"\n",
" acc = categorical_accuracy(predictions, tags, tag_pad_idx)\n",
"\n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
"\n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
],
"metadata": {
"id": "iIohEqwbTrZg"
},
"execution_count": 36,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def epoch_time(start_time, end_time):\n",
" elapsed_time = end_time - start_time\n",
" elapsed_mins = int(elapsed_time / 60)\n",
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
" return elapsed_mins, elapsed_secs"
],
"metadata": {
"id": "wlC50XjJTtrh"
},
"execution_count": 37,
"outputs": []
},
{
"cell_type": "code",
"source": [
"N_EPOCHS = 10\n",
"\n",
"best_valid_loss = float('inf')\n",
"\n",
"for epoch in range(N_EPOCHS):\n",
"\n",
" start_time = time.time()\n",
"\n",
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)\n",
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)\n",
"\n",
" end_time = time.time()\n",
"\n",
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
"\n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), 'tut1-model.pt')\n",
"\n",
" print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
" print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gFfB1AscTvHA",
"outputId": "d386cfd4-cdbb-4cca-ab72-11dbf3822e6a"
},
"execution_count": 38,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch: 01 | Epoch Time: 1m 51s\n",
"\tTrain Loss: 1.337 | Train Acc: 58.14%\n",
"\t Val. Loss: 0.677 | Val. Acc: 78.71%\n",
"Epoch: 02 | Epoch Time: 1m 51s\n",
"\tTrain Loss: 0.481 | Train Acc: 84.90%\n",
"\t Val. Loss: 0.508 | Val. Acc: 83.09%\n",
"Epoch: 03 | Epoch Time: 1m 50s\n",
"\tTrain Loss: 0.350 | Train Acc: 88.96%\n",
"\t Val. Loss: 0.445 | Val. Acc: 85.30%\n",
"Epoch: 04 | Epoch Time: 1m 50s\n",
"\tTrain Loss: 0.292 | Train Acc: 90.71%\n",
"\t Val. Loss: 0.407 | Val. Acc: 86.23%\n",
"Epoch: 05 | Epoch Time: 1m 47s\n",
"\tTrain Loss: 0.253 | Train Acc: 92.02%\n",
"\t Val. Loss: 0.397 | Val. Acc: 86.28%\n",
"Epoch: 06 | Epoch Time: 1m 58s\n",
"\tTrain Loss: 0.226 | Train Acc: 92.82%\n",
"\t Val. Loss: 0.381 | Val. Acc: 87.09%\n",
"Epoch: 07 | Epoch Time: 1m 51s\n",
"\tTrain Loss: 0.206 | Train Acc: 93.40%\n",
"\t Val. Loss: 0.372 | Val. Acc: 87.45%\n",
"Epoch: 08 | Epoch Time: 1m 52s\n",
"\tTrain Loss: 0.192 | Train Acc: 93.89%\n",
"\t Val. Loss: 0.362 | Val. Acc: 87.59%\n",
"Epoch: 09 | Epoch Time: 1m 52s\n",
"\tTrain Loss: 0.176 | Train Acc: 94.30%\n",
"\t Val. Loss: 0.356 | Val. Acc: 88.34%\n",
"Epoch: 10 | Epoch Time: 1m 52s\n",
"\tTrain Loss: 0.167 | Train Acc: 94.62%\n",
"\t Val. Loss: 0.350 | Val. Acc: 88.28%\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"model.load_state_dict(torch.load('tut1-model.pt'))\n",
"\n",
"test_loss, test_acc = evaluate(model, test_iterator, criterion, TAG_PAD_IDX)\n",
"\n",
"print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LBfU8G1nTw-P",
"outputId": "9c65940e-efde-45ad-e314-e68b3b9aae1e"
},
"execution_count": 39,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Test Loss: 0.365 | Test Acc: 88.30%\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"def tag_sentence(model, device, sentence, text_field, tag_field):\n",
"\n",
" model.eval()\n",
"\n",
" if isinstance(sentence, str):\n",
" nlp = spacy.load('en_core_web_sm')\n",
" tokens = [token.text for token in nlp(sentence)]\n",
" else:\n",
" tokens = [token for token in sentence]\n",
"\n",
" if text_field.lower:\n",
" tokens = [t.lower() for t in tokens]\n",
"\n",
" numericalized_tokens = [text_field.vocab.stoi[t] for t in tokens]\n",
"\n",
" unk_idx = text_field.vocab.stoi[text_field.unk_token]\n",
"\n",
" unks = [t for t, n in zip(tokens, numericalized_tokens) if n == unk_idx]\n",
"\n",
" token_tensor = torch.LongTensor(numericalized_tokens)\n",
"\n",
" token_tensor = token_tensor.unsqueeze(-1).to(device)\n",
"\n",
" predictions = model(token_tensor)\n",
"\n",
" top_predictions = predictions.argmax(-1)\n",
"\n",
" predicted_tags = [tag_field.vocab.itos[t.item()] for t in top_predictions]\n",
"\n",
" return tokens, predicted_tags, unks"
],
"metadata": {
"id": "KgzX7IV6TzWM"
},
"execution_count": 40,
"outputs": []
},
{
"cell_type": "code",
"source": [
"example_index = 1\n",
"\n",
"sentence = vars(train_data.examples[example_index])['text']\n",
"actual_tags = vars(train_data.examples[example_index])['udtags']\n",
"\n",
"print(sentence)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3WvTT5GxT1Gn",
"outputId": "473983aa-788d-4c8d-eb43-cab443799cf1"
},
"execution_count": 41,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['[', 'this', 'killing', 'of', 'a', 'respected', 'cleric', 'will', 'be', 'causing', 'us', 'trouble', 'for', 'years', 'to', 'come', '.', ']']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"tokens, pred_tags, unks = tag_sentence(model,\n",
" device,\n",
" sentence,\n",
" TEXT,\n",
" UD_TAGS)\n",
"\n",
"print(unks)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TxRypSLVT20a",
"outputId": "5f2b4fdf-8820-4339-87ae-57aa18c48ad7"
},
"execution_count": 42,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['respected', 'cleric']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(\"Pred. Tag\\tActual Tag\\tCorrect?\\tToken\\n\")\n",
"\n",
"for token, pred_tag, actual_tag in zip(tokens, pred_tags, actual_tags):\n",
" correct = '✔' if pred_tag == actual_tag else '✘'\n",
" print(f\"{pred_tag}\\t\\t{actual_tag}\\t\\t{correct}\\t\\t{token}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UcQJ7AcLT3_U",
"outputId": "d6cf11ba-1b17-4016-d126-1a3b70e1f7f0"
},
"execution_count": 43,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Pred. Tag\tActual Tag\tCorrect?\tToken\n",
"\n",
"PUNCT\t\tPUNCT\t\t✔\t\t[\n",
"DET\t\tDET\t\t✔\t\tthis\n",
"VERB\t\tNOUN\t\t✘\t\tkilling\n",
"ADP\t\tADP\t\t✔\t\tof\n",
"DET\t\tDET\t\t✔\t\ta\n",
"ADJ\t\tADJ\t\t✔\t\trespected\n",
"NOUN\t\tNOUN\t\t✔\t\tcleric\n",
"AUX\t\tAUX\t\t✔\t\twill\n",
"AUX\t\tAUX\t\t✔\t\tbe\n",
"VERB\t\tVERB\t\t✔\t\tcausing\n",
"PRON\t\tPRON\t\t✔\t\tus\n",
"NOUN\t\tNOUN\t\t✔\t\ttrouble\n",
"ADP\t\tADP\t\t✔\t\tfor\n",
"NOUN\t\tNOUN\t\t✔\t\tyears\n",
"PART\t\tPART\t\t✔\t\tto\n",
"VERB\t\tVERB\t\t✔\t\tcome\n",
"PUNCT\t\tPUNCT\t\t✔\t\t.\n",
"PUNCT\t\tPUNCT\t\t✔\t\t]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"sentence = 'The Queen will deliver a speech about the conflict in North Korea at 1pm tomorrow.'\n",
"\n",
"tokens, tags, unks = tag_sentence(model,\n",
" device,\n",
" sentence,\n",
" TEXT,\n",
" UD_TAGS)\n",
"\n",
"print(unks)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_SaAEPsDT5tk",
"outputId": "e2d74a00-6d4c-4bce-bd4e-4deab8264398"
},
"execution_count": 44,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(\"Pred. Tag\\tToken\\n\")\n",
"\n",
"for token, tag in zip(tokens, tags):\n",
" print(f\"{tag}\\t\\t{token}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AbucCuhyT7Po",
"outputId": "1483f038-0ae4-47a5-a8d5-80c6a3652192"
},
"execution_count": 45,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Pred. Tag\tToken\n",
"\n",
"DET\t\tthe\n",
"NOUN\t\tqueen\n",
"AUX\t\twill\n",
"VERB\t\tdeliver\n",
"DET\t\ta\n",
"NOUN\t\tspeech\n",
"ADP\t\tabout\n",
"DET\t\tthe\n",
"NOUN\t\tconflict\n",
"ADP\t\tin\n",
"PROPN\t\tnorth\n",
"PROPN\t\tkorea\n",
"ADP\t\tat\n",
"NUM\t\t1\n",
"NOUN\t\tpm\n",
"NOUN\t\ttomorrow\n",
"PUNCT\t\t.\n"
]
}
]
}
]
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment